├── .gitignore ├── LICENSE ├── README.md ├── configs ├── .gitkeep ├── model=bedsrnet │ └── config.yaml └── model=benet │ └── config.yaml ├── infer.py ├── libs ├── checkpoint.py ├── config.py ├── dataset.py ├── dataset_csv.py ├── device.py ├── helper.py ├── helper_bedsrnet.py ├── logger.py ├── loss.py ├── loss_fn │ └── __init__.py ├── meter.py ├── metric.py ├── models │ ├── __init__.py │ ├── cam.py │ ├── fix_weight_dict.py │ └── models.py ├── seed.py ├── transformer.py └── visualize_grid.py ├── make_dataset.py ├── pretrained └── .gitkeep ├── train_bedsrnet.py ├── train_benet.py └── utils └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | csv/ 2 | dataset/ 3 | configs/ 4 | pretrained/ 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | pip-wheel-metadata/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # IPython 86 | profile_default/ 87 | ipython_config.py 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 100 | __pypackages__/ 101 | 102 | # Celery stuff 103 | celerybeat-schedule 104 | celerybeat.pid 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Nick Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BEDSR-Net 2 | 3 | This repository is unofficial implementation of [BEDSR-Net: A Deep Shadow Removal Network From a Single Document Image](https://openaccess.thecvf.com/content_CVPR_2020/html/Lin_BEDSR-Net_A_Deep_Shadow_Removal_Network_From_a_Single_Document_CVPR_2020_paper.html) [Lin+, CVPR 2020] with PyTorch. 4 | 5 | A refined version of [IsHYuhi's implementation](https://github.com/IsHYuhi/BEDSR-Net_A_Deep_Shadow_Removal_Network_from_a_Single_Document_Image). 6 | 7 | ## Fix several problems 8 | 1. nn.ConvTranspose2d compatible with higher version of Pytorch 9 | 2. gradcam uses too much vram, use [pytorch-grad-cam](https://github.com/jacobgil/pytorch-grad-cam) instead 10 | 3. provide default correct training config 11 | 4. provide easy inference code 12 | 13 | ## Dependencies 14 | Pytorch, torchvision, matplotlib, wandb, albumentations, pytorch-grad-cam 15 | 16 | ## Dataset Structure 17 | 18 | The dataset should be formatted like below, train.csv and test.csv can be generated using 19 | 20 | ```python 21 | python make_dataset.py 22 | ``` 23 | ``` 24 | . 25 | ├── csv/ 26 | │ └── Jung/ 27 | │ ├── train.csv 28 | │ └── test.csv 29 | └── dataset/ 30 | └── Jung/ 31 | ├── train/ 32 | │ ├── input/ 33 | │ │ ├── *.jpg 34 | │ │ └── ... 35 | │ └── target/ 36 | │ ├── *.jpg 37 | │ └── ... 38 | └── test/ 39 | ├── input/ 40 | │ ├── *.jpg 41 | │ └── ... 42 | └── target/ 43 | ├── *.jpg 44 | └── ... 45 | ``` 46 | 47 | 48 | ## Training 49 | 50 | Training BE-Net 51 | ```python 52 | python3 train_benet.py ./configs/model\=benet/config.yaml 53 | ``` 54 | 55 | Training BEDSR-Net 56 | ```python 57 | python3 train_bedsrnet.py ./configs/model\=bedsrnet/config.yaml 58 | ``` 59 | 60 | You can use W&B by ```--use_wandb```. 61 | 62 | ## Infer 63 | 64 | mask sure put all your model state_dict into pretrained directory 65 | 66 | ```python 67 | python infer.py 68 | ``` 69 | 70 | result images will be produced in results folder 71 | -------------------------------------------------------------------------------- /configs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-Reimplementation/BEDSR-Net-Reimplementation/8004a9abf8cc2e8550a46c7d1b894075d7a1b752/configs/.gitkeep -------------------------------------------------------------------------------- /configs/model=bedsrnet/config.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 1 2 | beta1: 0.5 3 | beta2: 0.999 4 | dataset_name: Adobe 5 | height: 512 6 | lambda1: 1.0 7 | lambda2: 0.01 8 | learning_rate: 0.003 9 | loss_function_name: GAN 10 | max_epoch: 300 11 | model: bedsrnet 12 | num_workers: 8 13 | pretrained: false 14 | width: 512 15 | -------------------------------------------------------------------------------- /configs/model=benet/config.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 1 2 | beta1: 0.5 3 | beta2: 0.999 4 | dataset_name: Adobe 5 | height: 512 6 | lambda1: 1.0 7 | lambda2: 0.01 8 | learning_rate: 0.003 9 | loss_function_name: L1 10 | max_epoch: 300 11 | model: benet 12 | num_workers: 8 13 | pretrained: false 14 | width: 512 15 | -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import torch.optim as optim 5 | from albumentations import Compose, Normalize, Resize 6 | from albumentations.pytorch import ToTensorV2 7 | 8 | from libs.dataset import get_dataloader 9 | from libs.device import get_device 10 | from libs.helper_bedsrnet import infer 11 | from libs.loss_fn import get_criterion 12 | from libs.models import get_model 13 | from libs.seed import set_seed 14 | 15 | if __name__ == '__main__': 16 | os.makedirs('results', exist_ok=True) 17 | result_path = os.path.dirname('results') 18 | 19 | set_seed() 20 | device = get_device(allow_only_gpu=False) 21 | 22 | val_transform = Compose( 23 | [ 24 | Resize(512, 512), 25 | Normalize(mean=(0.5,), std=(0.5,)), 26 | ToTensorV2(), 27 | ] 28 | ) 29 | 30 | val_loader = get_dataloader( 31 | 'Adobe', 32 | 'bedsrnet', 33 | "test", 34 | batch_size=1, 35 | shuffle=False, 36 | num_workers=8, 37 | pin_memory=True, 38 | transform=val_transform, 39 | ) 40 | 41 | lambda_dict = {"lambda1": 1.0, "lambda2": 0.01} 42 | criterion = get_criterion('GAN', device) 43 | 44 | benet = get_model("benet", in_channels=3, pretrained=True) 45 | srnet = get_model("srnet", pretrained=True) 46 | generator, discriminator = srnet[0], srnet[1] 47 | 48 | benet.eval() 49 | benet.to(device) 50 | generator.to(device) 51 | discriminator.to(device) 52 | 53 | infer( 54 | val_loader, generator, discriminator, benet, criterion, lambda_dict, device 55 | ) 56 | 57 | -------------------------------------------------------------------------------- /libs/checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | from logging import getLogger 3 | from typing import Tuple 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | 9 | logger = getLogger(__name__) 10 | 11 | 12 | def save_checkpoint( 13 | result_path: str, 14 | epoch: int, 15 | model: nn.Module, 16 | optimizer: optim.Optimizer, 17 | best_loss: float, 18 | ) -> None: 19 | 20 | save_states = { 21 | "epoch": epoch, 22 | "state_dict": model.state_dict(), 23 | "optimizer": optimizer.state_dict(), 24 | "best_loss": best_loss, 25 | } 26 | 27 | torch.save(save_states, os.path.join(result_path, "checkpoint.pth")) 28 | logger.debug("successfully saved the ckeckpoint.") 29 | 30 | 31 | def save_checkpoint_BEDSRNet( 32 | result_path: str, 33 | epoch: int, 34 | generator: nn.Module, 35 | discriminator: nn.Module, 36 | optimizerG: optim.Optimizer, 37 | optimizerD: optim.Optimizer, 38 | best_g_loss: float, 39 | best_d_loss: float, 40 | ) -> None: 41 | 42 | save_states = { 43 | "epoch": epoch, 44 | "state_dictG": generator.state_dict(), 45 | "optimizerG": optimizerG.state_dict(), 46 | "best_g_loss": best_g_loss, 47 | } 48 | 49 | torch.save(save_states, os.path.join(result_path, "g_checkpoint.pth")) 50 | logger.debug("successfully saved the generator's ckeckpoint.") 51 | 52 | save_states = { 53 | "epoch": epoch, 54 | "state_dictD": discriminator.state_dict(), 55 | "optimizerD": optimizerD.state_dict(), 56 | "best_d_loss": best_d_loss, 57 | } 58 | 59 | torch.save(save_states, os.path.join(result_path, "d_checkpoint.pth")) 60 | logger.debug("successfully saved the discriminator's ckeckpoint.") 61 | 62 | 63 | def resume( 64 | resume_path: str, model: nn.Module, optimizer: optim.Optimizer 65 | ) -> Tuple[int, nn.Module, optim.Optimizer, float]: 66 | try: 67 | checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage) 68 | logger.info("loading checkpoint {}".format(resume_path)) 69 | except FileNotFoundError( 70 | "there is no checkpoint at the result folder." 71 | ) as e: # type: ignore 72 | logger.exception(f"{e}") 73 | 74 | begin_epoch = checkpoint["epoch"] 75 | best_loss = checkpoint["best_loss"] 76 | model.load_state_dict(checkpoint["state_dict"]) 77 | 78 | optimizer.load_state_dict(checkpoint["optimizer"]) 79 | 80 | logger.info("training will start from {} epoch".format(begin_epoch)) 81 | 82 | return begin_epoch, model, optimizer, best_loss 83 | 84 | 85 | def resume_BEDSRNet( 86 | resume_path: str, 87 | generator: nn.Module, 88 | discriminator: nn.Module, 89 | optimizerG: optim.Optimizer, 90 | optimizerD: optim.Optimizer, 91 | ) -> Tuple[int, nn.Module, nn.Module, optim.Optimizer, optim.Optimizer, float, float]: 92 | try: 93 | checkpoint_g = torch.load( 94 | os.path.join(resume_path + "g_checkpoint.pth"), 95 | map_location=lambda storage, loc: storage, 96 | ) 97 | logger.info( 98 | "loading checkpoint {}".format( 99 | os.path.join(resume_path + "g_checkpoint.pth") 100 | ) 101 | ) 102 | checkpoint_d = torch.load( 103 | os.path.join(resume_path + "d_checkpoint.pth"), 104 | map_location=lambda storage, loc: storage, 105 | ) 106 | logger.info( 107 | "loading checkpoint {}".format( 108 | os.path.join(resume_path + "d_checkpoint.pth") 109 | ) 110 | ) 111 | except FileNotFoundError( 112 | "there is no checkpoint at the result folder." 113 | ) as e: # type: ignore 114 | logger.exception(f"{e}") 115 | 116 | begin_epoch = checkpoint_g["epoch"] 117 | best_g_loss = checkpoint_g["best_g_loss"] 118 | best_d_loss = checkpoint_d["best_d_loss"] 119 | 120 | generator.load_state_dict(checkpoint_g["state_dict"]) 121 | discriminator.load_state_dict(checkpoint_d["state_dict"]) 122 | 123 | optimizerG.load_state_dict(checkpoint_g["optimizer"]) 124 | optimizerD.load_state_dict(checkpoint_d["optimizer"]) 125 | 126 | logger.info("training will start from {} epoch".format(begin_epoch)) 127 | 128 | return ( 129 | begin_epoch, 130 | generator, 131 | discriminator, 132 | optimizerG, 133 | optimizerD, 134 | best_g_loss, 135 | best_d_loss, 136 | ) 137 | -------------------------------------------------------------------------------- /libs/config.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from logging import getLogger 3 | from pprint import pformat 4 | from typing import Any, Dict, Tuple 5 | 6 | import yaml 7 | 8 | from .dataset_csv import DATASET_CSVS 9 | 10 | __all__ = ["get_config"] 11 | 12 | logger = getLogger(__name__) 13 | 14 | 15 | @dataclasses.dataclass(frozen=True) 16 | class Config: 17 | """Experimental configuration class.""" 18 | 19 | model: str = "bedsrnet " 20 | pretrained: bool = True 21 | 22 | batch_size: int = 32 23 | 24 | width: int = 256 25 | height: int = 256 26 | 27 | num_workers: int = 2 28 | max_epoch: int = 50 29 | 30 | learning_rate: float = 0.003 31 | 32 | dataset_name: str = "Jung" 33 | 34 | loss_function_name: str = "L1" 35 | 36 | lambda1: float = 1.0 37 | lambda2: float = 0.01 38 | 39 | beta1: float = 0.5 40 | beta2: float = 0.999 41 | 42 | def __post_init__(self) -> None: 43 | self._type_check() 44 | self._value_check() 45 | 46 | logger.info( 47 | "Experiment Configuration\n" + pformat(dataclasses.asdict(self), width=1) 48 | ) 49 | 50 | def _value_check(self) -> None: 51 | if self.dataset_name not in DATASET_CSVS: 52 | message = ( 53 | f"dataset_name should be selected from {list(DATASET_CSVS.keys())}." 54 | ) 55 | logger.error(message) 56 | raise ValueError(message) 57 | 58 | if self.max_epoch <= 0: 59 | message = "max_epoch must be positive." 60 | logger.error(message) 61 | raise ValueError(message) 62 | 63 | def _type_check(self) -> None: 64 | """Reference: 65 | https://qiita.com/obithree/items/1c2b43ca94e4fbc3aa8d 66 | """ 67 | 68 | _dict = dataclasses.asdict(self) 69 | 70 | for field, field_type in self.__annotations__.items(): 71 | # if you use type annotation class provided by `typing`, 72 | # you should convert it to the type class used in python. 73 | # e.g.) Tuple[int] -> tuple 74 | # https://stackoverflow.com/questions/51171908/extracting-data-from-typing-types 75 | 76 | # check the instance is Tuple or not. 77 | # https://github.com/zalando/connexion/issues/739 78 | if hasattr(field_type, "__origin__"): 79 | # e.g.) Tuple[int].__args__[0] -> `int` 80 | element_type = field_type.__args__[0] 81 | 82 | # e.g.) Tuple[int].__origin__ -> `tuple` 83 | field_type = field_type.__origin__ 84 | 85 | self._type_check_element(field, _dict[field], element_type) 86 | 87 | # bool is the subclass of int, 88 | # so need to use `type() is` instead of `isinstance` 89 | if type(_dict[field]) is not field_type: 90 | message = f"The type of '{field}' field is supposed to be {field_type}." 91 | logger.error(message) 92 | raise TypeError(message) 93 | 94 | def _type_check_element( 95 | self, field: str, vals: Tuple[Any], element_type: type 96 | ) -> None: 97 | for val in vals: 98 | if type(val) is not element_type: 99 | message = ( 100 | f"The element of '{field}' field is supposed to be {element_type}." 101 | ) 102 | logger.error(message) 103 | raise TypeError(message) 104 | 105 | 106 | def convert_list2tuple(_dict: Dict[str, Any]) -> Dict[str, Any]: 107 | # cannot use list in dataclass because mutable defaults are not allowed. 108 | for key, val in _dict.items(): 109 | if isinstance(val, list): 110 | _dict[key] = tuple(val) 111 | 112 | logger.debug("converted list to tuple in dictionary.") 113 | return _dict 114 | 115 | 116 | def get_config(config_path: str) -> Config: 117 | with open(config_path, "r") as f: 118 | config_dict = yaml.safe_load(f) 119 | 120 | config_dict = convert_list2tuple(config_dict) 121 | config = Config(**config_dict) 122 | 123 | logger.info("successfully loaded configuration.") 124 | return config 125 | -------------------------------------------------------------------------------- /libs/dataset.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | from typing import Any, Dict, Optional 3 | 4 | import albumentations as A 5 | import cv2 6 | import numpy as np 7 | import pandas as pd 8 | import torch 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from .dataset_csv import DATASET_CSVS 12 | 13 | __all__ = ["get_dataloader"] 14 | 15 | logger = getLogger(__name__) 16 | 17 | 18 | def get_dataloader( 19 | dataset_name: str, 20 | train_model: str, 21 | split: str, 22 | batch_size: int, 23 | shuffle: bool, 24 | num_workers: int, 25 | pin_memory: bool, 26 | drop_last: bool = False, 27 | transform: Optional[A.Compose] = None, 28 | ) -> DataLoader: 29 | if dataset_name not in DATASET_CSVS: 30 | message = f"dataset_name should be selected from {list(DATASET_CSVS.keys())}." 31 | logger.error(message) 32 | raise ValueError(message) 33 | 34 | if train_model not in ["benet", "bedsrnet", "stcgan", "stcgan-be"]: 35 | message = "dataset_name should be selected from\ 36 | ['benet', 'bedsrnet', 'stcgan', 'stcgan-be']." 37 | logger.error(message) 38 | raise ValueError(message) 39 | 40 | if split not in ["train", "val", "test"]: 41 | message = "split should be selected from ['train', 'val', 'test']." 42 | logger.error(message) 43 | raise ValueError(message) 44 | 45 | logger.info(f"Dataset: {dataset_name}\tSplit: {split}\tBatch size: {batch_size}.") 46 | 47 | data: Dataset 48 | csv_file = getattr(DATASET_CSVS[dataset_name], split) 49 | 50 | if train_model == "benet": 51 | data = BackGroundDataset(csv_file, transform=transform) 52 | elif ( 53 | train_model == "bedsrnet" 54 | or train_model == "stcgan" 55 | or train_model == "stcgan_be" 56 | ): 57 | data = ShadowDocumentDataset(csv_file, transform=transform) 58 | 59 | dataloader = DataLoader( 60 | data, 61 | batch_size=batch_size, 62 | shuffle=shuffle, 63 | num_workers=num_workers, 64 | pin_memory=pin_memory, 65 | drop_last=drop_last, 66 | ) 67 | 68 | return dataloader 69 | 70 | 71 | class BackGroundDataset(Dataset): 72 | def __init__(self, csv_file: str, transform: Optional[A.Compose] = None) -> None: 73 | super().__init__() 74 | 75 | try: 76 | self.df = pd.read_csv(csv_file) 77 | except FileNotFoundError("csv file not found.") as e: # type: ignore 78 | logger.exception(f"{e}") 79 | 80 | self.transform = transform 81 | 82 | logger.info(f"the number of samples: {len(self.df)}") 83 | 84 | def __len__(self) -> int: 85 | return len(self.df) 86 | 87 | def __getitem__(self, idx: int) -> Dict[str, Any]: 88 | img_path = self.df.iloc[idx]["input"] 89 | 90 | img = cv2.imread(img_path) 91 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 92 | 93 | if self.transform is not None: 94 | img = self.transform(image=img) 95 | 96 | rgb = ( 97 | torch.Tensor( 98 | [self.df.iloc[idx]["R"], self.df.iloc[idx]["G"], self.df.iloc[idx]["B"]] 99 | ) 100 | / 255 101 | ) 102 | rgb = (rgb - 0.5) / 0.5 103 | 104 | sample = {"img": img["image"], "rgb": rgb, "img_path": img_path} 105 | 106 | return sample 107 | 108 | 109 | class ShadowDocumentDataset(Dataset): 110 | def __init__(self, csv_file: str, transform: Optional[A.Compose] = None) -> None: 111 | super().__init__() 112 | 113 | try: 114 | self.df = pd.read_csv(csv_file) 115 | except FileNotFoundError("csv file not found.") as e: # type: ignore 116 | logger.exception(f"{e}") 117 | 118 | self.transform = transform 119 | 120 | logger.info(f"the number of samples: {len(self.df)}") 121 | 122 | def __len__(self) -> int: 123 | return len(self.df) 124 | 125 | def __getitem__(self, idx: int) -> Dict[str, Any]: 126 | img_path = self.df.iloc[idx]["input"] 127 | gt_path = self.df.iloc[idx]["target"] 128 | 129 | img = cv2.imread(img_path) 130 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 131 | gt = cv2.imread(gt_path) 132 | gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB) 133 | 134 | images = np.concatenate([img, gt], axis=2) 135 | 136 | if self.transform is not None: 137 | res = self.transform(image=images)["image"] 138 | img = res[0:3, :, :] 139 | gt = res[3:, :, :] 140 | 141 | sample = {"img": img, "gt": gt, "img_path": img_path} 142 | 143 | return sample 144 | -------------------------------------------------------------------------------- /libs/dataset_csv.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from logging import getLogger 3 | 4 | logger = getLogger(__name__) 5 | 6 | __all__ = ["DATASET_CSVS"] 7 | 8 | 9 | @dataclasses.dataclass(frozen=True) 10 | class DatasetCSV: 11 | train: str 12 | test: str 13 | 14 | 15 | DATASET_CSVS = { 16 | # paths from `src` directory 17 | "Jung": DatasetCSV( 18 | train="./csv/Jung/train.csv", 19 | test="./csv/Jung/test.csv", 20 | ), 21 | "Kligler": DatasetCSV( 22 | train="./csv/Kligler/train.csv", 23 | test="./csv/Kligler/test.csv", 24 | ), 25 | "Shadoc": DatasetCSV( 26 | train="./csv/Shadoc/train.csv", 27 | test="./csv/Shadoc/test.csv", 28 | ), 29 | "Adobe": DatasetCSV( 30 | train="./csv/Adobe/train.csv", 31 | test="./csv/Adobe/test.csv", 32 | ), 33 | "HS": DatasetCSV( 34 | train="./csv/HS/train.csv", 35 | test="./csv/HS/test.csv", 36 | ), 37 | "Shadoc": DatasetCSV( 38 | train="./csv/Shadoc/train.csv", 39 | test="./csv/Shadoc/test.csv", 40 | ), 41 | } 42 | -------------------------------------------------------------------------------- /libs/device.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | import torch 4 | 5 | logger = getLogger(__name__) 6 | 7 | 8 | def get_device(allow_only_gpu: bool = True) -> str: 9 | if torch.cuda.is_available(): 10 | device = "cuda" 11 | torch.backends.cudnn.benchmark = True 12 | else: 13 | if allow_only_gpu: 14 | message = ( 15 | "You can use only cpu while you don't" 16 | "allow the use of cpu alone during training." 17 | ) 18 | logger.error(message) 19 | raise ValueError(message) 20 | 21 | device = "cpu" 22 | logger.warning( 23 | "CPU will be used for training. It is better to use GPUs instead" 24 | "because training CNN is computationally expensive." 25 | ) 26 | 27 | return device 28 | -------------------------------------------------------------------------------- /libs/helper.py: -------------------------------------------------------------------------------- 1 | import time 2 | from logging import getLogger 3 | from typing import Any, Dict, Optional, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.utils.data import DataLoader 10 | from tqdm import tqdm 11 | 12 | from .meter import AverageMeter, ProgressMeter 13 | 14 | __all__ = ["train", "evaluate"] 15 | 16 | logger = getLogger(__name__) 17 | 18 | 19 | def do_one_iteration( 20 | sample: Dict[str, Any], 21 | model: nn.Module, 22 | criterion: Any, 23 | device: str, 24 | iter_type: str, 25 | optimizer: Optional[optim.Optimizer] = None, 26 | ) -> Tuple[int, float, np.ndarray, np.ndarray]: 27 | 28 | if iter_type not in ["train", "evaluate"]: 29 | message = "iter_type must be either 'train' or 'evaluate'." 30 | logger.error(message) 31 | raise ValueError(message) 32 | 33 | if iter_type == "train" and optimizer is None: 34 | message = "optimizer must be set during training." 35 | logger.error(message) 36 | raise ValueError(message) 37 | 38 | x = sample["img"].to(device) 39 | t = sample["rgb"].to(device) 40 | 41 | batch_size = x.shape[0] 42 | 43 | # compute output and loss 44 | output = model(x) 45 | loss = criterion(output, t) 46 | 47 | # keep predicted results and gts for calculate F1 Score 48 | gt = t.to("cpu").numpy() 49 | pred = output.detach().to("cpu").numpy() 50 | 51 | if iter_type == "train" and optimizer is not None: 52 | # compute gradient and do SGD step 53 | optimizer.zero_grad() 54 | loss.backward() 55 | optimizer.step() 56 | 57 | return batch_size, loss.item(), gt, pred 58 | 59 | 60 | def train( 61 | loader: DataLoader, 62 | model: nn.Module, 63 | criterion: Any, 64 | optimizer: optim.Optimizer, 65 | epoch: int, 66 | device: str, 67 | interval_of_progress: int = 50, 68 | ) -> float: 69 | 70 | batch_time = AverageMeter("Time", ":6.3f") 71 | data_time = AverageMeter("Data", ":6.3f") 72 | losses = AverageMeter("Loss", ":.4e") 73 | 74 | progress = ProgressMeter( 75 | len(loader), 76 | [batch_time, data_time, losses], # , top1 77 | prefix="Epoch: [{}]".format(epoch), 78 | ) 79 | 80 | # keep predicted results and gts for calculate F1 Score 81 | gts = [] 82 | preds = [] 83 | 84 | # switch to train mode 85 | model.train() 86 | 87 | end = time.time() 88 | for i, sample in enumerate(tqdm(loader)): 89 | # measure data loading time 90 | data_time.update(time.time() - end) 91 | 92 | batch_size, loss, gt, pred = do_one_iteration( 93 | sample, model, criterion, device, "train", optimizer 94 | ) 95 | 96 | losses.update(loss, batch_size) 97 | 98 | # save the ground truths and predictions in lists 99 | gts += list(gt) 100 | preds += list(pred) 101 | 102 | # measure elapsed time 103 | batch_time.update(time.time() - end) 104 | end = time.time() 105 | 106 | # show progress bar per 50 iteration 107 | if i != 0 and i % interval_of_progress == 0: 108 | progress.display(i) 109 | 110 | return losses.get_average() 111 | 112 | 113 | def evaluate( 114 | loader: DataLoader, model: nn.Module, criterion: Any, device: str 115 | ) -> float: 116 | losses = AverageMeter("Loss", ":.4e") 117 | 118 | # keep predicted results and gts for calculate F1 Score 119 | gts = [] 120 | preds = [] 121 | 122 | # switch to evaluate mode 123 | model.eval() 124 | 125 | with torch.no_grad(): 126 | for sample in tqdm(loader): 127 | batch_size, loss, gt, pred = do_one_iteration( 128 | sample, model, criterion, device, "evaluate" 129 | ) 130 | 131 | losses.update(loss, batch_size) 132 | 133 | # keep predicted results and gts for calculate F1 Score 134 | gts += list(gt) 135 | preds += list(pred) 136 | 137 | return losses.get_average() 138 | -------------------------------------------------------------------------------- /libs/helper_bedsrnet.py: -------------------------------------------------------------------------------- 1 | import time 2 | from logging import getLogger 3 | from typing import Any, Dict, List, Optional, Tuple 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from torch.autograd import Variable 11 | from torch.utils.data import DataLoader 12 | from torchvision.utils import save_image 13 | 14 | from .meter import AverageMeter, ProgressMeter 15 | from .metric import calc_psnr, calc_ssim 16 | from .visualize_grid import make_grid, unnormalize 17 | 18 | from pytorch_grad_cam import GradCAM 19 | from tqdm import tqdm 20 | import os 21 | 22 | __all__ = ["train", "evaluate"] 23 | 24 | logger = getLogger(__name__) 25 | 26 | 27 | def set_requires_grad(nets, requires_grad=False): 28 | for net in nets: 29 | if net is not None: 30 | for param in net.parameters(): 31 | param.requires_grad = requires_grad 32 | 33 | 34 | def do_one_iteration( 35 | infer: bool, 36 | sample: Dict[str, Any], 37 | generator: nn.Module, 38 | discriminator: nn.Module, 39 | benet: nn.Module, 40 | grad_cam: Any, 41 | criterion: Any, 42 | device: str, 43 | iter_type: str, 44 | lambda_dict: Dict, 45 | optimizerG: Optional[optim.Optimizer] = None, 46 | optimizerD: Optional[optim.Optimizer] = None, 47 | ) -> Tuple[ 48 | int, 49 | float, 50 | float, 51 | np.ndarray, 52 | np.ndarray, 53 | np.ndarray, 54 | np.ndarray, 55 | np.ndarray, 56 | float, 57 | float, 58 | ]: 59 | 60 | if iter_type not in ["train", "evaluate"]: 61 | message = "iter_type must be either 'train' or 'evaluate'." 62 | logger.error(message) 63 | raise ValueError(message) 64 | 65 | if iter_type == "train" and (optimizerG is None or optimizerD is None): 66 | message = "optimizer must be set during training." 67 | logger.error(message) 68 | raise ValueError(message) 69 | 70 | Tensor = ( 71 | torch.cuda.FloatTensor # type: ignore 72 | if device != torch.device("cpu") 73 | else torch.FloatTensor 74 | ) 75 | 76 | x = sample["img"].to(device) 77 | gt = sample["gt"].to(device) 78 | 79 | 80 | batch_size, c, h, w = x.shape 81 | 82 | # compute output and loss 83 | # train discriminator 84 | if iter_type == "train" and optimizerD is not None: 85 | set_requires_grad([discriminator], True) 86 | optimizerD.zero_grad() 87 | 88 | with torch.set_grad_enabled(True): 89 | cams = [] 90 | back_grounds = [] 91 | for i in range(batch_size): 92 | # color, cam, _ = benet(x[i].unsqueeze(dim=0)) 93 | color = benet(x[i].unsqueeze(dim=0)) 94 | cam = torch.from_numpy(grad_cam(x[i].unsqueeze(dim=0))).unsqueeze(dim=0) 95 | cam = (cam - 0.5) / 0.5 # clamp [-1.0, 1.0] 96 | cam = torch.nan_to_num(cam, nan=0.0) 97 | cams.append(cam.detach()) 98 | back_color = color.detach().repeat_interleave(h*w).reshape(c, h, w) 99 | back_grounds.append(back_color.unsqueeze(0)) 100 | 101 | attention_map = torch.cat(cams, dim=0) 102 | back_ground = torch.cat(back_grounds, dim=0) 103 | 104 | attention_map = attention_map.to(device) 105 | back_ground = back_ground.to(device) 106 | 107 | input = torch.cat([x, attention_map, back_ground], dim=1) 108 | 109 | shadow_removal_image = generator(input.to(device)) 110 | 111 | fake = torch.cat([x, shadow_removal_image], dim=1) 112 | real = torch.cat([x, gt], dim=1) 113 | 114 | out_D_fake = discriminator(fake.detach()) 115 | out_D_real = discriminator(real.detach()) 116 | 117 | label_D_fake = Variable(Tensor(np.zeros(out_D_fake.size())), requires_grad=True) 118 | label_D_real = Variable(Tensor(np.ones(out_D_fake.size())), requires_grad=True) 119 | 120 | loss_D_fake = criterion[1](out_D_fake, label_D_fake) 121 | loss_D_real = criterion[1](out_D_real, label_D_real) 122 | 123 | D_L_GAN = loss_D_fake + loss_D_real 124 | 125 | D_loss = lambda_dict["lambda2"] * D_L_GAN 126 | 127 | if iter_type == "train" and optimizerD is not None: 128 | D_loss.backward() 129 | optimizerD.step() 130 | 131 | # train generator 132 | if iter_type == "train" and optimizerG is not None: 133 | set_requires_grad([discriminator], False) 134 | optimizerG.zero_grad() 135 | 136 | fake = torch.cat([x, shadow_removal_image], dim=1) 137 | out_D_fake = discriminator(fake.detach()) 138 | 139 | G_L_GAN = criterion[1](out_D_fake, label_D_real) 140 | G_L_data = criterion[0](gt, shadow_removal_image) 141 | 142 | G_loss = lambda_dict["lambda1"] * G_L_data + lambda_dict["lambda2"] * G_L_GAN 143 | 144 | if iter_type == "train" and optimizerG is not None: 145 | G_loss.backward() 146 | optimizerG.step() 147 | 148 | x = x.detach().to("cpu").numpy() 149 | gt = gt.detach().to("cpu").numpy() 150 | pred = shadow_removal_image.detach().to("cpu").numpy() 151 | 152 | if infer: 153 | name = os.path.basename(sample['img_path'][0]) 154 | image = (unnormalize(pred[0]) * 255).astype('uint8') 155 | cv2.imwrite(os.path.join('results', name), image) 156 | 157 | attention_map = attention_map.detach().to("cpu").numpy() 158 | back_ground = back_ground.detach().to("cpu").numpy() 159 | out_D_fake.detach() 160 | out_D_real.detach() 161 | label_D_fake.detach() 162 | label_D_real.detach() 163 | 164 | psnr_score = calc_psnr(list(gt), list(pred)) 165 | ssim_score = calc_ssim(list(gt), list(pred)) 166 | 167 | return ( 168 | batch_size, 169 | G_loss.item(), 170 | D_loss.item(), 171 | x, 172 | gt, 173 | pred, 174 | attention_map, 175 | back_ground, 176 | psnr_score, 177 | ssim_score, 178 | ) 179 | 180 | 181 | def train( 182 | loader: DataLoader, 183 | generator: nn.Module, 184 | discriminator: nn.Module, 185 | benet: nn.Module, 186 | criterion: Any, 187 | lambda_dict: Dict, 188 | optimizerG: optim.Optimizer, 189 | optimizerD: optim.Optimizer, 190 | epoch: int, 191 | device: str, 192 | interval_of_progress: int = 50, 193 | ) -> Tuple[float, float, float, float, np.ndarray]: 194 | 195 | batch_time = AverageMeter("Time", ":6.3f") 196 | data_time = AverageMeter("Data", ":6.3f") 197 | g_losses = AverageMeter("Loss", ":.4e") 198 | d_losses = AverageMeter("Loss", ":.4e") 199 | psnr_scores = AverageMeter("PSNR", ":.4e") 200 | ssim_scores = AverageMeter("SSIM", ":.4e") 201 | 202 | progress = ProgressMeter( 203 | len(loader), 204 | [batch_time, data_time, g_losses, d_losses, psnr_scores, ssim_scores], 205 | prefix="Epoch: [{}]".format(epoch), 206 | ) 207 | 208 | # keep predicted results and gts for calculate F1 Score 209 | inputs: List[np.ndarary] = [] 210 | gts: List[np.ndarray] = [] 211 | preds: List[np.ndarray] = [] 212 | attention_maps: List[np.ndarray] = [] 213 | back_grounds: List[np.ndarray] = [] 214 | 215 | # switch to train mode 216 | generator.train() 217 | discriminator.train() 218 | 219 | target_layers = [benet.features[3]] 220 | grad_cam = GradCAM(model=benet, target_layers=target_layers, use_cuda=True) 221 | 222 | end = time.time() 223 | for i, sample in enumerate(tqdm(loader)): 224 | # measure data loading time 225 | data_time.update(time.time() - end) 226 | 227 | ( 228 | batch_size, 229 | g_loss, 230 | d_loss, 231 | input, 232 | gt, 233 | pred, 234 | attention_map, 235 | back_ground, 236 | psnr_score, 237 | ssim_score, 238 | ) = do_one_iteration( 239 | False, 240 | sample, 241 | generator, 242 | discriminator, 243 | benet, 244 | grad_cam, 245 | criterion, 246 | device, 247 | "train", 248 | lambda_dict, 249 | optimizerG, 250 | optimizerD, 251 | ) 252 | 253 | g_losses.update(g_loss, batch_size) 254 | d_losses.update(d_loss, batch_size) 255 | psnr_scores.update(psnr_score, batch_size) 256 | ssim_scores.update(ssim_score, batch_size) 257 | 258 | # save the ground truths and predictions in lists 259 | if len(inputs) <= 10: 260 | inputs += list(input) 261 | gts += list(gt) 262 | preds += list(pred) 263 | attention_maps += list(attention_map) 264 | back_grounds += list(back_ground) 265 | 266 | # measure elapsed time 267 | batch_time.update(time.time() - end) 268 | end = time.time() 269 | 270 | # show progress bar per 50 iteration 271 | if i != 0 and i % interval_of_progress == 0: 272 | progress.display(i) 273 | 274 | result_images = make_grid( 275 | [inputs[:5], preds[:5], gts[:5], attention_maps[:5], back_grounds[:5]] 276 | ) 277 | 278 | return ( 279 | g_losses.get_average(), 280 | d_losses.get_average(), 281 | psnr_scores.get_average(), 282 | ssim_scores.get_average(), 283 | result_images, 284 | ) 285 | 286 | 287 | def evaluate( 288 | loader: DataLoader, 289 | generator: nn.Module, 290 | discriminator: nn.Module, 291 | benet: nn.Module, 292 | criterion: Any, 293 | lambda_dict: Dict, 294 | device: str, 295 | ) -> Tuple[float, float, float, float, np.ndarray]: 296 | g_losses = AverageMeter("Loss", ":.4e") 297 | d_losses = AverageMeter("Loss", ":.4e") 298 | psnr_scores = AverageMeter("PSNR", ":.4e") 299 | ssim_scores = AverageMeter("SSIM", ":.4e") 300 | 301 | # keep predicted results and gts for calculate F1 Score 302 | inputs: List[np.ndarary] = [] 303 | gts: List[np.ndarray] = [] 304 | preds: List[np.ndarray] = [] 305 | attention_maps: List[np.ndarray] = [] 306 | back_grounds: List[np.ndarray] = [] 307 | 308 | # switch to evaluate mode 309 | generator.eval() 310 | discriminator.eval() 311 | 312 | target_layers = [benet.features[3]] 313 | grad_cam = GradCAM(model=benet, target_layers=target_layers, use_cuda=True) 314 | 315 | with torch.no_grad(): 316 | for sample in tqdm(loader): 317 | ( 318 | batch_size, 319 | g_loss, 320 | d_loss, 321 | input, 322 | gt, 323 | pred, 324 | attention_map, 325 | back_ground, 326 | psnr_score, 327 | ssim_score, 328 | ) = do_one_iteration( 329 | False, 330 | sample, 331 | generator, 332 | discriminator, 333 | benet, 334 | grad_cam, 335 | criterion, 336 | device, 337 | "evaluate", 338 | lambda_dict, 339 | ) 340 | 341 | g_losses.update(g_loss, batch_size) 342 | d_losses.update(d_loss, batch_size) 343 | psnr_scores.update(psnr_score, batch_size) 344 | ssim_scores.update(ssim_score, batch_size) 345 | 346 | # save the ground truths and predictions in lists 347 | if len(inputs) <= 10: 348 | inputs += list(input) 349 | gts += list(gt) 350 | preds += list(pred) 351 | attention_maps += list(attention_map) 352 | back_grounds += list(back_ground) 353 | 354 | result_images = make_grid( 355 | [inputs[:5], preds[:5], gts[:5], attention_maps[:5], back_grounds[:5]] 356 | ) 357 | 358 | return ( 359 | g_losses.get_average(), 360 | d_losses.get_average(), 361 | psnr_scores.get_average(), 362 | ssim_scores.get_average(), 363 | result_images, 364 | ) 365 | 366 | def infer( 367 | loader: DataLoader, 368 | generator: nn.Module, 369 | discriminator: nn.Module, 370 | benet: nn.Module, 371 | criterion: Any, 372 | lambda_dict: Dict, 373 | device: str, 374 | ): 375 | 376 | # switch to evaluate mode 377 | generator.eval() 378 | discriminator.eval() 379 | 380 | target_layers = [benet.features[3]] 381 | grad_cam = GradCAM(model=benet, target_layers=target_layers, use_cuda=True) 382 | 383 | with torch.no_grad(): 384 | for sample in tqdm(loader): 385 | do_one_iteration( 386 | True, 387 | sample, 388 | generator, 389 | discriminator, 390 | benet, 391 | grad_cam, 392 | criterion, 393 | device, 394 | "evaluate", 395 | lambda_dict, 396 | ) 397 | -------------------------------------------------------------------------------- /libs/logger.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | 3 | import pandas as pd 4 | 5 | logger = getLogger(__name__) 6 | 7 | 8 | class TrainLogger(object): 9 | def __init__(self, log_path: str, resume: bool) -> None: 10 | self.log_path = log_path 11 | self.columns = [ 12 | "epoch", 13 | "lr", 14 | "train_time[sec]", 15 | "train_loss", 16 | "val_time[sec]", 17 | "val_loss", 18 | ] 19 | 20 | if resume: 21 | self.df = self._load_log() 22 | else: 23 | self.df = pd.DataFrame(columns=self.columns) 24 | 25 | def _load_log(self) -> pd.DataFrame: 26 | try: 27 | df = pd.read_csv(self.log_path) 28 | logger.info("successfully loaded log csv file.") 29 | return df 30 | except FileNotFoundError as err: 31 | logger.exception(f"{err}") 32 | raise err 33 | 34 | def _save_log(self) -> None: 35 | self.df.to_csv(self.log_path, index=False) 36 | logger.debug("training logs are saved.") 37 | 38 | def update( 39 | self, 40 | epoch: int, 41 | lr: float, 42 | train_time: int, 43 | train_loss: float, 44 | val_time: int, 45 | val_loss: float, 46 | ) -> None: 47 | tmp = pd.Series( 48 | [ 49 | epoch, 50 | lr, 51 | train_time, 52 | train_loss, 53 | val_time, 54 | val_loss, 55 | ], 56 | index=self.columns, 57 | ) 58 | 59 | self.df = pd.concat([self.df, tmp]) 60 | self._save_log() 61 | 62 | logger.info( 63 | f"epoch: {epoch}\tepoch time[sec]: {train_time + val_time}\tlr: {lr}\t" 64 | f"train loss: {train_loss:.4f}\tval loss: {val_loss:.4f}\t" 65 | ) 66 | 67 | 68 | class TrainLoggerBEDSRNet(object): 69 | def __init__(self, log_path: str, resume: bool) -> None: 70 | self.log_path = log_path 71 | self.columns = [ 72 | "epoch", 73 | "lrG", 74 | "lrD", 75 | "train_time[sec]", 76 | "train_g_loss", 77 | "train_d_loss", 78 | "val_time[sec]", 79 | "val_g_loss", 80 | "val_d_loss", 81 | "train_psnr", 82 | "train_ssim", 83 | "val_psnr", 84 | "val_ssim", 85 | ] 86 | 87 | if resume: 88 | self.df = self._load_log() 89 | else: 90 | self.df = pd.DataFrame(columns=self.columns) 91 | 92 | def _load_log(self) -> pd.DataFrame: 93 | try: 94 | df = pd.read_csv(self.log_path) 95 | logger.info("successfully loaded log csv file.") 96 | return df 97 | except FileNotFoundError as err: 98 | logger.exception(f"{err}") 99 | raise err 100 | 101 | def _save_log(self) -> None: 102 | self.df.to_csv(self.log_path, index=False) 103 | logger.debug("training logs are saved.") 104 | 105 | def update( 106 | self, 107 | epoch: int, 108 | lrG: float, 109 | lrD: float, 110 | train_time: int, 111 | train_g_loss: float, 112 | train_d_loss: float, 113 | val_time: int, 114 | val_g_loss: float, 115 | val_d_loss: float, 116 | train_psnr: float, 117 | train_ssim: float, 118 | val_psnr: float, 119 | val_ssim: float, 120 | ) -> None: 121 | tmp = pd.Series( 122 | [ 123 | epoch, 124 | lrG, 125 | lrD, 126 | train_time, 127 | train_g_loss, 128 | train_d_loss, 129 | val_time, 130 | val_g_loss, 131 | val_d_loss, 132 | train_psnr, 133 | train_ssim, 134 | val_psnr, 135 | val_ssim, 136 | ], 137 | index=self.columns, 138 | ) 139 | 140 | self.df = pd.concat([self.df, tmp]) 141 | self._save_log() 142 | 143 | logger.info( 144 | f"epoch: {epoch}\tepoch time[sec]: {train_time + val_time}\tlr: {lrG}\t" 145 | f"train g loss: {train_g_loss:.4f}\tval g loss: {val_g_loss:.4f}\t" 146 | f"train d loss: {train_d_loss:.4f}\tval d loss: {val_d_loss:.4f}\t" 147 | f"train psnr: {train_d_loss:.4f}\tval psnr: {val_d_loss:.4f}\t" 148 | f"train ssim: {train_d_loss:.4f}\tval ssim: {val_d_loss:.4f}\t" 149 | ) 150 | -------------------------------------------------------------------------------- /libs/loss.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-Reimplementation/BEDSR-Net-Reimplementation/8004a9abf8cc2e8550a46c7d1b894075d7a1b752/libs/loss.py -------------------------------------------------------------------------------- /libs/loss_fn/__init__.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | from typing import List, Optional, Union 3 | 4 | import torch.nn as nn 5 | 6 | # from ..dataset_csv import DATASET_CSVS 7 | # from .class_weight import get_class_weight 8 | 9 | __all__ = ["get_criterion"] 10 | logger = getLogger(__name__) 11 | 12 | 13 | def get_criterion( 14 | loss_function_name: Optional[str] = None, 15 | device: Optional[str] = None, 16 | ) -> Union[nn.Module, List[nn.Module]]: 17 | criterion: Union[nn.Module, List[nn.Module]] 18 | if loss_function_name == "L1": 19 | criterion = nn.L1Loss().to(device) 20 | elif loss_function_name == "GAN": 21 | criterion = [nn.L1Loss().to(device), nn.BCEWithLogitsLoss().to(device)] 22 | else: 23 | criterion = nn.L1Loss().to(device) 24 | 25 | return criterion 26 | -------------------------------------------------------------------------------- /libs/meter.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | from typing import List 3 | 4 | logger = getLogger(__name__) 5 | 6 | 7 | class AverageMeter(object): 8 | """Computes and stores the average and current value.""" 9 | 10 | def __init__(self, name: str, fmt: str = ":f") -> None: 11 | self.name = name 12 | self.fmt = fmt 13 | self._reset() 14 | logger.debug("Average meter is set up.") 15 | 16 | def _reset(self) -> None: 17 | self.val = 0.0 18 | self.avg = 0.0 19 | self.sum = 0.0 20 | self.count = 0 21 | 22 | def update(self, val: float, n: int = 1) -> None: 23 | # `val` is the average value of `n` samples 24 | self.val = val 25 | self.sum += val * n 26 | self.count += n 27 | self.avg = self.sum / self.count 28 | 29 | def get_average(self) -> float: 30 | return self.avg 31 | 32 | def __str__(self) -> str: 33 | fmtstr = "{name} {val" + self.fmt + "} (avg. {avg" + self.fmt + "})" 34 | return fmtstr.format(**self.__dict__) 35 | 36 | 37 | class ProgressMeter(object): 38 | def __init__( 39 | self, num_batches: int, meters: List[AverageMeter], prefix: str = "" 40 | ) -> None: 41 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 42 | self.meters = meters 43 | self.prefix = prefix 44 | 45 | logger.debug("Progress meter is set up.") 46 | 47 | def display(self, batch: int) -> None: 48 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 49 | 50 | # show current values and average values 51 | entries += [str(meter) for meter in self.meters] 52 | logger.info("\t".join(entries)) 53 | 54 | def _get_batch_fmtstr(self, num_batches: int) -> str: 55 | num_digits = len(str(num_batches // 1)) 56 | # format the number of digits for string 57 | fmt = "{:" + str(num_digits) + "d}" 58 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 59 | -------------------------------------------------------------------------------- /libs/metric.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from skimage.metrics import peak_signal_noise_ratio as psnr 6 | from skimage.metrics import structural_similarity as ssim 7 | 8 | 9 | def calc_psnr(gts: List[np.ndarray], preds: List[np.ndarray]) -> float: 10 | psnrs: List[float] = [] 11 | for gt, pred in zip(gts, preds): 12 | psnrs.append( 13 | psnr( 14 | gt.transpose([1, 2, 0]) * 0.5 + 0.5, 15 | pred.transpose([1, 2, 0]) * 0.5 + 0.5, 16 | data_range=1, 17 | ), 18 | ) 19 | 20 | return np.mean(psnrs) 21 | 22 | 23 | def calc_ssim(gts: List[np.ndarray], preds: List[np.ndarray]) -> float: 24 | ssims: List[float] = [] 25 | for gt, pred in zip(gts, preds): 26 | ssims.append( 27 | ssim( 28 | gt.transpose([1, 2, 0]) * 0.5 + 0.5, 29 | pred.transpose([1, 2, 0]) * 0.5 + 0.5, 30 | channel_axis=2, 31 | data_range=1, 32 | ), 33 | ) 34 | 35 | return np.mean(ssims) 36 | 37 | 38 | def calc_accuracy( 39 | output: torch.Tensor, target: torch.Tensor, topk: Tuple[int] = (1,) 40 | ) -> List[float]: 41 | """Computes the accuracy over the k top predictions. 42 | Args: 43 | output: (N, C). model output. 44 | target: (N, C). ground truth. 45 | topk: if you set (1, 5), top 1 and top 5 accuracy are calcuated. 46 | Return: 47 | res: List of calculated top k accuracy 48 | """ 49 | with torch.no_grad(): 50 | maxk = max(topk) 51 | batch_size = target.size(0) 52 | 53 | _, pred = output.topk(maxk, 1, True, True) 54 | pred = pred.t() 55 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 56 | 57 | res = [] 58 | for k in topk: 59 | correct_k = correct[:k].contiguous().view(-1) 60 | correct_k = correct_k.float().sum(0, keepdim=True) 61 | res.append(correct_k.mul_(100.0 / batch_size).item()) 62 | return res 63 | -------------------------------------------------------------------------------- /libs/models/__init__.py: -------------------------------------------------------------------------------- 1 | from logging import getLogger 2 | from typing import List, Union 3 | 4 | import torch.nn as nn 5 | 6 | from . import models 7 | 8 | __all__ = ["get_model"] 9 | 10 | model_names = ["benet", "cam_benet", "srnet", "stcgan"] 11 | logger = getLogger(__name__) 12 | 13 | 14 | def get_model( 15 | name: str, in_channels: int = True, pretrained: bool = True 16 | ) -> Union[nn.Module, List[nn.Module]]: 17 | name = name.lower() 18 | if name not in model_names: 19 | message = ( 20 | "There is no model appropriate to your choice. " 21 | """ 22 | You have to choose benet/cam_benet(BENet), 23 | srnet(SR-Net), stcgan(ST-CGAN, *unimplemented) as a model. 24 | """ 25 | ) 26 | logger.error(message) 27 | raise ValueError(message) 28 | 29 | logger.info("{} will be used as a model.".format(name)) 30 | 31 | if name == "srnet" or name == "stcgan": 32 | generator = getattr(models, "generator")(pretrained=pretrained) 33 | discriminator = getattr(models, "discriminator")(pretrained=pretrained) 34 | model = [generator, discriminator] 35 | elif name == "benet" or name == "cam_benet": 36 | model = getattr(models, name)(in_channels=in_channels, pretrained=pretrained) 37 | 38 | return model 39 | -------------------------------------------------------------------------------- /libs/models/cam.py: -------------------------------------------------------------------------------- 1 | # This module criated by yiskw713 2 | # https://github.com/yiskw713/SmoothGradCAMplusplus 3 | from statistics import mean, mode 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | class SaveValues: 10 | def __init__(self, m): 11 | # register a hook to save values of activations and gradients 12 | self.activations = None 13 | self.gradients = None 14 | self.forward_hook = m.register_forward_hook(self.hook_fn_act) 15 | self.backward_hook = m.register_backward_hook(self.hook_fn_grad) 16 | 17 | def hook_fn_act(self, module, input, output): 18 | self.activations = output 19 | 20 | def hook_fn_grad(self, module, grad_input, grad_output): 21 | self.gradients = grad_output[0] 22 | 23 | def remove(self): 24 | self.forward_hook.remove() 25 | self.backward_hook.remove() 26 | 27 | 28 | class CAM(object): 29 | """Class Activation Mapping""" 30 | 31 | def __init__(self, model, target_layer): 32 | """ 33 | Args: 34 | model: a base model to get CAM which have 35 | global pooling and fully connected layer. 36 | target_layer: conv_layer before Global Average Pooling 37 | """ 38 | 39 | self.model = model 40 | self.target_layer = target_layer 41 | 42 | # save values of activations and gradients in target_layer 43 | self.values = SaveValues(self.target_layer) 44 | 45 | def forward(self, x, idx=None): 46 | """ 47 | Args: 48 | x: input image. shape =>(1, 3, H, W) 49 | Return: 50 | heatmap: class activation mappings of the predicted class 51 | """ 52 | 53 | # object classification 54 | score = self.model(x) 55 | 56 | prob = F.softmax(score, dim=1) 57 | 58 | if idx is None: 59 | prob, idx = torch.max(prob, dim=1) 60 | idx = idx.item() 61 | prob = prob.item() 62 | # print("predicted class ids {}\t probability {}".format(idx, prob)) 63 | 64 | # cam can be calculated from the weights of linear layer and activations 65 | weight_fc = list(self.model._modules.get("fc").parameters())[0].to("cpu").data 66 | 67 | cam = self.getCAM(self.values, weight_fc, idx) 68 | 69 | return score, cam, idx 70 | 71 | def __call__(self, x): 72 | return self.forward(x) 73 | 74 | def getCAM(self, values, weight_fc, idx): 75 | """ 76 | values: the activations and gradients of target_layer 77 | activations: feature map before GAP. shape => (1, C, H, W) 78 | weight_fc: the weight of fully connected layer. shape => (num_classes, C) 79 | idx: predicted class id 80 | cam: class activation map. shape => (1, num_classes, H, W) 81 | """ 82 | 83 | cam = F.conv2d(values.activations, weight=weight_fc[:, :, None, None]) 84 | _, _, h, w = cam.shape 85 | 86 | # class activation mapping only for the predicted class 87 | # cam is normalized with min-max. 88 | cam = cam[:, idx, :, :] 89 | cam -= torch.min(cam) 90 | cam /= torch.max(cam) 91 | cam = cam.view(1, 1, h, w) 92 | 93 | return cam.data 94 | 95 | 96 | class GradCAM(CAM): 97 | """Grad CAM""" 98 | 99 | def __init__(self, model, target_layer): 100 | super().__init__(model, target_layer) 101 | 102 | """ 103 | Args: 104 | model: a base model to get CAM, which 105 | need not have global pooling and fully connected layer. 106 | target_layer: conv_layer you want to visualize 107 | """ 108 | 109 | def forward(self, x, idx=None): 110 | """ 111 | Args: 112 | x: input image. shape =>(1, 3, H, W) 113 | idx: ground truth index => (1, C) 114 | Return: 115 | heatmap: class activation mappings of the predicted class 116 | """ 117 | 118 | # anomaly detection 119 | score = self.model(x) 120 | 121 | prob = F.softmax(score, dim=1) 122 | 123 | if idx is None: 124 | prob, idx = torch.max(prob, dim=1) 125 | idx = idx.item() 126 | prob = prob.item() 127 | # print("predicted class ids {}\t probability {}".format(idx, prob)) 128 | 129 | # caluculate cam of the predicted class 130 | cam = self.getGradCAM(self.values, score, idx) 131 | 132 | return score, cam, idx 133 | 134 | def __call__(self, x): 135 | return self.forward(x) 136 | 137 | def getGradCAM(self, values, score, idx): 138 | """ 139 | values: the activations and gradients of target_layer 140 | activations: feature map before GAP. shape => (1, C, H, W) 141 | score: the output of the model before softmax 142 | idx: predicted class id 143 | cam: class activation map. shape=> (1, 1, H, W) 144 | """ 145 | 146 | self.model.zero_grad() 147 | 148 | score[0, idx].backward(retain_graph=True) 149 | 150 | activations = values.activations 151 | gradients = values.gradients 152 | n, c, _, _ = gradients.shape 153 | alpha = gradients.view(n, c, -1).mean(2) 154 | alpha = alpha.view(n, c, 1, 1) 155 | 156 | # shape => (1, 1, H', W') 157 | cam = (alpha * activations).sum(dim=1, keepdim=True) 158 | cam = F.relu(cam) 159 | cam -= torch.min(cam) 160 | cam /= torch.max(cam) 161 | 162 | return cam.data 163 | 164 | 165 | class GradCAMpp(CAM): 166 | """Grad CAM plus plus""" 167 | 168 | def __init__(self, model, target_layer): 169 | super().__init__(model, target_layer) 170 | """ 171 | Args: 172 | model: a base model 173 | target_layer: conv_layer you want to visualize 174 | """ 175 | 176 | def forward(self, x, idx=None): 177 | """ 178 | Args: 179 | x: input image. shape =>(1, 3, H, W) 180 | Return: 181 | heatmap: class activation mappings of predicted classes 182 | """ 183 | 184 | # object classification 185 | score = self.model(x) 186 | 187 | prob = F.softmax(score, dim=1) 188 | 189 | if idx is None: 190 | prob, idx = torch.max(prob, dim=1) 191 | idx = idx.item() 192 | prob = prob.item() 193 | # print("predicted class ids {}\t probability {}".format(idx, prob)) 194 | 195 | # caluculate cam of the predicted class 196 | cam = self.getGradCAMpp(self.values, score, idx) 197 | 198 | return score, cam, idx 199 | 200 | def __call__(self, x): 201 | return self.forward(x) 202 | 203 | def getGradCAMpp(self, values, score, idx): 204 | """ 205 | values: the activations and gradients of target_layer 206 | activations: feature map before GAP. shape => (1, C, H, W) 207 | score: the output of the model before softmax. shape => (1, n_classes) 208 | idx: predicted class id 209 | cam: class activation map. shape=> (1, 1, H, W) 210 | """ 211 | 212 | self.model.zero_grad() 213 | 214 | score[0, idx].backward(retain_graph=True) 215 | 216 | activations = values.activations 217 | gradients = values.gradients 218 | n, c, _, _ = gradients.shape 219 | 220 | # calculate alpha 221 | numerator = gradients.pow(2) 222 | denominator = 2 * gradients.pow(2) 223 | ag = activations * gradients.pow(3) 224 | denominator += ag.view(n, c, -1).sum(-1, keepdim=True).view(n, c, 1, 1) 225 | denominator = torch.where( 226 | denominator != 0.0, denominator, torch.ones_like(denominator) 227 | ) 228 | alpha = numerator / (denominator + 1e-7) 229 | 230 | relu_grad = F.relu(score[0, idx].exp() * gradients) 231 | weights = (alpha * relu_grad).view(n, c, -1).sum(-1).view(n, c, 1, 1) 232 | 233 | # shape => (1, 1, H', W') 234 | cam = (weights * activations).sum(1, keepdim=True) 235 | cam = F.relu(cam) 236 | cam -= torch.min(cam) 237 | cam /= torch.max(cam) 238 | 239 | return cam.data 240 | 241 | 242 | class SmoothGradCAMpp(CAM): 243 | """Smooth Grad CAM plus plus""" 244 | 245 | def __init__(self, model, target_layer, n_samples=25, stdev_spread=0.15): 246 | super().__init__(model, target_layer) 247 | """ 248 | Args: 249 | model: a base model 250 | target_layer: conv_layer you want to visualize 251 | n_sample: the number of samples 252 | stdev_spread: standard deviationß 253 | """ 254 | 255 | self.n_samples = n_samples 256 | self.stdev_spread = stdev_spread 257 | 258 | def forward(self, x, idx=None): 259 | """ 260 | Args: 261 | x: input image. shape =>(1, 3, H, W) 262 | Return: 263 | heatmap: class activation mappings of predicted classes 264 | """ 265 | 266 | stdev = self.stdev_spread / (x.max() - x.min()) 267 | std_tensor = torch.ones_like(x) * stdev 268 | 269 | indices = [] 270 | probs = [] 271 | 272 | for i in range(self.n_samples): 273 | self.model.zero_grad() 274 | 275 | x_with_noise = torch.normal(mean=x, std=std_tensor) 276 | x_with_noise.requires_grad_() 277 | 278 | score = self.model(x_with_noise) 279 | 280 | prob = F.softmax(score, dim=1) 281 | 282 | if idx is None: 283 | prob, idx = torch.max(prob, dim=1) 284 | idx = idx.item() 285 | probs.append(prob.item()) 286 | 287 | indices.append(idx) 288 | 289 | score[0, idx].backward(retain_graph=True) 290 | 291 | activations = self.values.activations 292 | gradients = self.values.gradients 293 | n, c, _, _ = gradients.shape 294 | 295 | # calculate alpha 296 | numerator = gradients.pow(2) 297 | denominator = 2 * gradients.pow(2) 298 | ag = activations * gradients.pow(3) 299 | denominator += ag.view(n, c, -1).sum(-1, keepdim=True).view(n, c, 1, 1) 300 | denominator = torch.where( 301 | denominator != 0.0, denominator, torch.ones_like(denominator) 302 | ) 303 | alpha = numerator / (denominator + 1e-7) 304 | 305 | relu_grad = F.relu(score[0, idx].exp() * gradients) 306 | weights = (alpha * relu_grad).view(n, c, -1).sum(-1).view(n, c, 1, 1) 307 | 308 | # shape => (1, 1, H', W') 309 | cam = (weights * activations).sum(1, keepdim=True) 310 | cam = F.relu(cam) 311 | cam -= torch.min(cam) 312 | cam /= torch.max(cam) 313 | 314 | if i == 0: 315 | total_cams = cam.clone() 316 | else: 317 | total_cams += cam 318 | 319 | total_cams /= self.n_samples 320 | idx = mode(indices) 321 | prob = mean(probs) 322 | 323 | print("predicted class ids {}\t probability {}".format(idx, prob)) 324 | 325 | return total_cams.data, idx 326 | 327 | def __call__(self, x): 328 | return self.forward(x) 329 | -------------------------------------------------------------------------------- /libs/models/fix_weight_dict.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | 4 | def fix_model_state_dict(state_dict) -> OrderedDict: 5 | """ 6 | remove 'module.' of dataparallel 7 | """ 8 | new_state_dict = OrderedDict() 9 | for k, v in state_dict.items(): 10 | name = k 11 | if name.startswith("module."): 12 | name = name[7:] 13 | new_state_dict[name] = v 14 | return new_state_dict 15 | -------------------------------------------------------------------------------- /libs/models/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any, Callable 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from .cam import GradCAM 8 | from .fix_weight_dict import fix_model_state_dict 9 | 10 | 11 | def weights_init(init_type="gaussian") -> Callable: 12 | def init_fun(m): 13 | classname = m.__class__.__name__ 14 | if (classname.find("Conv") == 0 or classname.find("Linear") == 0) and hasattr( 15 | m, "weight" 16 | ): 17 | if init_type == "gaussian": 18 | nn.init.normal_(m.weight, 0.0, 0.02) 19 | elif init_type == "xavier": 20 | nn.init.xavier_normal_(m.weight, gain=math.sqrt(2)) 21 | elif init_type == "kaiming": 22 | nn.init.kaiming_normal_(m.weight, a=0, mode="fan_in") 23 | elif init_type == "orthogonal": 24 | nn.init.orthogonal_(m.weight, gain=math.sqrt(2)) 25 | elif init_type == "default": 26 | pass 27 | else: 28 | assert 0, "Unsupported initialization: {}".format(init_type) 29 | if hasattr(m, "bias") and m.bias is not None: 30 | nn.init.constant_(m.bias, 0.0) 31 | 32 | return init_fun 33 | 34 | 35 | class BENet(nn.Module): 36 | def __init__(self, in_channels: int = 3, out_channels: int = 3) -> None: 37 | super(BENet, self).__init__() 38 | self.features = nn.Sequential( 39 | nn.Conv2d(in_channels, 32, kernel_size=3, stride=1, padding=1), 40 | nn.Conv2d(32, 32, kernel_size=3, padding=1), 41 | nn.Conv2d(32, 128, kernel_size=3, padding=1), 42 | nn.Conv2d(128, 128, kernel_size=3, padding=1), 43 | ) 44 | self.global_maxpool = nn.AdaptiveMaxPool2d((1, 1)) 45 | self.classifier = nn.Sequential(nn.Linear(128, out_channels), nn.Tanh()) 46 | 47 | def forward(self, x: torch.Tensor) -> torch.Tensor: 48 | x = self.features(x) 49 | x = self.global_maxpool(x) 50 | x = torch.flatten(x, 1) 51 | x = self.classifier(x) 52 | return x 53 | 54 | 55 | class Cvi(nn.Module): 56 | def __init__( 57 | self, 58 | in_channels: int, 59 | out_channels: int, 60 | before: str = None, 61 | after: str = None, 62 | kernel_size: int = 4, 63 | stride: int = 2, 64 | padding: int = 1, 65 | dilation: int = 1, 66 | groups: int = 1, 67 | bias: bool = False, 68 | ) -> None: 69 | super(Cvi, self).__init__() 70 | self.conv = nn.Conv2d( 71 | in_channels, 72 | out_channels, 73 | kernel_size, 74 | stride, 75 | padding, 76 | dilation, 77 | groups, 78 | bias, 79 | ) 80 | self.after: Any[Callable] 81 | self.before: Any[Callable] 82 | self.conv.apply(weights_init("gaussian")) 83 | if after == "BN": 84 | self.after = nn.BatchNorm2d(out_channels) 85 | elif after == "Tanh": 86 | self.after = torch.tanh 87 | elif after == "sigmoid": 88 | self.after = torch.sigmoid 89 | 90 | if before == "ReLU": 91 | self.before = nn.ReLU(inplace=True) 92 | elif before == "LReLU": 93 | self.before = nn.LeakyReLU(negative_slope=0.2, inplace=True) 94 | 95 | def forward(self, x: torch.Tensor) -> torch.Tensor: 96 | 97 | if hasattr(self, "before"): 98 | x = self.before(x) 99 | 100 | x = self.conv(x) 101 | 102 | if hasattr(self, "after"): 103 | x = self.after(x) 104 | 105 | return x 106 | 107 | 108 | class CvTi(nn.Module): 109 | def __init__( 110 | self, 111 | in_channels: int, 112 | out_channels: int, 113 | before: str = None, 114 | after: str = None, 115 | kernel_size: int = 4, 116 | stride: int = 2, 117 | padding: int = 1, 118 | dilation: int = 1, 119 | groups: int = 1, 120 | bias: bool = False, 121 | ) -> None: 122 | super(CvTi, self).__init__() 123 | self.after: Any[Callable] 124 | self.before: Any[Callable] 125 | self.conv = nn.ConvTranspose2d( 126 | in_channels, out_channels, kernel_size, stride, padding, bias=bias 127 | ) 128 | self.conv.apply(weights_init("gaussian")) 129 | if after == "BN": 130 | self.after = nn.BatchNorm2d(out_channels) 131 | elif after == "Tanh": 132 | self.after = torch.tanh 133 | elif after == "sigmoid": 134 | self.after = torch.sigmoid 135 | 136 | if before == "ReLU": 137 | self.before = nn.ReLU(inplace=True) 138 | elif before == "LReLU": 139 | self.before = nn.LeakyReLU(negative_slope=0.2, inplace=True) 140 | 141 | def forward(self, x: torch.Tensor) -> torch.Tensor: 142 | 143 | if hasattr(self, "before"): 144 | x = self.before(x) 145 | 146 | x = self.conv(x) 147 | 148 | if hasattr(self, "after"): 149 | x = self.after(x) 150 | 151 | return x 152 | 153 | 154 | class Generator(nn.Module): 155 | def __init__(self, in_channels: int = 7, out_channels: int = 3) -> None: 156 | super(Generator, self).__init__() 157 | 158 | self.Cv0 = Cvi(in_channels, 64) 159 | 160 | self.Cv1 = Cvi(64, 128, before="LReLU", after="BN") 161 | 162 | self.Cv2 = Cvi(128, 256, before="LReLU", after="BN") 163 | 164 | self.Cv3 = Cvi(256, 512, before="LReLU", after="BN") 165 | 166 | self.Cv4 = Cvi(512, 512, before="LReLU", after="BN") 167 | 168 | self.Cv5 = Cvi(512, 512, before="LReLU") 169 | 170 | self.CvT6 = CvTi(512, 512, before="ReLU", after="BN") 171 | 172 | self.CvT7 = CvTi(1024, 512, before="ReLU", after="BN") 173 | 174 | self.CvT8 = CvTi(1024, 256, before="ReLU", after="BN") 175 | 176 | self.CvT9 = CvTi(512, 128, before="ReLU", after="BN") 177 | 178 | self.CvT10 = CvTi(256, 64, before="ReLU", after="BN") 179 | 180 | self.CvT11 = CvTi(128, out_channels, before="ReLU", after="Tanh") 181 | 182 | def forward(self, input: torch.Tensor) -> torch.Tensor: 183 | # encoder 184 | x0 = self.Cv0(input) 185 | x1 = self.Cv1(x0) 186 | x2 = self.Cv2(x1) 187 | x3 = self.Cv3(x2) 188 | x4_1 = self.Cv4(x3) 189 | x4_2 = self.Cv4(x4_1) 190 | x4_3 = self.Cv4(x4_2) 191 | x5 = self.Cv5(x4_3) 192 | 193 | # decoder 194 | x6 = self.CvT6(x5) 195 | 196 | cat1_1 = torch.cat([x6, x4_3], dim=1) 197 | x7_1 = self.CvT7(cat1_1) 198 | cat1_2 = torch.cat([x7_1, x4_2], dim=1) 199 | x7_2 = self.CvT7(cat1_2) 200 | cat1_3 = torch.cat([x7_2, x4_1], dim=1) 201 | x7_3 = self.CvT7(cat1_3) 202 | 203 | cat2 = torch.cat([x7_3, x3], dim=1) 204 | x8 = self.CvT8(cat2) 205 | 206 | cat3 = torch.cat([x8, x2], dim=1) 207 | x9 = self.CvT9(cat3) 208 | 209 | cat4 = torch.cat([x9, x1], dim=1) 210 | x10 = self.CvT10(cat4) 211 | 212 | cat5 = torch.cat([x10, x0], dim=1) 213 | out = self.CvT11(cat5) 214 | 215 | return out 216 | 217 | 218 | class Discriminator(nn.Module): 219 | def __init__(self, in_channels=6) -> None: 220 | super(Discriminator, self).__init__() 221 | 222 | self.Cv0 = Cvi(in_channels, 64) 223 | 224 | self.Cv1 = Cvi(64, 128, before="LReLU", after="BN") 225 | 226 | self.Cv2 = Cvi(128, 256, before="LReLU", after="BN") 227 | 228 | self.Cv3 = Cvi(256, 512, before="LReLU", after="BN") 229 | 230 | self.Cv4 = Cvi(512, 1, before="LReLU", after="sigmoid") 231 | 232 | def forward(self, input: torch.Tensor) -> torch.Tensor: 233 | x0 = self.Cv0(input) 234 | x1 = self.Cv1(x0) 235 | x2 = self.Cv2(x1) 236 | x3 = self.Cv3(x2) 237 | out = self.Cv4(x3) 238 | 239 | return out 240 | 241 | 242 | def benet(pretrained: bool = False, **kwargs: Any) -> BENet: 243 | model = BENet(**kwargs) 244 | if pretrained: 245 | state_dict = torch.load("./configs/model=benet/pretrained_benet.prm") # map_location 246 | model.load_state_dict(fix_model_state_dict(state_dict)) 247 | return model 248 | 249 | 250 | def cam_benet(pretrained: bool = False, **kwargs: Any) -> GradCAM: 251 | model = BENet(**kwargs) 252 | if pretrained: 253 | state_dict = torch.load("./configs/model=benet/pretrained_benet.prm") # map_location 254 | model.load_state_dict(fix_model_state_dict(state_dict)) 255 | model.eval() 256 | target_layer = model.features[3] 257 | wrapped_model = GradCAM(model, target_layer) 258 | return wrapped_model 259 | 260 | 261 | def generator(pretrained: bool = False, **kwargs: Any) -> Generator: 262 | model = Generator(**kwargs) 263 | if pretrained: 264 | state_dict = torch.load("./configs/model=bedsrnet/pretrained_g_srnet.prm") 265 | model.load_state_dict(fix_model_state_dict(state_dict)) 266 | return model 267 | 268 | 269 | def discriminator(pretrained: bool = False, **kwargs: Any) -> Discriminator: 270 | model = Discriminator(**kwargs) 271 | if pretrained: 272 | state_dict = torch.load("./configs/model=bedsrnet/pretrained_d_srnet.prm") 273 | model.load_state_dict(fix_model_state_dict(state_dict)) 274 | return model 275 | -------------------------------------------------------------------------------- /libs/seed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from logging import getLogger 4 | 5 | import numpy as np 6 | import torch 7 | 8 | logger = getLogger(__name__) 9 | 10 | 11 | def set_seed(seed: int = 42) -> None: 12 | random.seed(seed) 13 | os.environ["PYTHONHASHSEED"] = str(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.backends.cudnn.deterministic = True 18 | 19 | logger.info("Finished setting up seed.") 20 | -------------------------------------------------------------------------------- /libs/transformer.py: -------------------------------------------------------------------------------- 1 | # if you want to use your own transformer,,, 2 | -------------------------------------------------------------------------------- /libs/visualize_grid.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | 7 | def unnormalize(img: np.ndarray, to_colormap: bool = True) -> np.ndarray: 8 | c, h, w = img.shape 9 | unnormalized_img = img.transpose([1, 2, 0]) * 0.5 + 0.5 10 | if c == 1 and to_colormap: 11 | unnormalized_img = ( 12 | cv2.applyColorMap( 13 | np.array((1-unnormalized_img) * 255, dtype=np.uint8), cv2.COLORMAP_JET 14 | ) 15 | / 255 16 | ) 17 | return unnormalized_img 18 | 19 | 20 | def make_grid(images: List[List[np.ndarray]], pad: int = 3) -> np.ndarray: 21 | for k in range(len(images)): 22 | c, h, w = images[k][0].shape 23 | dtype = images[k][0].dtype 24 | row_image = np.array(unnormalize(images[k][0]), dtype=dtype) 25 | for i in range(1, len(images[k])): 26 | add_image = cv2.hconcat( 27 | [ 28 | np.zeros([h, pad, 3], dtype=dtype), 29 | np.array(unnormalize(images[k][i]), dtype=dtype), 30 | ] 31 | ) 32 | row_image = cv2.hconcat([row_image, add_image]) 33 | 34 | if k == 0: 35 | grid_image = row_image 36 | else: 37 | h, w, c = row_image.shape 38 | add_image = cv2.vconcat([np.zeros([pad, w, c], dtype=dtype), row_image]) 39 | grid_image = cv2.vconcat([grid_image, add_image]) 40 | 41 | return grid_image 42 | 43 | 44 | def make_grid_gray(images: List[List[np.ndarray]], pad: int = 3) -> np.ndarray: 45 | for k in range(len(images)): 46 | _, h, w = images[k][0].shape 47 | dtype = images[k][0].dtype 48 | row_image = images[k][0].transpose([1, 2, 0]) * 0.5 + 0.5 49 | for i in range(1, len(images[k])): 50 | add_image = cv2.hconcat( 51 | [ 52 | np.zeros([h, pad, 1], dtype=dtype), 53 | images[k][i].transpose([1, 2, 0]) * 0.5 + 0.5, 54 | ] 55 | ) 56 | row_image = cv2.hconcat([row_image, add_image]) 57 | 58 | if k == 0: 59 | grid_image = row_image 60 | else: 61 | h, w = row_image.shape 62 | add_image = cv2.vconcat([np.zeros([pad, w], dtype=dtype), row_image]) 63 | grid_image = cv2.vconcat([grid_image, add_image]) 64 | 65 | return grid_image[:, :, np.newaxis] 66 | -------------------------------------------------------------------------------- /make_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import no_type_check 2 | import numpy as np 3 | from sklearn import mixture 4 | from sklearn.cluster import KMeans 5 | import cv2 6 | import pandas as pd 7 | import os 8 | import multiprocessing as mp 9 | from tqdm import tqdm 10 | 11 | def get_average_color(x): 12 | b, g, r = x[:, 0], x[:, 1], x[:, 2] 13 | 14 | return np.array([np.mean(b), np.mean(g), np.mean(r)]) 15 | 16 | df = pd.DataFrame() 17 | phase = 'train' 18 | dataset = 'RDD' 19 | img_path = os.path.join('dataset', dataset, phase, 'input') 20 | root_path = os.path.join('dataset', dataset, phase, 'target') 21 | paths = os.listdir(root_path) 22 | paths.sort() 23 | img_paths = [] 24 | gt_paths = [] 25 | background_colors = [[], [], []] 26 | 27 | def process_img(path): 28 | img_paths.append(os.path.join(img_path, path)) 29 | gt_paths.append(os.path.join(root_path, path)) 30 | 31 | x = cv2.imread(os.path.join(root_path, path)) 32 | h, w, c = x.shape 33 | x = x.flatten().reshape(h*w, c) 34 | gmm = mixture.GaussianMixture(n_components=2, covariance_type='full') 35 | gmm.fit(x) 36 | #km = KMeans(n_clusters=2) 37 | #km.fit(x) 38 | 39 | cls = gmm.predict(x.flatten().reshape(h*w, c)) 40 | #cls = km.predict(x.flatten().reshape(h*w, c)) 41 | cls0_colors = x[cls == 0] 42 | cls1_colors = x[cls == 1] 43 | 44 | cls0_avg_color = get_average_color(cls0_colors) 45 | cls1_avg_color = get_average_color(cls1_colors) 46 | 47 | 48 | if np.sum(cls0_avg_color)>=np.sum(cls1_avg_color): 49 | background_color = cls0_avg_color 50 | #cls = 1 - cls 51 | else: 52 | background_color = cls1_avg_color 53 | 54 | gmm_out = np.array([cls0_avg_color if i == 0 else cls1_avg_color for i in cls]) 55 | # cv2.imwrite('../dataset/Jung/'+phase+'/gmm/gmm_{:s}.jpg'.format(path), gmm_out.reshape(h, w, c)) 56 | #cv2.imwrite('../dataset/Jung/'+phase+'/kmeans/km_{:s}.jpg'.format(path), gmm_out.reshape(h, w, c)) 57 | #cv2.imwrite('../dataset/Jung/'+phase+'/background/background_{:s}.jpg'.format(path), np.full_like(x, background_color).reshape(h, w, c)) 58 | #cv2.imwrite('gmm/{:s}'.format(path), cls.reshape(h, w)*255) 59 | for i in range(3): 60 | background_colors[i].append(background_color[i]) 61 | 62 | 63 | 64 | if __name__ == '__main__': 65 | 66 | for path in tqdm(paths): 67 | process_img(path) 68 | 69 | df['input'] = img_paths 70 | df['target'] = gt_paths 71 | df['B'], df['G'], df['R'] = background_colors[0], background_colors[1], background_colors[2] 72 | os.makedirs(os.path.join('csv', dataset), exist_ok=True) 73 | df.to_csv(os.path.join('csv', dataset, phase + '.csv')) 74 | 75 | -------------------------------------------------------------------------------- /pretrained/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CV-Reimplementation/BEDSR-Net-Reimplementation/8004a9abf8cc2e8550a46c7d1b894075d7a1b752/pretrained/.gitkeep -------------------------------------------------------------------------------- /train_bedsrnet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import time 5 | from logging import DEBUG, INFO, basicConfig, getLogger 6 | 7 | import torch 8 | import torch.optim as optim 9 | from albumentations import Affine # noqa 10 | from albumentations import CoarseDropout # noqa 11 | from albumentations import ColorJitter # noqa 12 | from albumentations import Rotate # noqa 13 | from albumentations import Transpose # noqa 14 | from albumentations import VerticalFlip # noqa 15 | from albumentations import Compose, HorizontalFlip, Normalize, RandomResizedCrop, Resize 16 | from albumentations.pytorch import ToTensorV2 17 | 18 | import wandb 19 | from libs.checkpoint import resume_BEDSRNet, save_checkpoint_BEDSRNet 20 | from libs.config import get_config 21 | from libs.dataset import get_dataloader 22 | from libs.device import get_device 23 | from libs.helper_bedsrnet import evaluate, train 24 | from libs.logger import TrainLoggerBEDSRNet 25 | from libs.loss_fn import get_criterion 26 | from libs.models import get_model 27 | from libs.seed import set_seed 28 | 29 | logger = getLogger(__name__) 30 | 31 | 32 | def get_arguments() -> argparse.Namespace: 33 | """parse all the arguments from command line inteface return a list of 34 | parsed arguments.""" 35 | 36 | parser = argparse.ArgumentParser( 37 | description=""" 38 | train a network for image classification with Flowers Recognition Dataset. 39 | """ 40 | ) 41 | parser.add_argument("--config", type=str, help="path of a config file", default='./configs/model=bedsrnet/config.yaml') 42 | parser.add_argument( 43 | "--resume", 44 | action="store_true", 45 | help="Add --resume option if you start training from checkpoint.", 46 | ) 47 | parser.add_argument( 48 | "--use_wandb", 49 | action="store_true", 50 | help="Add --use_wandb option if you want to use wandb.", 51 | ) 52 | parser.add_argument( 53 | "--debug", 54 | action="store_true", 55 | help="Add --debug option if you want to see debug-level logs.", 56 | ) 57 | parser.add_argument( 58 | "--seed", 59 | type=int, 60 | default=42, 61 | help="random seed", 62 | ) 63 | 64 | return parser.parse_args() 65 | 66 | 67 | def main() -> None: 68 | args = get_arguments() 69 | 70 | # save log files in the directory which contains config file. 71 | result_path = os.path.dirname(args.config) 72 | experiment_name = os.path.basename(result_path) 73 | 74 | # setting logger configuration 75 | logname = os.path.join(result_path, f"{datetime.datetime.now():%Y-%m-%d}_train.log") 76 | basicConfig( 77 | level=DEBUG if args.debug else INFO, 78 | format="[%(asctime)s] %(name)s %(levelname)s: %(message)s", 79 | datefmt="%Y-%m-%d %H:%M:%S", 80 | filename=logname, 81 | ) 82 | 83 | # fix seed 84 | set_seed() 85 | 86 | # configuration 87 | config = get_config(args.config) 88 | 89 | # cpu or cuda 90 | device = get_device(allow_only_gpu=False) 91 | 92 | # Dataloader 93 | train_transform = Compose( 94 | [ 95 | RandomResizedCrop(config.height, config.width), 96 | HorizontalFlip(), 97 | Normalize(mean=(0.5,), std=(0.5,)), 98 | ToTensorV2(), 99 | ] 100 | ) 101 | 102 | val_transform = Compose( 103 | [ 104 | Resize(config.height, config.width), 105 | Normalize(mean=(0.5,), std=(0.5,)), 106 | ToTensorV2(), 107 | ] 108 | ) 109 | 110 | train_loader = get_dataloader( 111 | config.dataset_name, 112 | config.model, 113 | "train", 114 | batch_size=config.batch_size, 115 | shuffle=True, 116 | num_workers=config.num_workers, 117 | pin_memory=True, 118 | drop_last=True, 119 | transform=train_transform, 120 | ) 121 | 122 | val_loader = get_dataloader( 123 | config.dataset_name, 124 | config.model, 125 | "test", 126 | batch_size=1, 127 | shuffle=False, 128 | num_workers=config.num_workers, 129 | pin_memory=True, 130 | transform=val_transform, 131 | ) 132 | 133 | # define a model 134 | benet = get_model("benet", in_channels=3, pretrained=True) 135 | srnet = get_model("srnet", pretrained=config.pretrained) 136 | generator, discriminator = srnet[0], srnet[1] 137 | 138 | # send the model to cuda/cpu 139 | benet.to(device) 140 | generator.to(device) 141 | discriminator.to(device) 142 | 143 | benet.eval() 144 | 145 | optimizerG = optim.Adam( 146 | generator.parameters(), 147 | lr=config.learning_rate, 148 | betas=(config.beta1, config.beta2), 149 | ) 150 | optimizerD = optim.Adam( 151 | discriminator.parameters(), 152 | lr=config.learning_rate, 153 | betas=(config.beta1, config.beta2), 154 | ) 155 | 156 | lambda_dict = {"lambda1": config.lambda1, "lambda2": config.lambda2} 157 | 158 | # keep training and validation log 159 | begin_epoch = 0 160 | best_g_loss = float("inf") 161 | best_d_loss = float("inf") 162 | 163 | # resume if you want 164 | if args.resume: 165 | resume_path = os.path.join(result_path, "checkpoint.pth") 166 | ( 167 | begin_epoch, 168 | generator, 169 | discriminator, 170 | optimizerG, 171 | optimizerD, 172 | best_g_loss, 173 | best_d_loss, 174 | ) = resume_BEDSRNet( 175 | resume_path, generator, discriminator, optimizerG, optimizerD 176 | ) 177 | 178 | log_path = os.path.join(result_path, "log.csv") 179 | train_logger = TrainLoggerBEDSRNet(log_path, resume=args.resume) 180 | 181 | # criterion for loss 182 | criterion = get_criterion(config.loss_function_name, device) 183 | 184 | # Weights and biases 185 | if args.use_wandb: 186 | wandb.init( 187 | name=experiment_name, 188 | config=config, 189 | project="BEDSR-Net", 190 | job_type="training", 191 | ) 192 | # Magic 193 | wandb.watch(generator, log="all") 194 | wandb.watch(discriminator, log="all") 195 | 196 | # train and validate model 197 | logger.info("Start training.") 198 | 199 | for epoch in range(begin_epoch, config.max_epoch): 200 | # training 201 | start = time.time() 202 | train_g_loss, train_d_loss, train_psnr, train_ssim, train_result_images = train( 203 | train_loader, 204 | generator, 205 | discriminator, 206 | benet, 207 | criterion, 208 | lambda_dict, 209 | optimizerG, 210 | optimizerD, 211 | epoch, 212 | device, 213 | ) 214 | train_time = int(time.time() - start) 215 | 216 | # validation 217 | start = time.time() 218 | val_g_loss, val_d_loss, val_psnr, val_ssim, val_result_images = evaluate( 219 | val_loader, generator, discriminator, benet, criterion, lambda_dict, device 220 | ) 221 | val_time = int(time.time() - start) 222 | 223 | # save a model if top1 acc is higher than ever 224 | if best_g_loss > val_g_loss: 225 | best_g_loss = val_g_loss 226 | best_d_loss = val_d_loss 227 | torch.save( 228 | generator.state_dict(), 229 | os.path.join(result_path, "pretrained_g_srnet.prm"), 230 | ) 231 | torch.save( 232 | discriminator.state_dict(), 233 | os.path.join(result_path, "pretrained_d_srnet.prm"), 234 | ) 235 | 236 | # save checkpoint every epoch 237 | save_checkpoint_BEDSRNet( 238 | result_path, 239 | epoch, 240 | generator, 241 | discriminator, 242 | optimizerG, 243 | optimizerD, 244 | best_g_loss, 245 | best_d_loss, 246 | ) 247 | 248 | # write logs to dataframe and csv file 249 | train_logger.update( 250 | epoch, 251 | optimizerG.param_groups[0]["lr"], 252 | optimizerD.param_groups[0]["lr"], 253 | train_time, 254 | train_g_loss, 255 | train_d_loss, 256 | val_time, 257 | val_g_loss, 258 | val_d_loss, 259 | train_psnr, 260 | train_ssim, 261 | val_psnr, 262 | val_ssim, 263 | ) 264 | 265 | # save logs to wandb 266 | if args.use_wandb: 267 | wandb.log( 268 | { 269 | "lrG": optimizerG.param_groups[0]["lr"], 270 | "lrD": optimizerD.param_groups[0]["lr"], 271 | "train_time[sec]": train_time, 272 | "train_g_loss": train_g_loss, 273 | "train_d_loss": train_d_loss, 274 | "val_time[sec]": val_time, 275 | "val_g_loss": val_g_loss, 276 | "val_d_loss": val_d_loss, 277 | "train_psnr": train_psnr, 278 | "val_psnr": val_psnr, 279 | "train_ssim": train_ssim, 280 | "val_ssim": val_ssim, 281 | "train_image": wandb.Image(train_result_images, caption="train"), 282 | "val_image": wandb.Image(val_result_images, caption="val"), 283 | }, 284 | step=epoch, 285 | ) 286 | 287 | # save models 288 | torch.save(generator.state_dict(), os.path.join(result_path, "g_final.prm")) 289 | torch.save(discriminator.state_dict(), os.path.join(result_path, "d_final.prm")) 290 | 291 | # delete checkpoint 292 | os.remove(os.path.join(result_path, "g_checkpoint.pth")) 293 | os.remove(os.path.join(result_path, "d_checkpoint.pth")) 294 | 295 | logger.info("Done") 296 | 297 | 298 | if __name__ == "__main__": 299 | main() 300 | -------------------------------------------------------------------------------- /train_benet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import time 5 | from logging import DEBUG, INFO, basicConfig, getLogger 6 | 7 | import torch 8 | import torch.optim as optim 9 | from albumentations import Affine # noqa 10 | from albumentations import CoarseDropout # noqa 11 | from albumentations import ColorJitter # noqa 12 | from albumentations import Rotate # noqa 13 | from albumentations import Transpose # noqa 14 | from albumentations import VerticalFlip # noqa 15 | from albumentations import Compose, HorizontalFlip, Normalize, RandomResizedCrop, Resize 16 | from albumentations.pytorch import ToTensorV2 17 | 18 | import wandb 19 | from libs.checkpoint import resume, save_checkpoint 20 | from libs.config import get_config 21 | from libs.dataset import get_dataloader 22 | from libs.device import get_device 23 | from libs.helper import evaluate, train 24 | from libs.logger import TrainLogger 25 | from libs.loss_fn import get_criterion 26 | from libs.models import get_model 27 | from libs.seed import set_seed 28 | 29 | logger = getLogger(__name__) 30 | 31 | 32 | def get_arguments() -> argparse.Namespace: 33 | """parse all the arguments from command line inteface return a list of 34 | parsed arguments.""" 35 | 36 | parser = argparse.ArgumentParser( 37 | description=""" 38 | train a network for image classification with Flowers Recognition Dataset. 39 | """ 40 | ) 41 | parser.add_argument("--config", type=str, help="path of a config file", default='./configs/model=benet/config.yaml') 42 | parser.add_argument( 43 | "--resume", 44 | action="store_true", 45 | help="Add --resume option if you start training from checkpoint.", 46 | ) 47 | parser.add_argument( 48 | "--use_wandb", 49 | action="store_true", 50 | help="Add --use_wandb option if you want to use wandb.", 51 | ) 52 | parser.add_argument( 53 | "--debug", 54 | action="store_true", 55 | help="Add --debug option if you want to see debug-level logs.", 56 | ) 57 | parser.add_argument( 58 | "--seed", 59 | type=int, 60 | default=42, 61 | help="random seed", 62 | ) 63 | 64 | return parser.parse_args() 65 | 66 | 67 | def main() -> None: 68 | args = get_arguments() 69 | 70 | # save log files in the directory which contains config file. 71 | result_path = os.path.dirname(args.config) 72 | experiment_name = os.path.basename(result_path) 73 | 74 | # setting logger configuration 75 | logname = os.path.join(result_path, f"{datetime.datetime.now():%Y-%m-%d}_train.log") 76 | basicConfig( 77 | level=DEBUG if args.debug else INFO, 78 | format="[%(asctime)s] %(name)s %(levelname)s: %(message)s", 79 | datefmt="%Y-%m-%d %H:%M:%S", 80 | filename=logname, 81 | ) 82 | 83 | # fix seed 84 | set_seed() 85 | 86 | # configuration 87 | config = get_config(args.config) 88 | 89 | # cpu or cuda 90 | device = get_device(allow_only_gpu=False) 91 | 92 | # Dataloader 93 | train_transform = Compose( 94 | [ 95 | RandomResizedCrop(config.height, config.width), 96 | HorizontalFlip(), 97 | Normalize(mean=(0.5,), std=(0.5,)), 98 | ToTensorV2(), 99 | ] 100 | ) 101 | 102 | val_transform = Compose( 103 | [ 104 | Resize(config.height, config.width), 105 | Normalize(mean=(0.5,), std=(0.5,)), 106 | ToTensorV2(), 107 | ] 108 | ) 109 | 110 | train_loader = get_dataloader( 111 | config.dataset_name, 112 | config.model, 113 | "train", 114 | batch_size=config.batch_size, 115 | shuffle=True, 116 | num_workers=config.num_workers, 117 | pin_memory=True, 118 | drop_last=True, 119 | transform=train_transform, 120 | ) 121 | 122 | val_loader = get_dataloader( 123 | config.dataset_name, 124 | config.model, 125 | "test", 126 | batch_size=1, 127 | shuffle=False, 128 | num_workers=config.num_workers, 129 | pin_memory=True, 130 | transform=val_transform, 131 | ) 132 | 133 | # define a model 134 | model = get_model(config.model, in_channels=3, pretrained=config.pretrained) 135 | 136 | # send the model to cuda/cpu 137 | model.to(device) 138 | 139 | optimizer = optim.Adam(model.parameters(), lr=config.learning_rate) 140 | 141 | # keep training and validation log 142 | begin_epoch = 0 143 | best_loss = float("inf") 144 | 145 | # resume if you want 146 | if args.resume: 147 | resume_path = os.path.join(result_path, "checkpoint.pth") 148 | begin_epoch, model, optimizer, best_loss = resume(resume_path, model, optimizer) 149 | 150 | log_path = os.path.join(result_path, "log.csv") 151 | train_logger = TrainLogger(log_path, resume=args.resume) 152 | 153 | # criterion for loss 154 | criterion = get_criterion(config.loss_function_name, device) 155 | 156 | # Weights and biases 157 | if args.use_wandb: 158 | wandb.init( 159 | name=experiment_name, 160 | config=config, 161 | project="BEDSR-Net", 162 | job_type="training", 163 | ) 164 | # Magic 165 | wandb.watch(model, log="all") 166 | 167 | # train and validate model 168 | logger.info("Start training.") 169 | 170 | for epoch in range(begin_epoch, config.max_epoch): 171 | # training 172 | start = time.time() 173 | train_loss = train(train_loader, model, criterion, optimizer, epoch, device) 174 | train_time = int(time.time() - start) 175 | 176 | # validation 177 | start = time.time() 178 | val_loss = evaluate(val_loader, model, criterion, device) 179 | val_time = int(time.time() - start) 180 | 181 | # save a model if top1 acc is higher than ever 182 | if best_loss > val_loss: 183 | best_loss = val_loss 184 | torch.save( 185 | model.state_dict(), 186 | os.path.join(result_path, "pretrained_benet.prm"), 187 | ) 188 | 189 | # save checkpoint every epoch 190 | save_checkpoint(result_path, epoch, model, optimizer, best_loss) 191 | 192 | # write logs to dataframe and csv file 193 | train_logger.update( 194 | epoch, 195 | optimizer.param_groups[0]["lr"], 196 | train_time, 197 | train_loss, 198 | val_time, 199 | val_loss, 200 | ) 201 | 202 | # save logs to wandb 203 | if args.use_wandb: 204 | wandb.log( 205 | { 206 | "lr": optimizer.param_groups[0]["lr"], 207 | "train_time[sec]": train_time, 208 | "train_loss": train_loss, 209 | "val_time[sec]": val_time, 210 | "val_loss": val_loss, 211 | }, 212 | step=epoch, 213 | ) 214 | 215 | # save models 216 | torch.save(model.state_dict(), os.path.join(result_path, "final.prm")) 217 | 218 | # delete checkpoint 219 | os.remove(os.path.join(result_path, "checkpoint.pth")) 220 | 221 | logger.info("Done") 222 | 223 | 224 | if __name__ == "__main__": 225 | main() 226 | -------------------------------------------------------------------------------- /utils/visualize.py: -------------------------------------------------------------------------------- 1 | # This module criated by yiskw713 2 | # https://github.com/yiskw713/SmoothGradCAMplusplus/blob/master/utils/visualize.py 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | import numpy as np 8 | import cv2 9 | 10 | 11 | def reverse_normalize(x, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 12 | x[:, 0, :, :] = x[:, 0, :, :] * std[0] + mean[0] 13 | x[:, 1, :, :] = x[:, 1, :, :] * std[1] + mean[1] 14 | x[:, 2, :, :] = x[:, 2, :, :] * std[2] + mean[2] 15 | return x 16 | 17 | 18 | def visualize(img, cam): 19 | """ 20 | Synthesize an image with CAM to make a result image. 21 | Args: 22 | img: (Tensor) shape => (1, 3, H, W) 23 | cam: (Tensor) shape => (1, 1, H', W') 24 | Return: 25 | synthesized image (Tensor): shape =>(1, 3, H, W) 26 | """ 27 | 28 | _, _, H, W = img.shape 29 | cam = F.interpolate(cam, size=(H, W), mode='bilinear', align_corners=False) 30 | cam = 255 * cam.squeeze() 31 | heatmap = cv2.applyColorMap(np.uint8(cam), cv2.COLORMAP_JET) 32 | heatmap = torch.from_numpy(heatmap.transpose(2, 0, 1)) 33 | heatmap = heatmap.float() / 255 34 | b, g, r = heatmap.split(1) 35 | heatmap = torch.cat([r, g, b]) 36 | 37 | result = heatmap + img.cpu() 38 | result = result.div(result.max()) 39 | 40 | return result --------------------------------------------------------------------------------