├── .gitignore ├── LICENSE.md ├── README.md ├── applications ├── iih_enhancement │ ├── .gitignore │ ├── README.md │ ├── data │ │ ├── __init__.py │ │ ├── adobe5k_dataset.py │ │ ├── base_dataset.py │ │ ├── image_folder.py │ │ ├── loader.py │ │ └── test_dataset.py │ ├── models │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── harmony_networks.py │ │ ├── iih_base_gd_model.py │ │ └── networks.py │ ├── options │ │ ├── __init__.py │ │ ├── base_options.py │ │ ├── test_options.py │ │ └── train_options.py │ ├── test.py │ ├── train.py │ ├── train_net.py │ └── util │ │ ├── distributed.py │ │ ├── evaluation.py │ │ ├── html.py │ │ ├── misc.py │ │ ├── multiprocessing.py │ │ ├── ssim.py │ │ ├── tools.py │ │ ├── util.py │ │ └── visualizer.py ├── iih_mef │ ├── .gitignore │ ├── README.md │ ├── data │ │ ├── __init__.py │ │ ├── base_dataset.py │ │ ├── image_folder.py │ │ ├── loader.py │ │ ├── mef_dataset.py │ │ └── test_dataset.py │ ├── models │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── harmony_networks.py │ │ ├── iih_base_gd_model.py │ │ └── networks.py │ ├── options │ │ ├── __init__.py │ │ ├── base_options.py │ │ ├── test_options.py │ │ └── train_options.py │ ├── test.py │ ├── train.py │ ├── train_net.py │ └── util │ │ ├── distributed.py │ │ ├── evaluation.py │ │ ├── html.py │ │ ├── misc.py │ │ ├── multiprocessing.py │ │ ├── ssim.py │ │ ├── tools.py │ │ ├── util.py │ │ └── visualizer.py └── iih_relighting │ ├── .gitignore │ ├── README.md │ ├── data │ ├── __init__.py │ ├── base_dataset.py │ ├── dpr_dataset.py │ ├── dprtransfer_dataset.py │ ├── image_folder.py │ ├── loader.py │ └── test_dataset.py │ ├── models │ ├── __init__.py │ ├── base_model.py │ ├── iih_base_lt_model.py │ ├── networks.py │ └── relighting_networks.py │ ├── options │ ├── __init__.py │ ├── base_options.py │ ├── test_options.py │ └── train_options.py │ ├── test.py │ ├── train.py │ ├── train_net.py │ └── util │ ├── distributed.py │ ├── evaluation.py │ ├── html.py │ ├── misc.py │ ├── multiprocessing.py │ ├── ssim.py │ ├── tools.py │ ├── util.py │ └── visualizer.py ├── data ├── __init__.py ├── base_dataset.py ├── ihd_dataset.py ├── image_folder.py ├── loader.py └── real_dataset.py ├── evaluation ├── ih_evaluation.py └── pytorch_ssim.py ├── models ├── __init__.py ├── base_model.py ├── harmony_networks.py ├── iih_base_gd_model.py ├── iih_base_lt_gd_model.py ├── iih_base_lt_model.py ├── iih_base_model.py └── networks.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── requirements.txt ├── test.py ├── train.py ├── train_net.py └── util ├── distributed.py ├── html.py ├── misc.py ├── multiprocessing.py ├── ssim.py ├── tools.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .vscode/* 132 | checkpoints/* 133 | results/* 134 | applications/iih_enhancement/evaluation/enhance_evaluation.py 135 | applications/iih_enhancement/evaluation/pytorch_ssim.py 136 | applications/iih_mef/evaluation.py 137 | applications/iih_relighting/evaluation/relighting.py 138 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 AI @ OUC 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /applications/iih_enhancement/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/* 2 | checkpoints/* 3 | results/* 4 | */__pycache__/* 5 | tmp/* 6 | __pycache__/distribute.cpython-37.pyc 7 | __pycache__/options.cpython-37.pyc 8 | __pycache__/train_net.cpython-37.pyc 9 | __pycache__/train.cpython-37.pyc -------------------------------------------------------------------------------- /applications/iih_enhancement/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # Image Enhancement
5 | 6 | Here we provide PyTorch implementation and the pre-trained model of our latest version. 7 | 8 | ## Prerequisites 9 | 10 | - Linux 11 | - Python 3 12 | - CPU or NVIDIA GPU + CUDA CuDNN 13 | 14 | ## Base Model with Guiding 15 | - Download MIT-Adobe-5K-UPE dataset. 16 | 17 | - Train 18 | ```bash 19 | CUDA_VISIBLE_DEVICES=0 python train.py --model iih_base_gd --name base_gd_adobe5k_test --dataset_root --dataset_name Adobe5k --batch_size xx --init_port xxx 20 | ``` 21 | - Test 22 | ```bash 23 | CUDA_VISIBLE_DEVICES=0 python test.py --model iih_base_gd --name base_gd_adobe5k_test --dataset_root --dataset_name Adobe5k --batch_size xx --init_port xxxx 24 | ``` 25 | - Apply pre-trained model 26 | 27 | Download pre-trained model from [Google Drive](https://drive.google.com/file/d/1h9EG2kZnYi3GI4nAsqnJb1HHBv8GeNf7/view?usp=sharing) or [BaiduCloud](https://pan.baidu.com/s/1mhAxHjetfIvZv-O-kqeHTA) (access code: 0r0k), and put `latest_net_G.pth` in the directory `checkpoints/base_gd_enhancement`. Run: 28 | ```bash 29 | CUDA_VISIBLE_DEVICES=0 python test.py --model iih_base_gd --name base_gd_enhancement --dataset_root --dataset_name Adobe5k --batch_size xx --init_port xxxx 30 | ``` 31 | -------------------------------------------------------------------------------- /applications/iih_enhancement/data/adobe5k_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class template 2 | 3 | This module provides a template for users to implement custom datasets. 4 | You can specify '--dataset_mode template' to use this dataset. 5 | The class name should be consistent with both the filename and its dataset_mode option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | You need to implement the following functions: 9 | -- : Add dataset-specific options and rewrite default values for existing options. 10 | -- <__init__>: Initialize this dataset class. 11 | -- <__getitem__>: Return a data point and its metadata information. 12 | -- <__len__>: Return the number of images. 13 | """ 14 | import os.path 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | import torch.nn.functional as F 18 | from data.base_dataset import BaseDataset, get_transform 19 | from data.image_folder import make_dataset 20 | from PIL import Image 21 | import numpy as np 22 | import torchvision.transforms as transforms 23 | from util import util 24 | 25 | class Adobe5kDataset(BaseDataset): 26 | @staticmethod 27 | def modify_commandline_options(parser, is_train): 28 | parser.add_argument('--is_train', type=bool, default=True, help='whether in the training phase') 29 | parser.set_defaults(max_dataset_size=float("inf"), new_dataset_option=2.0) # specify dataset-specific default values 30 | return parser 31 | 32 | def __init__(self, opt): 33 | """Initialize this dataset class. 34 | 35 | Parameters: 36 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 37 | 38 | A few things can be done here. 39 | - save the options (have been done in BaseDataset) 40 | - get image paths and meta information of the dataset. 41 | - define the image transformation. 42 | """ 43 | # save the option and dataset root 44 | BaseDataset.__init__(self, opt) 45 | self.fake_image_paths = [] 46 | self.image_paths = [] 47 | self.isTrain = opt.isTrain 48 | self.image_size = opt.crop_size 49 | 50 | if opt.isTrain==True: 51 | #self.real_ext='.jpg' 52 | print('loading training file') 53 | self.trainfile = opt.dataset_root+opt.dataset_name+'_train.txt' 54 | with open(self.trainfile,'r') as f: 55 | for line in f.readlines(): 56 | self.image_paths.append(os.path.join(opt.dataset_root,'UPEresize',line.rstrip())) 57 | elif opt.isTrain==False: 58 | #self.real_ext='.jpg' 59 | print('loading test file') 60 | self.trainfile = opt.dataset_root+opt.dataset_name+'_test.txt' 61 | with open(self.trainfile,'r') as f: 62 | for line in f.readlines(): 63 | self.image_paths.append(os.path.join(opt.dataset_root,'test_set/input/',line.rstrip())) 64 | # get the image paths of your dataset; 65 | # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root 66 | # define the default transform function. You can use ; You can also define your custom transform function 67 | transform_list = [ 68 | transforms.ToTensor(), 69 | transforms.Normalize((0, 0, 0), (1, 1, 1)) 70 | ] 71 | self.transforms = transforms.Compose(transform_list) 72 | 73 | def __getitem__(self, index): 74 | """Return a data point and its metadata information. 75 | 76 | Parameters: 77 | index -- a random integer for data indexing 78 | 79 | Returns: 80 | a dictionary of data with their names. It usually contains the data itself and its metadata information. 81 | 82 | Step 1: get a random image path: e.g., path = self.image_paths[index] 83 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). 84 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) 85 | Step 4: return a data point as a dictionary. 86 | """ 87 | path = self.image_paths[index] 88 | name_parts=path.split('/') 89 | if self.isTrain: 90 | target_path = self.image_paths[index].replace(name_parts[-2],'Expert_C_resize') 91 | else: 92 | target_path = self.image_paths[index].replace('input','expertC_gt') 93 | 94 | comp = Image.open(path).convert('RGB') 95 | real = Image.open(target_path).convert('RGB') 96 | 97 | if np.random.rand() > 0.5 and self.isTrain: 98 | comp, real = tf.hflip(comp), tf.hflip(real) 99 | if comp.size[0] != self.image_size: 100 | # assert 0 101 | comp = tf.resize(comp, [self.image_size, self.image_size]) 102 | real = tf.resize(real, [self.image_size,self.image_size]) 103 | 104 | comp = self.transforms(comp) 105 | real = self.transforms(real) 106 | return {'fake': comp, 'real': real,'img_path':path} 107 | 108 | def __len__(self): 109 | """Return the total number of images.""" 110 | return len(self.image_paths) 111 | 112 | -------------------------------------------------------------------------------- /applications/iih_enhancement/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | # self.root = opt.dataroot 31 | self.root = opt.dataset_root 32 | 33 | @staticmethod 34 | def modify_commandline_options(parser, is_train): 35 | """Add new dataset-specific options, and rewrite default values for existing options. 36 | 37 | Parameters: 38 | parser -- original option parser 39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 40 | 41 | Returns: 42 | the modified parser. 43 | """ 44 | return parser 45 | 46 | @abstractmethod 47 | def __len__(self): 48 | """Return the total number of images in the dataset.""" 49 | return 0 50 | 51 | @abstractmethod 52 | def __getitem__(self, index): 53 | """Return a data point and its metadata information. 54 | 55 | Parameters: 56 | index - - a random integer for data indexing 57 | 58 | Returns: 59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 60 | """ 61 | pass 62 | 63 | 64 | def get_params(opt, size): 65 | w, h = size 66 | new_h = h 67 | new_w = w 68 | if opt.preprocess == 'resize_and_crop': 69 | new_h = new_w = opt.load_size 70 | elif opt.preprocess == 'scale_width_and_crop': 71 | new_w = opt.load_size 72 | new_h = opt.load_size * h // w 73 | 74 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 75 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 76 | 77 | flip = random.random() > 0.5 78 | 79 | return {'crop_pos': (x, y), 'flip': flip} 80 | 81 | 82 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 83 | transform_list = [] 84 | if grayscale: 85 | transform_list.append(transforms.Grayscale(1)) 86 | if 'resize' in opt.preprocess: 87 | osize = [opt.load_size, opt.load_size] 88 | transform_list.append(transforms.Resize(osize, method)) 89 | elif 'scale_width' in opt.preprocess: 90 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 91 | 92 | if 'crop' in opt.preprocess: 93 | if params is None: 94 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 95 | else: 96 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 97 | 98 | if opt.preprocess == 'none': 99 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 100 | 101 | if not opt.no_flip: 102 | if params is None: 103 | transform_list.append(transforms.RandomHorizontalFlip()) 104 | elif params['flip']: 105 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 106 | 107 | if convert: 108 | transform_list += [transforms.ToTensor()] 109 | if grayscale: 110 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 111 | else: 112 | # transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 113 | transform_list += [transforms.Normalize((0, 0, 0), (1, 1, 1))] 114 | return transforms.Compose(transform_list) 115 | 116 | 117 | def __make_power_2(img, base, method=Image.BICUBIC): 118 | ow, oh = img.size 119 | h = int(round(oh / base) * base) 120 | w = int(round(ow / base) * base) 121 | if (h == oh) and (w == ow): 122 | return img 123 | 124 | __print_size_warning(ow, oh, w, h) 125 | return img.resize((w, h), method) 126 | 127 | 128 | def __scale_width(img, target_width, method=Image.BICUBIC): 129 | ow, oh = img.size 130 | if (ow == target_width): 131 | return img 132 | w = target_width 133 | h = int(target_width * oh / ow) 134 | return img.resize((w, h), method) 135 | 136 | 137 | def __crop(img, pos, size): 138 | ow, oh = img.size 139 | x1, y1 = pos 140 | tw = th = size 141 | if (ow > tw or oh > th): 142 | return img.crop((x1, y1, x1 + tw, y1 + th)) 143 | return img 144 | 145 | 146 | def __flip(img, flip): 147 | if flip: 148 | return img.transpose(Image.FLIP_LEFT_RIGHT) 149 | return img 150 | 151 | 152 | def __print_size_warning(ow, oh, w, h): 153 | """Print warning information about image size(only print once)""" 154 | if not hasattr(__print_size_warning, 'has_printed'): 155 | print("The image size needs to be a multiple of 4. " 156 | "The loaded image size was (%d, %d), so it was adjusted to " 157 | "(%d, %d). This adjustment will be done to all images " 158 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 159 | __print_size_warning.has_printed = True 160 | -------------------------------------------------------------------------------- /applications/iih_enhancement/data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | import os.path 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | ] 17 | 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | 23 | def make_dataset(dir, max_dataset_size=float("inf")): 24 | images = [] 25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 26 | 27 | for root, _, fnames in sorted(os.walk(dir)): 28 | for fname in fnames: 29 | if is_image_file(fname): 30 | path = os.path.join(root, fname) 31 | images.append(path) 32 | return images[:min(max_dataset_size, len(images))] 33 | 34 | 35 | def default_loader(path): 36 | return Image.open(path).convert('RGB') 37 | 38 | 39 | class ImageFolder(data.Dataset): 40 | 41 | def __init__(self, root, transform=None, return_paths=False, 42 | loader=default_loader): 43 | imgs = make_dataset(root) 44 | if len(imgs) == 0: 45 | raise(RuntimeError("Found 0 images in: " + root + "\n" 46 | "Supported image extensions are: " + 47 | ",".join(IMG_EXTENSIONS))) 48 | 49 | self.root = root 50 | self.imgs = imgs 51 | self.transform = transform 52 | self.return_paths = return_paths 53 | self.loader = loader 54 | 55 | def __getitem__(self, index): 56 | path = self.imgs[index] 57 | img = self.loader(path) 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | if self.return_paths: 61 | return img, path 62 | else: 63 | return img 64 | 65 | def __len__(self): 66 | return len(self.imgs) 67 | -------------------------------------------------------------------------------- /applications/iih_enhancement/data/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Data loader.""" 5 | 6 | import itertools 7 | import numpy as np 8 | import torch 9 | from torch.utils.data._utils.collate import default_collate 10 | from torch.utils.data.distributed import DistributedSampler 11 | from torch.utils.data.sampler import RandomSampler 12 | 13 | from slowfast.datasets.multigrid_helper import ShortCycleBatchSampler 14 | 15 | from . import utils as utils 16 | 17 | def build_dataset(cfg): 18 | image_paths = [] 19 | if cfg.phase == 'train': 20 | print('loading training file') 21 | file = cfg.dataset_root+cfg.dataset_name+'_train.txt' 22 | with open(file,'r') as f: 23 | for line in f.readlines(): 24 | image_paths.append(os.path.join(cfg.dataset_root,'composite_images',line.rstrip())) 25 | 26 | 27 | def construct_loader(cfg, split, is_precise_bn=False): 28 | """ 29 | Constructs the data loader for the given dataset. 30 | Args: 31 | cfg (CfgNode): configs. Details can be found in 32 | slowfast/config/defaults.py 33 | split (str): the split of the data loader. Options include `train`, 34 | `val`, and `test`. 35 | """ 36 | assert split in ["train", "val", "test"] 37 | if split in ["train"]: 38 | dataset_name = cfg.TRAIN.DATASET 39 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 40 | shuffle = True 41 | drop_last = True 42 | elif split in ["val"]: 43 | dataset_name = cfg.TRAIN.DATASET 44 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 45 | shuffle = False 46 | drop_last = False 47 | elif split in ["test"]: 48 | dataset_name = cfg.TEST.DATASET 49 | batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 50 | shuffle = False 51 | drop_last = False 52 | 53 | # Construct the dataset 54 | dataset = build_dataset(dataset_name, cfg, split) 55 | 56 | if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn: 57 | # Create a sampler for multi-process training 58 | sampler = utils.create_sampler(dataset, shuffle, cfg) 59 | batch_sampler = ShortCycleBatchSampler( 60 | sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg 61 | ) 62 | # Create a loader 63 | loader = torch.utils.data.DataLoader( 64 | dataset, 65 | batch_sampler=batch_sampler, 66 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 67 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 68 | worker_init_fn=utils.loader_worker_init_fn(dataset), 69 | ) 70 | else: 71 | # Create a sampler for multi-process training 72 | sampler = utils.create_sampler(dataset, shuffle, cfg) 73 | # Create a loader 74 | loader = torch.utils.data.DataLoader( 75 | dataset, 76 | batch_size=batch_size, 77 | shuffle=(False if sampler else shuffle), 78 | sampler=sampler, 79 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 80 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 81 | drop_last=drop_last, 82 | collate_fn=detection_collate if cfg.DETECTION.ENABLE else None, 83 | worker_init_fn=utils.loader_worker_init_fn(dataset), 84 | ) 85 | return loader 86 | 87 | 88 | def shuffle_dataset(loader, cur_epoch): 89 | """ " 90 | Shuffles the data. 91 | Args: 92 | loader (loader): data loader to perform shuffle. 93 | cur_epoch (int): number of the current epoch. 94 | """ 95 | sampler = ( 96 | loader.batch_sampler.sampler 97 | if isinstance(loader.batch_sampler, ShortCycleBatchSampler) 98 | else loader.sampler 99 | ) 100 | assert isinstance( 101 | sampler, (RandomSampler, DistributedSampler) 102 | ), "Sampler type '{}' not supported".format(type(sampler)) 103 | # RandomSampler handles shuffling automatically 104 | if isinstance(sampler, DistributedSampler): 105 | # DistributedSampler shuffles data based on epoch 106 | sampler.set_epoch(cur_epoch) 107 | -------------------------------------------------------------------------------- /applications/iih_enhancement/models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /applications/iih_enhancement/models/iih_base_gd_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import itertools 4 | import torch.nn.functional as F 5 | from util import distributed as du 6 | # import pytorch_colors as colors 7 | from .base_model import BaseModel 8 | from util import util 9 | from . import harmony_networks as networks 10 | import util.ssim as ssim 11 | 12 | 13 | class IIHBaseGDModel(BaseModel): 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train=True): 16 | parser.set_defaults(norm='instance', netG='base_gd', dataset_mode='adobe5k') 17 | if is_train: 18 | parser.add_argument('--lambda_L1', type=float, default=50.0, help='weight for L1 loss') 19 | parser.add_argument('--lambda_R', type=float, default=100., help='weight for R gradient loss') 20 | parser.add_argument('--lambda_ssim', type=float, default=50., help='weight for L L2 loss') 21 | parser.add_argument('--lambda_ifm', type=float, default=100, help='weight for pm loss') 22 | 23 | return parser 24 | 25 | def __init__(self, opt): 26 | BaseModel.__init__(self, opt) 27 | self.opt = opt 28 | # specify the training losses you want to print out. The training/test scripts will call 29 | self.loss_names = ['G','G_L1','G_R','G_R_SSIM',"IF"] 30 | 31 | # specify the images you want to save/display. The training/test scripts will call 32 | self.visual_names = ['enhanced','real','fake','reconstruct','illumination'] 33 | # specify the models you want to save to the disk. The training/test scripts will call and 34 | self.model_names = ['G'] 35 | self.opt.device = self.device 36 | self.netG = networks.define_G(opt.netG, opt.init_type, opt.init_gain, self.opt) 37 | self.cur_device = torch.cuda.current_device() 38 | self.ismaster = du.is_master_proc(opt.NUM_GPUS) 39 | print(self.netG) 40 | 41 | if self.isTrain: 42 | # if self.ismaster == 0: 43 | util.saveprint(self.opt, 'netG', str(self.netG)) 44 | # define loss functions 45 | self.criterionL1 = torch.nn.L1Loss() 46 | self.criterionL2 = torch.nn.MSELoss() 47 | self.criterionSSIM = ssim.SSIM() 48 | self.criterionDSSIM_CS = ssim.DSSIM(mode='c_s').to(self.device) 49 | # initialize optimizers; schedulers will be automatically created by function . 50 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 51 | self.optimizers.append(self.optimizer_G) 52 | 53 | def set_input(self, input): 54 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 55 | 56 | Parameters: 57 | input (dict): include the data itself and its metadata information. 58 | 59 | The option 'direction' can be used to swap images in domain A and domain B. 60 | """ 61 | self.fake = input['fake'].to(self.device) 62 | self.real = input['real'].to(self.device) 63 | self.image_paths = input['img_path'] 64 | self.real_r = F.interpolate(self.real, size=[32,32]) 65 | self.real_gray = util.rgbtogray(self.real_r) 66 | def forward(self): 67 | self.reconstruct, self.enhanced, self.illumination, self.ifm_mean = self.netG(self.fake) 68 | def backward_G(self): 69 | self.loss_IF = self.criterionDSSIM_CS(self.ifm_mean, self.real_gray)*self.opt.lambda_ifm 70 | 71 | self.loss_G_L1 = self.criterionL1(self.reconstruct, self.fake)*self.opt.lambda_L1 72 | self.loss_G_R = self.criterionL2(self.enhanced, self.real)*self.opt.lambda_R 73 | self.loss_G_R_SSIM = (1-self.criterionSSIM(self.enhanced, self.real))*self.opt.lambda_ssim 74 | self.loss_G = self.loss_G_L1 + self.loss_G_R + self.loss_G_R_SSIM + self.loss_IF 75 | self.loss_G.backward() 76 | 77 | def optimize_parameters(self): 78 | self.forward() # compute fake images: G(A) 79 | # update G 80 | self.optimizer_G.zero_grad() # set G's gradients to zero 81 | self.backward_G() # calculate graidents for G 82 | self.optimizer_G.step() # udpate G's weights -------------------------------------------------------------------------------- /applications/iih_enhancement/options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /applications/iih_enhancement/options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | # Dropout and Batchnorm has different behavioir during training and test. 17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 18 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 19 | parser.add_argument('--test_epoch', type=str, default="0", help='how many test images to run') 20 | # rewrite devalue values 21 | parser.set_defaults(model='test') 22 | # To avoid cropping, the load_size should be the same as crop_size 23 | parser.set_defaults(load_size=parser.get_default('crop_size')) 24 | self.isTrain = False 25 | return parser 26 | -------------------------------------------------------------------------------- /applications/iih_enhancement/options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | # visdom and HTML visualization parameters 13 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 14 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 15 | parser.add_argument('--display_id', type=int, default=0, help='window id of the web display') 16 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 17 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 18 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 19 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 20 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 21 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 22 | # network saving and loading parameters 23 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 24 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 25 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 26 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 27 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 28 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 29 | # training parameters 30 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 31 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 32 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 33 | parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') 34 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 35 | parser.add_argument('--g_lr_ratio', type=float, default=1.0, help='a ratio for changing learning rate of generator') 36 | parser.add_argument('--d_lr_ratio', type=float, default=1.0, help='a ratio for changing learning rate of discriminator') 37 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') 38 | parser.add_argument('--pool_size', type=int, default=40, help='the size of image buffer that stores previously generated images') 39 | parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') 40 | parser.add_argument('--lr_decay_iters', type=int, default=40, help='multiply by a gamma every lr_decay_iters iterations') 41 | parser.add_argument('--save_iter_model', action='store_true', help='whether saves model by iteration') 42 | 43 | 44 | 45 | 46 | self.isTrain = True 47 | return parser 48 | -------------------------------------------------------------------------------- /applications/iih_enhancement/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from util.misc import launch_job 5 | from train_net import test 6 | 7 | from options.test_options import TestOptions 8 | 9 | def main(): 10 | cfg = TestOptions().parse() # get training options 11 | cfg.NUM_GPUS = torch.cuda.device_count() 12 | cfg.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 13 | cfg.no_flip = True # no flip; comment this line if results on flipped images are needed. 14 | cfg.display_id = -1 # no visdom display; the test code saves the results to a HTML file. 15 | 16 | cfg.phase = 'test' 17 | cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS)) 18 | launch_job(cfg=cfg, init_method=cfg.init_method, func=test) 19 | 20 | 21 | if __name__=="__main__": 22 | main() -------------------------------------------------------------------------------- /applications/iih_enhancement/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from util.misc import launch_job 5 | from train_net import train 6 | 7 | from options.train_options import TrainOptions 8 | 9 | def main(): 10 | cfg = TrainOptions().parse() # get training options 11 | cfg.NUM_GPUS = torch.cuda.device_count() 12 | cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS)) 13 | cfg.phase = 'train' 14 | launch_job(cfg=cfg, init_method=cfg.init_method, func=train) 15 | 16 | 17 | if __name__=="__main__": 18 | main() -------------------------------------------------------------------------------- /applications/iih_enhancement/util/evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as f 3 | 4 | 5 | def evaluation(name, fake, real, mask): 6 | b,c,w,h = real.size() 7 | mse_score = f.mse_loss(fake, real) 8 | fore_area = torch.sum(mask) 9 | fmse_score = f.mse_loss(fake*mask,real*mask)*w*h/fore_area 10 | mse_score = mse_score.item() 11 | fmse_score = fmse_score.item() 12 | # score_str = "%s MSE %0.2f | fMSE %0.2f" % (name, mse_score,fmse_score) 13 | image_fmse_info = (name, round(fmse_score,2), round(mse_score, 2)) 14 | return mse_score, fmse_score, image_fmse_info -------------------------------------------------------------------------------- /applications/iih_enhancement/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /applications/iih_enhancement/util/multiprocessing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Multiprocessing helpers.""" 5 | 6 | import torch 7 | 8 | 9 | def run( 10 | local_rank, 11 | num_proc, 12 | func, 13 | init_method, 14 | shard_id, 15 | num_shards, 16 | backend, 17 | cfg, 18 | output_queue=None, 19 | ): 20 | """ 21 | Runs a function from a child process. 22 | Args: 23 | local_rank (int): rank of the current process on the current machine. 24 | num_proc (int): number of processes per machine. 25 | func (function): function to execute on each of the process. 26 | init_method (string): method to initialize the distributed training. 27 | TCP initialization: equiring a network address reachable from all 28 | processes followed by the port. 29 | Shared file-system initialization: makes use of a file system that 30 | is shared and visible from all machines. The URL should start with 31 | file:// and contain a path to a non-existent file on a shared file 32 | system. 33 | shard_id (int): the rank of the current machine. 34 | num_shards (int): number of overall machines for the distributed 35 | training job. 36 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are 37 | supports, each with different capabilities. Details can be found 38 | here: 39 | https://pytorch.org/docs/stable/distributed.html 40 | cfg (CfgNode): configs. Details can be found in 41 | slowfast/config/defaults.py 42 | output_queue (queue): can optionally be used to return values from the 43 | master process. 44 | """ 45 | # Initialize the process group. 46 | world_size = num_proc * num_shards 47 | rank = shard_id * num_proc + local_rank 48 | try: 49 | torch.distributed.init_process_group( 50 | backend=backend, 51 | init_method=init_method, 52 | world_size=world_size, 53 | rank=rank, 54 | ) 55 | 56 | except Exception as e: 57 | raise e 58 | 59 | torch.cuda.set_device(local_rank) 60 | ret = func(cfg) 61 | if output_queue is not None and local_rank == 0: 62 | output_queue.put(ret) 63 | -------------------------------------------------------------------------------- /applications/iih_enhancement/util/ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | def _ssim_c_s(img1, img2, window, window_size, channel, size_average = True): 40 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 41 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 42 | 43 | mu1_sq = mu1.pow(2) 44 | mu2_sq = mu2.pow(2) 45 | mu1_mu2 = mu1*mu2 46 | 47 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 48 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 49 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 50 | 51 | C1 = 0.01**2 52 | C2 = 0.03**2 53 | 54 | ssim_map = ((2*sigma12 + C2))/((sigma1_sq + sigma2_sq + C2)) 55 | 56 | if size_average: 57 | return ssim_map.mean() 58 | else: 59 | return ssim_map.mean(1).mean(1).mean(1) 60 | def _ssim_l(img1, img2, window, window_size, channel, size_average = True): 61 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 62 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 63 | 64 | mu1_sq = mu1.pow(2) 65 | mu2_sq = mu2.pow(2) 66 | mu1_mu2 = mu1*mu2 67 | 68 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 69 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 70 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 71 | 72 | C1 = 0.01**2 73 | C2 = 0.03**2 74 | 75 | ssim_map = (2*mu1_mu2 + C1)/(mu1_sq + mu2_sq + C1) 76 | 77 | if size_average: 78 | return ssim_map.mean() 79 | else: 80 | return ssim_map.mean(1).mean(1).mean(1) 81 | 82 | 83 | def ssim(img1, img2, window_size = 11, size_average = True): 84 | (_, channel, _, _) = img1.size() 85 | window = create_window(window_size, channel) 86 | 87 | if img1.is_cuda: 88 | window = window.cuda(img1.get_device()) 89 | window = window.type_as(img1) 90 | 91 | return _ssim(img1, img2, window, window_size, channel, size_average) 92 | 93 | class SSIM(torch.nn.Module): 94 | def __init__(self, window_size = 11, size_average = True, mode='all'): 95 | super(SSIM, self).__init__() 96 | self.window_size = window_size 97 | self.size_average = size_average 98 | self.channel = 1 99 | self.window = create_window(window_size, self.channel) 100 | self.mode = mode 101 | def forward(self, img1, img2): 102 | (_, channel, _, _) = img1.size() 103 | 104 | if channel == self.channel and self.window.data.type() == img1.data.type(): 105 | window = self.window 106 | else: 107 | window = create_window(self.window_size, channel) 108 | 109 | # if img1.is_cuda: 110 | # window = window.cuda(img1.get_device()) 111 | window = window.type_as(img1) 112 | 113 | self.window = window 114 | self.channel = channel 115 | if self.mode == 'all': 116 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 117 | elif self.mode == 'c_s': 118 | return _ssim_c_s(img1, img2, window, self.window_size, channel, self.size_average) 119 | else: 120 | return _ssim_l(img1, img2, window, self.window_size, channel, self.size_average) 121 | 122 | class DSSIM(torch.nn.Module): 123 | def __init__(self, window_size = 11, size_average = True, mode='all'): 124 | super(DSSIM, self).__init__() 125 | self.window_size = window_size 126 | self.size_average = size_average 127 | self.channel = 1 128 | self.window = create_window(window_size, self.channel) 129 | self.mode = mode 130 | def forward(self, img1, img2): 131 | (_, channel, _, _) = img1.size() 132 | 133 | if channel == self.channel and self.window.data.type() == img1.data.type(): 134 | window = self.window 135 | else: 136 | window = create_window(self.window_size, channel) 137 | 138 | # if img1.is_cuda: 139 | # window = window.cuda(img1.get_device()) 140 | window = window.type_as(img1) 141 | 142 | self.window = window 143 | self.channel = channel 144 | if self.mode == 'all': 145 | ssim_v = _ssim(img1, img2, window, self.window_size, channel, self.size_average) 146 | elif self.mode == 'c_s': 147 | ssim_v = _ssim_c_s(img1, img2, window, self.window_size, channel, self.size_average) 148 | else: 149 | ssim_v = _ssim_l(img1, img2, window, self.window_size, channel, self.size_average) 150 | return (1-ssim_v)/2 -------------------------------------------------------------------------------- /applications/iih_mef/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/* 2 | checkpoints/* 3 | results/* 4 | */__pycache__/* 5 | tmp/* 6 | __pycache__/distribute.cpython-37.pyc 7 | __pycache__/options.cpython-37.pyc 8 | __pycache__/train_net.cpython-37.pyc 9 | __pycache__/train.cpython-37.pyc -------------------------------------------------------------------------------- /applications/iih_mef/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # Multi-Exposure Image Fusion
5 | 6 | Here we provide PyTorch implementation and the pre-trained model of our latest version. 7 | 8 | ## Prerequisites 9 | 10 | - Linux 11 | - Python 3 12 | - CPU or NVIDIA GPU + CUDA CuDNN 13 | 14 | ## Base Model with Guiding 15 | - Download SICE dataset. 16 | 17 | - Train 18 | ```bash 19 | CUDA_VISIBLE_DEVICES=0 python train.py --model iih_base_gd --name base_gd_sice_test --dataset_root --dataset_name mef --batch_size xx --init_port xxxx 20 | ``` 21 | - Test 22 | ```bash 23 | CUDA_VISIBLE_DEVICES=0 python test.py --model iih_base_gd --name base_gd_sice_test --dataset_root --dataset_name mef --batch_size xx --init_port xxxx 24 | ``` 25 | - Apply pre-trained model 26 | 27 | Download pre-trained model from [Google Drive](https://drive.google.com/file/d/17SIkVhRFW5LTuX2PXDPkVw2IwKWDpO-B/view?usp=sharing) or [BaiduCloud](https://pan.baidu.com/s/1V4ulhcC1eqM6EfVbxRIz1g) (access code: 15vn), and put `latest_net_G.pth` in the directory `checkpoints/base_gd_mef`. Run: 28 | ```bash 29 | CUDA_VISIBLE_DEVICES=0 python test.py --model iih_base_gd --name base_gd_mef --dataset_root --dataset_name mef --batch_size xx --init_port xxxx 30 | ``` -------------------------------------------------------------------------------- /applications/iih_mef/data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from data.base_dataset import BaseDataset 16 | from torch.utils.data.distributed import DistributedSampler 17 | from torch.utils.data.sampler import RandomSampler 18 | 19 | 20 | def find_dataset_using_name(dataset_name): 21 | """Import the module "data/[dataset_name]_dataset.py". 22 | 23 | In the file, the class called DatasetNameDataset() will 24 | be instantiated. It has to be a subclass of BaseDataset, 25 | and it is case-insensitive. 26 | """ 27 | dataset_filename = "data." + dataset_name + "_dataset" 28 | datasetlib = importlib.import_module(dataset_filename) 29 | 30 | dataset = None 31 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 32 | for name, cls in datasetlib.__dict__.items(): 33 | if name.lower() == target_dataset_name.lower() \ 34 | and issubclass(cls, BaseDataset): 35 | dataset = cls 36 | 37 | if dataset is None: 38 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 39 | 40 | return dataset 41 | 42 | 43 | def get_option_setter(dataset_name): 44 | """Return the static method of the dataset class.""" 45 | dataset_class = find_dataset_using_name(dataset_name) 46 | return dataset_class.modify_commandline_options 47 | 48 | 49 | def create_dataset(opt): 50 | """Create a dataset given the option. 51 | 52 | This function wraps the class CustomDatasetDataLoader. 53 | This is the main interface between this package and 'train.py'/'test.py' 54 | 55 | Example: 56 | >>> from data import create_dataset 57 | >>> dataset = create_dataset(opt) 58 | """ 59 | # data_loader = CustomDatasetDataLoader(opt) 60 | # dataset = data_loader.load_data() 61 | 62 | dataset_class = find_dataset_using_name(opt.dataset_mode) 63 | dataset = dataset_class(opt) 64 | print("dataset [%s] was created" % type(dataset).__name__) 65 | 66 | # batch_size = int(opt.batch_size / max(1, opt.NUM_GPUS)) 67 | if opt.isTrain==True: 68 | shuffle = True 69 | drop_last = True 70 | elif opt.isTrain==False: 71 | shuffle = False 72 | drop_last = False 73 | 74 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) if opt.NUM_GPUS > 1 else None 75 | 76 | # Create a loader 77 | dataloader = torch.utils.data.DataLoader( 78 | dataset, 79 | batch_size=opt.batch_size, 80 | shuffle=(False if sampler else shuffle), 81 | sampler=sampler, 82 | num_workers=int(opt.num_threads), 83 | drop_last=drop_last, 84 | pin_memory=True, 85 | ) 86 | return dataloader 87 | 88 | 89 | class CustomDatasetDataLoader(): 90 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 91 | 92 | def __init__(self, opt): 93 | """Initialize this class 94 | 95 | Step 1: create a dataset instance given the name [dataset_mode] 96 | Step 2: create a multi-threaded data loader. 97 | """ 98 | self.opt = opt 99 | dataset_class = find_dataset_using_name(opt.dataset_mode) 100 | self.dataset = dataset_class(opt) 101 | print("dataset [%s] was created" % type(self.dataset).__name__) 102 | 103 | batch_size = int(opt.batch_size / max(1, opt.NUM_GPUS)) 104 | if opt.isTrain==True: 105 | shuffle = True 106 | drop_last = True 107 | elif opt.isTrain==False: 108 | shuffle = False 109 | drop_last = False 110 | 111 | self.sampler = torch.utils.data.distributed.DistributedSampler(self.dataset) if opt.NUM_GPUS > 1 else None 112 | 113 | # Create a loader 114 | self.dataloader = torch.utils.data.DataLoader( 115 | self.dataset, 116 | batch_size=batch_size, 117 | shuffle=(False if self.sampler else shuffle), 118 | sampler=self.sampler, 119 | num_workers=int(opt.num_threads), 120 | drop_last=drop_last, 121 | ) 122 | 123 | # self.dataloader = torch.utils.data.DataLoader( 124 | # self.dataset, 125 | # batch_size=opt.batch_size, 126 | # shuffle=not opt.serial_batches, 127 | # num_workers=int(opt.num_threads)) 128 | 129 | def load_data(self): 130 | return self 131 | 132 | def __len__(self): 133 | """Return the number of data in the dataset""" 134 | return min(len(self.dataset), self.opt.max_dataset_size) 135 | 136 | # def __iter__(self): 137 | # """Return a batch of data""" 138 | # for i, data in enumerate(self.dataloader): 139 | # if i * self.opt.batch_size >= self.opt.max_dataset_size: 140 | # break 141 | # yield data 142 | 143 | def shuffle_dataset(loader, cur_epoch): 144 | """ " 145 | Shuffles the data. 146 | Args: 147 | loader (loader): data loader to perform shuffle. 148 | cur_epoch (int): number of the current epoch. 149 | """ 150 | # sampler = ( 151 | # loader.batch_sampler.sampler 152 | # if isinstance(loader.batch_sampler, ShortCycleBatchSampler) 153 | # else loader.sampler 154 | # ) 155 | sampler = loader.sampler 156 | assert isinstance( 157 | sampler, (RandomSampler, DistributedSampler) 158 | ), "Sampler type '{}' not supported".format(type(sampler)) 159 | # RandomSampler handles shuffling automatically 160 | if isinstance(sampler, DistributedSampler): 161 | # DistributedSampler shuffles data based on epoch 162 | sampler.set_epoch(cur_epoch) 163 | 164 | 165 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /applications/iih_mef/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | # self.root = opt.dataroot 31 | self.root = opt.dataset_root 32 | 33 | @staticmethod 34 | def modify_commandline_options(parser, is_train): 35 | """Add new dataset-specific options, and rewrite default values for existing options. 36 | 37 | Parameters: 38 | parser -- original option parser 39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 40 | 41 | Returns: 42 | the modified parser. 43 | """ 44 | return parser 45 | 46 | @abstractmethod 47 | def __len__(self): 48 | """Return the total number of images in the dataset.""" 49 | return 0 50 | 51 | @abstractmethod 52 | def __getitem__(self, index): 53 | """Return a data point and its metadata information. 54 | 55 | Parameters: 56 | index - - a random integer for data indexing 57 | 58 | Returns: 59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 60 | """ 61 | pass 62 | 63 | 64 | def get_params(opt, size): 65 | w, h = size 66 | new_h = h 67 | new_w = w 68 | if opt.preprocess == 'resize_and_crop': 69 | new_h = new_w = opt.load_size 70 | elif opt.preprocess == 'scale_width_and_crop': 71 | new_w = opt.load_size 72 | new_h = opt.load_size * h // w 73 | 74 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 75 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 76 | 77 | flip = random.random() > 0.5 78 | 79 | return {'crop_pos': (x, y), 'flip': flip} 80 | 81 | 82 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 83 | transform_list = [] 84 | if grayscale: 85 | transform_list.append(transforms.Grayscale(1)) 86 | if 'resize' in opt.preprocess: 87 | osize = [opt.load_size, opt.load_size] 88 | transform_list.append(transforms.Resize(osize, method)) 89 | elif 'scale_width' in opt.preprocess: 90 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 91 | 92 | if 'crop' in opt.preprocess: 93 | if params is None: 94 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 95 | else: 96 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 97 | 98 | if opt.preprocess == 'none': 99 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 100 | 101 | if not opt.no_flip: 102 | if params is None: 103 | transform_list.append(transforms.RandomHorizontalFlip()) 104 | elif params['flip']: 105 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 106 | 107 | if convert: 108 | transform_list += [transforms.ToTensor()] 109 | if grayscale: 110 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 111 | else: 112 | # transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 113 | transform_list += [transforms.Normalize((0, 0, 0), (1, 1, 1))] 114 | return transforms.Compose(transform_list) 115 | 116 | 117 | def __make_power_2(img, base, method=Image.BICUBIC): 118 | ow, oh = img.size 119 | h = int(round(oh / base) * base) 120 | w = int(round(ow / base) * base) 121 | if (h == oh) and (w == ow): 122 | return img 123 | 124 | __print_size_warning(ow, oh, w, h) 125 | return img.resize((w, h), method) 126 | 127 | 128 | def __scale_width(img, target_width, method=Image.BICUBIC): 129 | ow, oh = img.size 130 | if (ow == target_width): 131 | return img 132 | w = target_width 133 | h = int(target_width * oh / ow) 134 | return img.resize((w, h), method) 135 | 136 | 137 | def __crop(img, pos, size): 138 | ow, oh = img.size 139 | x1, y1 = pos 140 | tw = th = size 141 | if (ow > tw or oh > th): 142 | return img.crop((x1, y1, x1 + tw, y1 + th)) 143 | return img 144 | 145 | 146 | def __flip(img, flip): 147 | if flip: 148 | return img.transpose(Image.FLIP_LEFT_RIGHT) 149 | return img 150 | 151 | 152 | def __print_size_warning(ow, oh, w, h): 153 | """Print warning information about image size(only print once)""" 154 | if not hasattr(__print_size_warning, 'has_printed'): 155 | print("The image size needs to be a multiple of 4. " 156 | "The loaded image size was (%d, %d), so it was adjusted to " 157 | "(%d, %d). This adjustment will be done to all images " 158 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 159 | __print_size_warning.has_printed = True 160 | -------------------------------------------------------------------------------- /applications/iih_mef/data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | import os.path 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | ] 17 | 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | 23 | def make_dataset(dir, max_dataset_size=float("inf")): 24 | images = [] 25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 26 | 27 | for root, _, fnames in sorted(os.walk(dir)): 28 | for fname in fnames: 29 | if is_image_file(fname): 30 | path = os.path.join(root, fname) 31 | images.append(path) 32 | return images[:min(max_dataset_size, len(images))] 33 | 34 | 35 | def default_loader(path): 36 | return Image.open(path).convert('RGB') 37 | 38 | 39 | class ImageFolder(data.Dataset): 40 | 41 | def __init__(self, root, transform=None, return_paths=False, 42 | loader=default_loader): 43 | imgs = make_dataset(root) 44 | if len(imgs) == 0: 45 | raise(RuntimeError("Found 0 images in: " + root + "\n" 46 | "Supported image extensions are: " + 47 | ",".join(IMG_EXTENSIONS))) 48 | 49 | self.root = root 50 | self.imgs = imgs 51 | self.transform = transform 52 | self.return_paths = return_paths 53 | self.loader = loader 54 | 55 | def __getitem__(self, index): 56 | path = self.imgs[index] 57 | img = self.loader(path) 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | if self.return_paths: 61 | return img, path 62 | else: 63 | return img 64 | 65 | def __len__(self): 66 | return len(self.imgs) 67 | -------------------------------------------------------------------------------- /applications/iih_mef/data/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Data loader.""" 5 | 6 | import itertools 7 | import numpy as np 8 | import torch 9 | from torch.utils.data._utils.collate import default_collate 10 | from torch.utils.data.distributed import DistributedSampler 11 | from torch.utils.data.sampler import RandomSampler 12 | 13 | from slowfast.datasets.multigrid_helper import ShortCycleBatchSampler 14 | 15 | from . import utils as utils 16 | 17 | def build_dataset(cfg): 18 | image_paths = [] 19 | if cfg.phase == 'train': 20 | print('loading training file') 21 | file = cfg.dataset_root+cfg.dataset_name+'_train.txt' 22 | with open(file,'r') as f: 23 | for line in f.readlines(): 24 | image_paths.append(os.path.join(cfg.dataset_root,'composite_images',line.rstrip())) 25 | 26 | 27 | def construct_loader(cfg, split, is_precise_bn=False): 28 | """ 29 | Constructs the data loader for the given dataset. 30 | Args: 31 | cfg (CfgNode): configs. Details can be found in 32 | slowfast/config/defaults.py 33 | split (str): the split of the data loader. Options include `train`, 34 | `val`, and `test`. 35 | """ 36 | assert split in ["train", "val", "test"] 37 | if split in ["train"]: 38 | dataset_name = cfg.TRAIN.DATASET 39 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 40 | shuffle = True 41 | drop_last = True 42 | elif split in ["val"]: 43 | dataset_name = cfg.TRAIN.DATASET 44 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 45 | shuffle = False 46 | drop_last = False 47 | elif split in ["test"]: 48 | dataset_name = cfg.TEST.DATASET 49 | batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 50 | shuffle = False 51 | drop_last = False 52 | 53 | # Construct the dataset 54 | dataset = build_dataset(dataset_name, cfg, split) 55 | 56 | if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn: 57 | # Create a sampler for multi-process training 58 | sampler = utils.create_sampler(dataset, shuffle, cfg) 59 | batch_sampler = ShortCycleBatchSampler( 60 | sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg 61 | ) 62 | # Create a loader 63 | loader = torch.utils.data.DataLoader( 64 | dataset, 65 | batch_sampler=batch_sampler, 66 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 67 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 68 | worker_init_fn=utils.loader_worker_init_fn(dataset), 69 | ) 70 | else: 71 | # Create a sampler for multi-process training 72 | sampler = utils.create_sampler(dataset, shuffle, cfg) 73 | # Create a loader 74 | loader = torch.utils.data.DataLoader( 75 | dataset, 76 | batch_size=batch_size, 77 | shuffle=(False if sampler else shuffle), 78 | sampler=sampler, 79 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 80 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 81 | drop_last=drop_last, 82 | collate_fn=detection_collate if cfg.DETECTION.ENABLE else None, 83 | worker_init_fn=utils.loader_worker_init_fn(dataset), 84 | ) 85 | return loader 86 | 87 | 88 | def shuffle_dataset(loader, cur_epoch): 89 | """ " 90 | Shuffles the data. 91 | Args: 92 | loader (loader): data loader to perform shuffle. 93 | cur_epoch (int): number of the current epoch. 94 | """ 95 | sampler = ( 96 | loader.batch_sampler.sampler 97 | if isinstance(loader.batch_sampler, ShortCycleBatchSampler) 98 | else loader.sampler 99 | ) 100 | assert isinstance( 101 | sampler, (RandomSampler, DistributedSampler) 102 | ), "Sampler type '{}' not supported".format(type(sampler)) 103 | # RandomSampler handles shuffling automatically 104 | if isinstance(sampler, DistributedSampler): 105 | # DistributedSampler shuffles data based on epoch 106 | sampler.set_epoch(cur_epoch) 107 | -------------------------------------------------------------------------------- /applications/iih_mef/data/mef_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import os.path 3 | import torch 4 | import torchvision.transforms.functional as tf 5 | import torch.nn.functional as F 6 | import random 7 | from torchvision.transforms.transforms import RandomCrop, RandomResizedCrop 8 | from data.base_dataset import BaseDataset, get_transform 9 | from data.image_folder import make_dataset 10 | from PIL import Image 11 | import numpy as np 12 | import torchvision.transforms as transforms 13 | from util import util 14 | 15 | class MefDataset(BaseDataset): 16 | @staticmethod 17 | def modify_commandline_options(parser, is_train): 18 | parser.add_argument('--is_train', type=bool, default=True, help='whether in the training phase') 19 | parser.set_defaults(max_dataset_size=float("inf"), new_dataset_option=2.0) # specify dataset-specific default values 20 | return parser 21 | 22 | def __init__(self, opt): 23 | # save the option and dataset root 24 | BaseDataset.__init__(self, opt) 25 | self.fake_image_paths = [] 26 | self.image_paths = [] 27 | self.isTrain = opt.isTrain 28 | self.image_size = opt.crop_size 29 | 30 | if opt.isTrain==True: 31 | print('loading training file') 32 | self.trainfile = opt.dataset_root+'Dataset_Part1_resize/'+'part1_train.txt' 33 | with open(self.trainfile,'r') as f: 34 | for line in f.readlines(): 35 | name = line.rstrip().split('.') 36 | self.image_paths.append(os.path.join(opt.dataset_root,'Dataset_Part1_resize/',name[0])) 37 | self.trainfile = opt.dataset_root+'Dataset_Part2_resize/'+'part2_train.txt' 38 | with open(self.trainfile,'r') as f: 39 | for line in f.readlines(): 40 | name = line.rstrip().split('.') 41 | self.image_paths.append(os.path.join(opt.dataset_root,'Dataset_Part2_resize/',name[0])) 42 | elif opt.isTrain==False: 43 | print('loading test file') 44 | self.trainfile = opt.dataset_root+'Dataset_Part1_resize/'+'part1_test.txt' 45 | with open(self.trainfile,'r') as f: 46 | for line in f.readlines(): 47 | name = line.rstrip().split('.') 48 | self.image_paths.append(os.path.join(opt.dataset_root,'Dataset_Part1_resize/',name[0])) 49 | self.trainfile = opt.dataset_root+'Dataset_Part2_resize/'+'part2_test.txt' 50 | with open(self.trainfile,'r') as f: 51 | for line in f.readlines(): 52 | name = line.rstrip().split('.') 53 | self.image_paths.append(os.path.join(opt.dataset_root,'Dataset_Part2_resize/',name[0])) 54 | transform_list = [ 55 | # transforms.RandomCrop(self.image_size), 56 | transforms.ToTensor(), 57 | transforms.Normalize((0, 0, 0), (1, 1, 1)) 58 | ] 59 | self.transforms = transforms.Compose(transform_list) 60 | def __getitem__(self, index): 61 | path = self.image_paths[index] 62 | files = os.listdir(path) 63 | 64 | files.sort(key= lambda x:int(x[:-4])) 65 | if self.isTrain: 66 | max_file = files[-1] 67 | min_file = files[0] 68 | else: 69 | max_file = files[-1] 70 | min_file = files[0] 71 | 72 | u_path = os.path.join(path,min_file) 73 | o_path = os.path.join(path,max_file) 74 | file_name_path = path+".JPG" 75 | name_parts=file_name_path.split('/') 76 | target_path = file_name_path.replace(name_parts[-1],'Label/'+name_parts[-1]) 77 | if not os.path.exists(target_path): 78 | target_path = target_path.replace(".JPG",".PNG") 79 | fake_u = Image.open(u_path).convert('RGB') 80 | fake_o = Image.open(o_path).convert('RGB') 81 | real = Image.open(target_path).convert('RGB') 82 | if np.random.rand() > 0.5 and self.isTrain: 83 | fake_u, fake_o, real = tf.hflip(fake_u), tf.hflip(fake_o), tf.hflip(real) 84 | fake_u = self.transforms(fake_u) 85 | fake_o = self.transforms(fake_o) 86 | real = self.transforms(real) 87 | 88 | return {'fake_u': fake_u, 'fake_o': fake_o, 'real': real,'img_path':path} 89 | 90 | def __len__(self): 91 | """Return the total number of images.""" 92 | return len(self.image_paths) -------------------------------------------------------------------------------- /applications/iih_mef/models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /applications/iih_mef/models/iih_base_gd_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import itertools 4 | import torch.nn.functional as F 5 | from util import distributed as du 6 | # import pytorch_colors as colors 7 | from .base_model import BaseModel 8 | from util import util 9 | from . import harmony_networks as networks 10 | import util.ssim as ssim 11 | 12 | 13 | class IIHBaseGDModel(BaseModel): 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train=True): 16 | parser.set_defaults(norm='instance', netG='base_gd', dataset_mode='mef') 17 | if is_train: 18 | parser.add_argument('--lambda_L1', type=float, default=50.0, help='weight for L1 loss') 19 | parser.add_argument('--lambda_R', type=float, default=100., help='weight for R gradient loss') 20 | parser.add_argument('--lambda_ssim', type=float, default=50., help='weight for L L2 loss') 21 | parser.add_argument('--lambda_ifm', type=float, default=100, help='weight for pm loss') 22 | 23 | return parser 24 | 25 | def __init__(self, opt): 26 | BaseModel.__init__(self, opt) 27 | self.opt = opt 28 | # specify the training losses you want to print out. The training/test scripts will call 29 | self.loss_names = ['G','G_L1','G_R','G_R_SSIM',"IF"] 30 | 31 | self.visual_names = ['hdr','real','fake_u','fake_o'] 32 | self.model_names = ['G'] 33 | self.opt.device = self.device 34 | self.netG = networks.define_G(opt.netG, opt.init_type, opt.init_gain, self.opt) 35 | self.cur_device = torch.cuda.current_device() 36 | self.ismaster = du.is_master_proc(opt.NUM_GPUS) 37 | print(self.netG) 38 | 39 | if self.isTrain: 40 | util.saveprint(self.opt, 'netG', str(self.netG)) 41 | # define loss functions 42 | self.criterionL1 = torch.nn.L1Loss() 43 | self.criterionL2 = torch.nn.MSELoss() 44 | self.criterionSSIM = ssim.SSIM() 45 | self.criterionDSSIM_CS = ssim.DSSIM(mode='c_s').to(self.device) 46 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 47 | self.optimizers.append(self.optimizer_G) 48 | 49 | def set_input(self, input): 50 | self.fake_u = input['fake_u'].to(self.device) 51 | self.fake_o = input['fake_o'].to(self.device) 52 | self.real = input['real'].to(self.device) 53 | self.image_paths = input['img_path'] 54 | self.real_r = F.interpolate(self.real, size=[32,32]) 55 | self.real_gray = util.rgbtogray(self.real_r) 56 | def forward(self): 57 | self.reconstruct_u, self.reconstruct_o, self.hdr, self.ifm_mean = self.netG(self.fake_u, self.fake_o) 58 | def backward_G(self): 59 | self.loss_IF = (self.criterionDSSIM_CS(self.ifm_mean, self.real_gray))*self.opt.lambda_ifm 60 | 61 | self.loss_G_L1 = (self.criterionL1(self.reconstruct_u, self.fake_u)+self.criterionL1(self.reconstruct_o, self.fake_o))*self.opt.lambda_L1 62 | self.loss_G_R = self.criterionL2(self.hdr, self.real)*self.opt.lambda_R 63 | self.loss_G_R_SSIM = (1-self.criterionSSIM(self.hdr, self.real))*self.opt.lambda_ssim 64 | self.loss_G = self.loss_G_L1 + self.loss_G_R + self.loss_G_R_SSIM + self.loss_IF 65 | self.loss_G.backward() 66 | 67 | def optimize_parameters(self): 68 | self.forward() # compute fake images: G(A) 69 | # update G 70 | self.optimizer_G.zero_grad() # set G's gradients to zero 71 | self.backward_G() # calculate graidents for G 72 | self.optimizer_G.step() # udpate G's weights 73 | -------------------------------------------------------------------------------- /applications/iih_mef/options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /applications/iih_mef/options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | # Dropout and Batchnorm has different behavioir during training and test. 17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 18 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 19 | parser.add_argument('--test_epoch', type=str, default="0", help='how many test images to run') 20 | # rewrite devalue values 21 | parser.set_defaults(model='test') 22 | # To avoid cropping, the load_size should be the same as crop_size 23 | parser.set_defaults(load_size=parser.get_default('crop_size')) 24 | self.isTrain = False 25 | return parser 26 | -------------------------------------------------------------------------------- /applications/iih_mef/options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | # visdom and HTML visualization parameters 13 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 14 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 15 | parser.add_argument('--display_id', type=int, default=0, help='window id of the web display') 16 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 17 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 18 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 19 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 20 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 21 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 22 | # network saving and loading parameters 23 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 24 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 25 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 26 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 27 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 28 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 29 | # training parameters 30 | parser.add_argument('--niter', type=int, default=300, help='# of iter at starting learning rate') 31 | parser.add_argument('--niter_decay', type=int, default=300, help='# of iter to linearly decay learning rate to zero') 32 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 33 | parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') 34 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 35 | parser.add_argument('--g_lr_ratio', type=float, default=1.0, help='a ratio for changing learning rate of generator') 36 | parser.add_argument('--d_lr_ratio', type=float, default=1.0, help='a ratio for changing learning rate of discriminator') 37 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') 38 | parser.add_argument('--pool_size', type=int, default=40, help='the size of image buffer that stores previously generated images') 39 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') 40 | parser.add_argument('--lr_decay_iters', type=int, default=40, help='multiply by a gamma every lr_decay_iters iterations') 41 | parser.add_argument('--save_iter_model', action='store_true', help='whether saves model by iteration') 42 | 43 | 44 | 45 | 46 | self.isTrain = True 47 | return parser 48 | -------------------------------------------------------------------------------- /applications/iih_mef/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from util.misc import launch_job 5 | from train_net import test 6 | 7 | from options.test_options import TestOptions 8 | 9 | def main(): 10 | cfg = TestOptions().parse() # get training options 11 | cfg.NUM_GPUS = torch.cuda.device_count() 12 | cfg.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 13 | cfg.no_flip = True # no flip; comment this line if results on flipped images are needed. 14 | cfg.display_id = -1 # no visdom display; the test code saves the results to a HTML file. 15 | 16 | cfg.phase = 'test' 17 | cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS)) 18 | launch_job(cfg=cfg, init_method=cfg.init_method, func=test) 19 | 20 | 21 | if __name__=="__main__": 22 | main() -------------------------------------------------------------------------------- /applications/iih_mef/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from util.misc import launch_job 5 | from train_net import train 6 | 7 | from options.train_options import TrainOptions 8 | 9 | def main(): 10 | cfg = TrainOptions().parse() # get training options 11 | cfg.NUM_GPUS = torch.cuda.device_count() 12 | cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS)) 13 | cfg.phase = 'train' 14 | launch_job(cfg=cfg, init_method=cfg.init_method, func=train) 15 | 16 | 17 | if __name__=="__main__": 18 | main() -------------------------------------------------------------------------------- /applications/iih_mef/util/evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as f 3 | 4 | 5 | def evaluation(name, fake, real, mask): 6 | b,c,w,h = real.size() 7 | mse_score = f.mse_loss(fake, real) 8 | fore_area = torch.sum(mask) 9 | fmse_score = f.mse_loss(fake*mask,real*mask)*w*h/fore_area 10 | mse_score = mse_score.item() 11 | fmse_score = fmse_score.item() 12 | # score_str = "%s MSE %0.2f | fMSE %0.2f" % (name, mse_score,fmse_score) 13 | image_fmse_info = (name, round(fmse_score,2), round(mse_score, 2)) 14 | return mse_score, fmse_score, image_fmse_info -------------------------------------------------------------------------------- /applications/iih_mef/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /applications/iih_mef/util/multiprocessing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Multiprocessing helpers.""" 5 | 6 | import torch 7 | 8 | 9 | def run( 10 | local_rank, 11 | num_proc, 12 | func, 13 | init_method, 14 | shard_id, 15 | num_shards, 16 | backend, 17 | cfg, 18 | output_queue=None, 19 | ): 20 | """ 21 | Runs a function from a child process. 22 | Args: 23 | local_rank (int): rank of the current process on the current machine. 24 | num_proc (int): number of processes per machine. 25 | func (function): function to execute on each of the process. 26 | init_method (string): method to initialize the distributed training. 27 | TCP initialization: equiring a network address reachable from all 28 | processes followed by the port. 29 | Shared file-system initialization: makes use of a file system that 30 | is shared and visible from all machines. The URL should start with 31 | file:// and contain a path to a non-existent file on a shared file 32 | system. 33 | shard_id (int): the rank of the current machine. 34 | num_shards (int): number of overall machines for the distributed 35 | training job. 36 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are 37 | supports, each with different capabilities. Details can be found 38 | here: 39 | https://pytorch.org/docs/stable/distributed.html 40 | cfg (CfgNode): configs. Details can be found in 41 | slowfast/config/defaults.py 42 | output_queue (queue): can optionally be used to return values from the 43 | master process. 44 | """ 45 | # Initialize the process group. 46 | world_size = num_proc * num_shards 47 | rank = shard_id * num_proc + local_rank 48 | try: 49 | torch.distributed.init_process_group( 50 | backend=backend, 51 | init_method=init_method, 52 | world_size=world_size, 53 | rank=rank, 54 | ) 55 | 56 | except Exception as e: 57 | raise e 58 | 59 | torch.cuda.set_device(local_rank) 60 | ret = func(cfg) 61 | if output_queue is not None and local_rank == 0: 62 | output_queue.put(ret) 63 | -------------------------------------------------------------------------------- /applications/iih_mef/util/ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | def _ssim_c_s(img1, img2, window, window_size, channel, size_average = True): 40 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 41 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 42 | 43 | mu1_sq = mu1.pow(2) 44 | mu2_sq = mu2.pow(2) 45 | mu1_mu2 = mu1*mu2 46 | 47 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 48 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 49 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 50 | 51 | C1 = 0.01**2 52 | C2 = 0.03**2 53 | 54 | ssim_map = ((2*sigma12 + C2))/((sigma1_sq + sigma2_sq + C2)) 55 | 56 | if size_average: 57 | return ssim_map.mean() 58 | else: 59 | return ssim_map.mean(1).mean(1).mean(1) 60 | def _ssim_l(img1, img2, window, window_size, channel, size_average = True): 61 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 62 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 63 | 64 | mu1_sq = mu1.pow(2) 65 | mu2_sq = mu2.pow(2) 66 | mu1_mu2 = mu1*mu2 67 | 68 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 69 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 70 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 71 | 72 | C1 = 0.01**2 73 | C2 = 0.03**2 74 | 75 | ssim_map = (2*mu1_mu2 + C1)/(mu1_sq + mu2_sq + C1) 76 | 77 | if size_average: 78 | return ssim_map.mean() 79 | else: 80 | return ssim_map.mean(1).mean(1).mean(1) 81 | 82 | 83 | def ssim(img1, img2, window_size = 11, size_average = True): 84 | (_, channel, _, _) = img1.size() 85 | window = create_window(window_size, channel) 86 | 87 | if img1.is_cuda: 88 | window = window.cuda(img1.get_device()) 89 | window = window.type_as(img1) 90 | 91 | return _ssim(img1, img2, window, window_size, channel, size_average) 92 | 93 | class SSIM(torch.nn.Module): 94 | def __init__(self, window_size = 11, size_average = True, mode='all'): 95 | super(SSIM, self).__init__() 96 | self.window_size = window_size 97 | self.size_average = size_average 98 | self.channel = 1 99 | self.window = create_window(window_size, self.channel) 100 | self.mode = mode 101 | def forward(self, img1, img2): 102 | (_, channel, _, _) = img1.size() 103 | 104 | if channel == self.channel and self.window.data.type() == img1.data.type(): 105 | window = self.window 106 | else: 107 | window = create_window(self.window_size, channel) 108 | 109 | # if img1.is_cuda: 110 | # window = window.cuda(img1.get_device()) 111 | window = window.type_as(img1) 112 | 113 | self.window = window 114 | self.channel = channel 115 | if self.mode == 'all': 116 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 117 | elif self.mode == 'c_s': 118 | return _ssim_c_s(img1, img2, window, self.window_size, channel, self.size_average) 119 | else: 120 | return _ssim_l(img1, img2, window, self.window_size, channel, self.size_average) 121 | 122 | class DSSIM(torch.nn.Module): 123 | def __init__(self, window_size = 11, size_average = True, mode='all'): 124 | super(DSSIM, self).__init__() 125 | self.window_size = window_size 126 | self.size_average = size_average 127 | self.channel = 1 128 | self.window = create_window(window_size, self.channel) 129 | self.mode = mode 130 | def forward(self, img1, img2): 131 | (_, channel, _, _) = img1.size() 132 | 133 | if channel == self.channel and self.window.data.type() == img1.data.type(): 134 | window = self.window 135 | else: 136 | window = create_window(self.window_size, channel) 137 | 138 | # if img1.is_cuda: 139 | # window = window.cuda(img1.get_device()) 140 | window = window.type_as(img1) 141 | 142 | self.window = window 143 | self.channel = channel 144 | if self.mode == 'all': 145 | ssim_v = _ssim(img1, img2, window, self.window_size, channel, self.size_average) 146 | elif self.mode == 'c_s': 147 | ssim_v = _ssim_c_s(img1, img2, window, self.window_size, channel, self.size_average) 148 | else: 149 | ssim_v = _ssim_l(img1, img2, window, self.window_size, channel, self.size_average) 150 | return (1-ssim_v)/2 -------------------------------------------------------------------------------- /applications/iih_relighting/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/* 2 | checkpoints/* 3 | results/* 4 | */__pycache__/* 5 | tmp/* 6 | __pycache__/distribute.cpython-37.pyc 7 | __pycache__/options.cpython-37.pyc 8 | __pycache__/train_net.cpython-37.pyc 9 | __pycache__/train.cpython-37.pyc -------------------------------------------------------------------------------- /applications/iih_relighting/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | # Portrait Relighting
5 | 6 | Here we provide PyTorch implementation and the pre-trained model of our latest version. 7 | 8 | ## Prerequisites 9 | 10 | - Linux 11 | - Python 3 12 | - CPU or NVIDIA GPU + CUDA CuDNN 13 | 14 | ## Base Model with Lighting 15 | - Download DPR dataset. 16 | 17 | - Train 18 | ```bash 19 | CUDA_VISIBLE_DEVICES=0 python train.py --model iih_base_lt --name base_lt_relighting_test --dataset_root --dataset_name DPR --batch_size xx --init_port xxxx 20 | ``` 21 | - Test 22 | ```bash 23 | # SH-based relighting 24 | CUDA_VISIBLE_DEVICES=0 python test.py --model iih_base_lt --name base_lt_relighting_test --dataset_root --dataset_name DPR --batch_size xx --init_port xxxx 25 | #Image-based relighting 26 | CUDA_VISIBLE_DEVICES=0 python test.py --model iih_base_lt --name base_lt_relighting_test --relighting_action transfer --dataset_root --dataset_name DPR --dataset_mode dprtransfer --batch_size xx --init_port xxxx 27 | ``` 28 | 29 | - Apply pre-trained model 30 | 31 | Download pre-trained model from [Google Drive](https://drive.google.com/file/d/11yGZvo-gLDRyfnO0A6xuqPmDaPcMB1en/view?usp=sharing) or [BaiduCloud](https://pan.baidu.com/s/1yrUZ2YkT2bY9ThfYn_gJAg) (access code: bjqb), and put `latest_net_G.pth` in the directory `checkpoints/base_lt_relighting`. Run: 32 | 33 | ```bash 34 | # SH-based relighting 35 | CUDA_VISIBLE_DEVICES=0 python test.py --model iih_base_lt --name base_lt_relighting --dataset_root --dataset_name DPR --batch_size xx --init_port xxxx 36 | #Image-based relighting 37 | CUDA_VISIBLE_DEVICES=0 python test.py --model iih_base_lt --name base_lt_relighting --relighting_action transfer --dataset_root --dataset_name DPR --dataset_mode dprtransfer --batch_size xx --init_port xxxx 38 | ``` 39 | -------------------------------------------------------------------------------- /applications/iih_relighting/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | # self.root = opt.dataroot 31 | self.root = opt.dataset_root 32 | 33 | @staticmethod 34 | def modify_commandline_options(parser, is_train): 35 | """Add new dataset-specific options, and rewrite default values for existing options. 36 | 37 | Parameters: 38 | parser -- original option parser 39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 40 | 41 | Returns: 42 | the modified parser. 43 | """ 44 | return parser 45 | 46 | @abstractmethod 47 | def __len__(self): 48 | """Return the total number of images in the dataset.""" 49 | return 0 50 | 51 | @abstractmethod 52 | def __getitem__(self, index): 53 | """Return a data point and its metadata information. 54 | 55 | Parameters: 56 | index - - a random integer for data indexing 57 | 58 | Returns: 59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 60 | """ 61 | pass 62 | 63 | 64 | def get_params(opt, size): 65 | w, h = size 66 | new_h = h 67 | new_w = w 68 | if opt.preprocess == 'resize_and_crop': 69 | new_h = new_w = opt.load_size 70 | elif opt.preprocess == 'scale_width_and_crop': 71 | new_w = opt.load_size 72 | new_h = opt.load_size * h // w 73 | 74 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 75 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 76 | 77 | flip = random.random() > 0.5 78 | 79 | return {'crop_pos': (x, y), 'flip': flip} 80 | 81 | 82 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 83 | transform_list = [] 84 | if grayscale: 85 | transform_list.append(transforms.Grayscale(1)) 86 | if 'resize' in opt.preprocess: 87 | osize = [opt.load_size, opt.load_size] 88 | transform_list.append(transforms.Resize(osize, method)) 89 | elif 'scale_width' in opt.preprocess: 90 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 91 | 92 | if 'crop' in opt.preprocess: 93 | if params is None: 94 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 95 | else: 96 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 97 | 98 | if opt.preprocess == 'none': 99 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 100 | 101 | if not opt.no_flip: 102 | if params is None: 103 | transform_list.append(transforms.RandomHorizontalFlip()) 104 | elif params['flip']: 105 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 106 | 107 | if convert: 108 | transform_list += [transforms.ToTensor()] 109 | if grayscale: 110 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 111 | else: 112 | # transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 113 | transform_list += [transforms.Normalize((0, 0, 0), (1, 1, 1))] 114 | return transforms.Compose(transform_list) 115 | 116 | 117 | def __make_power_2(img, base, method=Image.BICUBIC): 118 | ow, oh = img.size 119 | h = int(round(oh / base) * base) 120 | w = int(round(ow / base) * base) 121 | if (h == oh) and (w == ow): 122 | return img 123 | 124 | __print_size_warning(ow, oh, w, h) 125 | return img.resize((w, h), method) 126 | 127 | 128 | def __scale_width(img, target_width, method=Image.BICUBIC): 129 | ow, oh = img.size 130 | if (ow == target_width): 131 | return img 132 | w = target_width 133 | h = int(target_width * oh / ow) 134 | return img.resize((w, h), method) 135 | 136 | 137 | def __crop(img, pos, size): 138 | ow, oh = img.size 139 | x1, y1 = pos 140 | tw = th = size 141 | if (ow > tw or oh > th): 142 | return img.crop((x1, y1, x1 + tw, y1 + th)) 143 | return img 144 | 145 | 146 | def __flip(img, flip): 147 | if flip: 148 | return img.transpose(Image.FLIP_LEFT_RIGHT) 149 | return img 150 | 151 | 152 | def __print_size_warning(ow, oh, w, h): 153 | """Print warning information about image size(only print once)""" 154 | if not hasattr(__print_size_warning, 'has_printed'): 155 | print("The image size needs to be a multiple of 4. " 156 | "The loaded image size was (%d, %d), so it was adjusted to " 157 | "(%d, %d). This adjustment will be done to all images " 158 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 159 | __print_size_warning.has_printed = True 160 | -------------------------------------------------------------------------------- /applications/iih_relighting/data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | import os.path 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | ] 17 | 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | 23 | def make_dataset(dir, max_dataset_size=float("inf")): 24 | images = [] 25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 26 | 27 | for root, _, fnames in sorted(os.walk(dir)): 28 | for fname in fnames: 29 | if is_image_file(fname): 30 | path = os.path.join(root, fname) 31 | images.append(path) 32 | return images[:min(max_dataset_size, len(images))] 33 | 34 | 35 | def default_loader(path): 36 | return Image.open(path).convert('RGB') 37 | 38 | 39 | class ImageFolder(data.Dataset): 40 | 41 | def __init__(self, root, transform=None, return_paths=False, 42 | loader=default_loader): 43 | imgs = make_dataset(root) 44 | if len(imgs) == 0: 45 | raise(RuntimeError("Found 0 images in: " + root + "\n" 46 | "Supported image extensions are: " + 47 | ",".join(IMG_EXTENSIONS))) 48 | 49 | self.root = root 50 | self.imgs = imgs 51 | self.transform = transform 52 | self.return_paths = return_paths 53 | self.loader = loader 54 | 55 | def __getitem__(self, index): 56 | path = self.imgs[index] 57 | img = self.loader(path) 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | if self.return_paths: 61 | return img, path 62 | else: 63 | return img 64 | 65 | def __len__(self): 66 | return len(self.imgs) 67 | -------------------------------------------------------------------------------- /applications/iih_relighting/data/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Data loader.""" 5 | 6 | import itertools 7 | import numpy as np 8 | import torch 9 | from torch.utils.data._utils.collate import default_collate 10 | from torch.utils.data.distributed import DistributedSampler 11 | from torch.utils.data.sampler import RandomSampler 12 | 13 | from slowfast.datasets.multigrid_helper import ShortCycleBatchSampler 14 | 15 | from . import utils as utils 16 | 17 | def build_dataset(cfg): 18 | image_paths = [] 19 | if cfg.phase == 'train': 20 | print('loading training file') 21 | file = cfg.dataset_root+cfg.dataset_name+'_train.txt' 22 | with open(file,'r') as f: 23 | for line in f.readlines(): 24 | image_paths.append(os.path.join(cfg.dataset_root,'composite_images',line.rstrip())) 25 | 26 | 27 | def construct_loader(cfg, split, is_precise_bn=False): 28 | """ 29 | Constructs the data loader for the given dataset. 30 | Args: 31 | cfg (CfgNode): configs. Details can be found in 32 | slowfast/config/defaults.py 33 | split (str): the split of the data loader. Options include `train`, 34 | `val`, and `test`. 35 | """ 36 | assert split in ["train", "val", "test"] 37 | if split in ["train"]: 38 | dataset_name = cfg.TRAIN.DATASET 39 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 40 | shuffle = True 41 | drop_last = True 42 | elif split in ["val"]: 43 | dataset_name = cfg.TRAIN.DATASET 44 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 45 | shuffle = False 46 | drop_last = False 47 | elif split in ["test"]: 48 | dataset_name = cfg.TEST.DATASET 49 | batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 50 | shuffle = False 51 | drop_last = False 52 | 53 | # Construct the dataset 54 | dataset = build_dataset(dataset_name, cfg, split) 55 | 56 | if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn: 57 | # Create a sampler for multi-process training 58 | sampler = utils.create_sampler(dataset, shuffle, cfg) 59 | batch_sampler = ShortCycleBatchSampler( 60 | sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg 61 | ) 62 | # Create a loader 63 | loader = torch.utils.data.DataLoader( 64 | dataset, 65 | batch_sampler=batch_sampler, 66 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 67 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 68 | worker_init_fn=utils.loader_worker_init_fn(dataset), 69 | ) 70 | else: 71 | # Create a sampler for multi-process training 72 | sampler = utils.create_sampler(dataset, shuffle, cfg) 73 | # Create a loader 74 | loader = torch.utils.data.DataLoader( 75 | dataset, 76 | batch_size=batch_size, 77 | shuffle=(False if sampler else shuffle), 78 | sampler=sampler, 79 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 80 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 81 | drop_last=drop_last, 82 | collate_fn=detection_collate if cfg.DETECTION.ENABLE else None, 83 | worker_init_fn=utils.loader_worker_init_fn(dataset), 84 | ) 85 | return loader 86 | 87 | 88 | def shuffle_dataset(loader, cur_epoch): 89 | """ " 90 | Shuffles the data. 91 | Args: 92 | loader (loader): data loader to perform shuffle. 93 | cur_epoch (int): number of the current epoch. 94 | """ 95 | sampler = ( 96 | loader.batch_sampler.sampler 97 | if isinstance(loader.batch_sampler, ShortCycleBatchSampler) 98 | else loader.sampler 99 | ) 100 | assert isinstance( 101 | sampler, (RandomSampler, DistributedSampler) 102 | ), "Sampler type '{}' not supported".format(type(sampler)) 103 | # RandomSampler handles shuffling automatically 104 | if isinstance(sampler, DistributedSampler): 105 | # DistributedSampler shuffles data based on epoch 106 | sampler.set_epoch(cur_epoch) 107 | -------------------------------------------------------------------------------- /applications/iih_relighting/models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /applications/iih_relighting/models/iih_base_lt_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools 3 | import torch.nn.functional as F 4 | from util import distributed as du 5 | from .base_model import BaseModel 6 | from util import util 7 | from . import relighting_networks as networks 8 | from . import networks as network_init 9 | import util.ssim as ssim 10 | 11 | 12 | class IIHBaseLTModel(BaseModel): 13 | 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train=True): 16 | 17 | parser.set_defaults(norm='instance', netG='base_lt', dataset_mode='dpr') 18 | parser.add_argument('--action', type=str, default='relighting', help='weight for L1 loss') 19 | if is_train: 20 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 21 | parser.add_argument('--lambda_R_gradient', type=float, default=10., help='weight for R gradient loss') 22 | parser.add_argument('--lambda_ssim', type=float, default=50., help='weight for L L2 loss') 23 | parser.add_argument('--lambda_I_smooth', type=float, default=1., help='weight for L L2 loss') 24 | parser.add_argument('--lambda_I_L2', type=float, default=10., help='weight for L L2 loss') 25 | parser.add_argument('--lambda_ifm', type=float, default=100, help='weight for pm loss') 26 | parser.add_argument('--lambda_L', type=float, default=100.0, help='weight for L1 loss') 27 | 28 | return parser 29 | 30 | def __init__(self, opt): 31 | """Initialize the pix2pix class. 32 | 33 | Parameters: 34 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 35 | """ 36 | BaseModel.__init__(self, opt) 37 | self.opt = opt 38 | # specify the training losses you want to print out. The training/test scripts will call 39 | self.loss_names = ['G','G_L1','G_R','G_I_L2','G_I_smooth',"G_L"] 40 | 41 | # specify the images you want to save/display. The training/test scripts will call 42 | self.visual_names = ['harmonized','real','fake','reflectance','illumination'] 43 | # specify the models you want to save to the disk. The training/test scripts will call and 44 | self.model_names = ['G'] 45 | self.opt.device = self.device 46 | self.netG = networks.define_G(opt.netG, opt.init_type, opt.init_gain, self.opt) 47 | self.cur_device = torch.cuda.current_device() 48 | self.ismaster = du.is_master_proc(opt.NUM_GPUS) 49 | if self.ismaster: 50 | print(self.netG) 51 | 52 | if self.isTrain: 53 | util.saveprint(self.opt, 'netG', str(self.netG)) 54 | # define loss functions 55 | self.criterionL1 = torch.nn.L1Loss() 56 | self.criterionL2 = torch.nn.MSELoss() 57 | self.criterionSSIM = ssim.SSIM() 58 | # initialize optimizers; schedulers will be automatically created by function . 59 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 60 | self.optimizers.append(self.optimizer_G) 61 | 62 | def set_input(self, input): 63 | 64 | self.fake = input['fake'].to(self.device) 65 | self.real = input['real'].to(self.device) 66 | if input['target'] is not None: 67 | self.target = input['target'].to(self.device) 68 | if input['fake_light'] is not None: 69 | self.light_fake = input['fake_light'].to(self.device) 70 | if input['real_light'] is not None: 71 | self.light_real = input['real_light'].to(self.device) 72 | self.image_paths = input['img_path'] 73 | 74 | def forward(self): 75 | """Run forward pass; called by both functions and .""" 76 | if self.isTrain: 77 | self.harmonized, self.reflectance, self.illumination, self.light_gen_fake = self.netG(self.fake, self.light_real) 78 | else: 79 | if self.opt.action == "relighting": 80 | self.harmonized, self.reflectance, self.illumination, self.light_gen_fake = self.netG(self.fake, isTest=True, light=self.light_real) 81 | else: 82 | self.harmonized, self.reflectance, self.illumination, self.light_gen_fake = self.netG(self.fake, isTest=True, target=self.target) 83 | def backward_G(self): 84 | """Calculate GAN and L1 loss for the generator""" 85 | self.loss_G_L1 = self.criterionL1(self.harmonized, self.real)*self.opt.lambda_L1 86 | self.loss_G_R = (self.gradient_loss(self.reflectance, self.fake)+self.gradient_loss(self.reflectance, self.real))*self.opt.lambda_R_gradient 87 | self.loss_G_I_smooth = util.compute_smooth_loss(self.illumination)*self.opt.lambda_I_smooth 88 | self.loss_G_I_L2 = self.criterionL2(self.illumination, self.real)*self.opt.lambda_I_L2 89 | self.loss_G_L = self.criterionL2(self.light_gen_fake, self.light_fake)*self.opt.lambda_L 90 | # assert 0 91 | self.loss_G = self.loss_G_L1 + self.loss_G_R + self.loss_G_I_smooth + self.loss_G_I_L2 + self.loss_G_L 92 | self.loss_G.backward() 93 | 94 | def optimize_parameters(self): 95 | self.forward() # compute fake images: G(A) 96 | # update G 97 | self.optimizer_G.zero_grad() # set G's gradients to zero 98 | self.backward_G() # calculate graidents for G 99 | self.optimizer_G.step() # udpate G's weights 100 | 101 | def gradient_loss(self, input_1, input_2): 102 | g_x = self.criterionL1(util.gradient(input_1, 'x'), util.gradient(input_2, 'x')) 103 | g_y = self.criterionL1(util.gradient(input_1, 'y'), util.gradient(input_2, 'y')) 104 | return g_x+g_y 105 | 106 | def __compute_kl(self, mu): 107 | mu_2 = torch.pow(mu, 2) 108 | encoding_loss = torch.mean(mu_2) 109 | return encoding_loss 110 | 111 | -------------------------------------------------------------------------------- /applications/iih_relighting/options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /applications/iih_relighting/options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | # Dropout and Batchnorm has different behavioir during training and test. 17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 18 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 19 | parser.add_argument('--test_epoch', type=str, default="0", help='how many test images to run') 20 | # rewrite devalue values 21 | parser.set_defaults(model='test') 22 | # To avoid cropping, the load_size should be the same as crop_size 23 | parser.set_defaults(load_size=parser.get_default('crop_size')) 24 | self.isTrain = False 25 | return parser 26 | -------------------------------------------------------------------------------- /applications/iih_relighting/options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | # visdom and HTML visualization parameters 13 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 14 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 15 | parser.add_argument('--display_id', type=int, default=0, help='window id of the web display') 16 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 17 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 18 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 19 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 20 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 21 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 22 | # network saving and loading parameters 23 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 24 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 25 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 26 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 27 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 28 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 29 | # training parameters 30 | parser.add_argument('--niter', type=int, default=5, help='# of iter at starting learning rate') 31 | parser.add_argument('--niter_decay', type=int, default=5, help='# of iter to linearly decay learning rate to zero') 32 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 33 | parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') 34 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 35 | parser.add_argument('--g_lr_ratio', type=float, default=1.0, help='a ratio for changing learning rate of generator') 36 | parser.add_argument('--d_lr_ratio', type=float, default=1.0, help='a ratio for changing learning rate of discriminator') 37 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') 38 | parser.add_argument('--pool_size', type=int, default=40, help='the size of image buffer that stores previously generated images') 39 | parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]') 40 | parser.add_argument('--lr_decay_iters', type=int, default=40, help='multiply by a gamma every lr_decay_iters iterations') 41 | parser.add_argument('--save_iter_model', action='store_true', help='whether saves model by iteration') 42 | 43 | 44 | 45 | 46 | self.isTrain = True 47 | return parser 48 | -------------------------------------------------------------------------------- /applications/iih_relighting/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from util.misc import launch_job 5 | from train_net import test 6 | 7 | from options.test_options import TestOptions 8 | 9 | def main(): 10 | cfg = TestOptions().parse() # get training options 11 | cfg.NUM_GPUS = torch.cuda.device_count() 12 | cfg.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 13 | cfg.no_flip = True # no flip; comment this line if results on flipped images are needed. 14 | cfg.display_id = -1 # no visdom display; the test code saves the results to a HTML file. 15 | 16 | cfg.phase = 'test' 17 | cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS)) 18 | launch_job(cfg=cfg, init_method=cfg.init_method, func=test) 19 | 20 | 21 | if __name__=="__main__": 22 | main() -------------------------------------------------------------------------------- /applications/iih_relighting/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from util.misc import launch_job 5 | from train_net import train 6 | 7 | from options.train_options import TrainOptions 8 | 9 | def main(): 10 | cfg = TrainOptions().parse() # get training options 11 | cfg.NUM_GPUS = torch.cuda.device_count() 12 | cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS)) 13 | cfg.phase = 'train' 14 | launch_job(cfg=cfg, init_method=cfg.init_method, func=train) 15 | 16 | 17 | if __name__=="__main__": 18 | main() -------------------------------------------------------------------------------- /applications/iih_relighting/util/evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as f 3 | 4 | 5 | def evaluation(name, fake, real, mask): 6 | b,c,w,h = real.size() 7 | mse_score = f.mse_loss(fake, real) 8 | fore_area = torch.sum(mask) 9 | fmse_score = f.mse_loss(fake*mask,real*mask)*w*h/fore_area 10 | mse_score = mse_score.item() 11 | fmse_score = fmse_score.item() 12 | # score_str = "%s MSE %0.2f | fMSE %0.2f" % (name, mse_score,fmse_score) 13 | image_fmse_info = (name, round(fmse_score,2), round(mse_score, 2)) 14 | return mse_score, fmse_score, image_fmse_info -------------------------------------------------------------------------------- /applications/iih_relighting/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /applications/iih_relighting/util/multiprocessing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Multiprocessing helpers.""" 5 | 6 | import torch 7 | 8 | 9 | def run( 10 | local_rank, 11 | num_proc, 12 | func, 13 | init_method, 14 | shard_id, 15 | num_shards, 16 | backend, 17 | cfg, 18 | output_queue=None, 19 | ): 20 | """ 21 | Runs a function from a child process. 22 | Args: 23 | local_rank (int): rank of the current process on the current machine. 24 | num_proc (int): number of processes per machine. 25 | func (function): function to execute on each of the process. 26 | init_method (string): method to initialize the distributed training. 27 | TCP initialization: equiring a network address reachable from all 28 | processes followed by the port. 29 | Shared file-system initialization: makes use of a file system that 30 | is shared and visible from all machines. The URL should start with 31 | file:// and contain a path to a non-existent file on a shared file 32 | system. 33 | shard_id (int): the rank of the current machine. 34 | num_shards (int): number of overall machines for the distributed 35 | training job. 36 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are 37 | supports, each with different capabilities. Details can be found 38 | here: 39 | https://pytorch.org/docs/stable/distributed.html 40 | cfg (CfgNode): configs. Details can be found in 41 | slowfast/config/defaults.py 42 | output_queue (queue): can optionally be used to return values from the 43 | master process. 44 | """ 45 | # Initialize the process group. 46 | world_size = num_proc * num_shards 47 | rank = shard_id * num_proc + local_rank 48 | try: 49 | torch.distributed.init_process_group( 50 | backend=backend, 51 | init_method=init_method, 52 | world_size=world_size, 53 | rank=rank, 54 | ) 55 | 56 | except Exception as e: 57 | raise e 58 | 59 | torch.cuda.set_device(local_rank) 60 | ret = func(cfg) 61 | if output_queue is not None and local_rank == 0: 62 | output_queue.put(ret) 63 | -------------------------------------------------------------------------------- /applications/iih_relighting/util/ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | def _ssim_c_s(img1, img2, window, window_size, channel, size_average = True): 40 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 41 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 42 | 43 | mu1_sq = mu1.pow(2) 44 | mu2_sq = mu2.pow(2) 45 | mu1_mu2 = mu1*mu2 46 | 47 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 48 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 49 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 50 | 51 | C1 = 0.01**2 52 | C2 = 0.03**2 53 | 54 | ssim_map = ((2*sigma12 + C2))/((sigma1_sq + sigma2_sq + C2)) 55 | 56 | if size_average: 57 | return ssim_map.mean() 58 | else: 59 | return ssim_map.mean(1).mean(1).mean(1) 60 | def _ssim_l(img1, img2, window, window_size, channel, size_average = True): 61 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 62 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 63 | 64 | mu1_sq = mu1.pow(2) 65 | mu2_sq = mu2.pow(2) 66 | mu1_mu2 = mu1*mu2 67 | 68 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 69 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 70 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 71 | 72 | C1 = 0.01**2 73 | C2 = 0.03**2 74 | 75 | ssim_map = (2*mu1_mu2 + C1)/(mu1_sq + mu2_sq + C1) 76 | 77 | if size_average: 78 | return ssim_map.mean() 79 | else: 80 | return ssim_map.mean(1).mean(1).mean(1) 81 | 82 | 83 | def ssim(img1, img2, window_size = 11, size_average = True): 84 | (_, channel, _, _) = img1.size() 85 | window = create_window(window_size, channel) 86 | 87 | if img1.is_cuda: 88 | window = window.cuda(img1.get_device()) 89 | window = window.type_as(img1) 90 | 91 | return _ssim(img1, img2, window, window_size, channel, size_average) 92 | 93 | class SSIM(torch.nn.Module): 94 | def __init__(self, window_size = 11, size_average = True, mode='all'): 95 | super(SSIM, self).__init__() 96 | self.window_size = window_size 97 | self.size_average = size_average 98 | self.channel = 1 99 | self.window = create_window(window_size, self.channel) 100 | self.mode = mode 101 | def forward(self, img1, img2): 102 | (_, channel, _, _) = img1.size() 103 | 104 | if channel == self.channel and self.window.data.type() == img1.data.type(): 105 | window = self.window 106 | else: 107 | window = create_window(self.window_size, channel) 108 | 109 | # if img1.is_cuda: 110 | # window = window.cuda(img1.get_device()) 111 | window = window.type_as(img1) 112 | 113 | self.window = window 114 | self.channel = channel 115 | if self.mode == 'all': 116 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 117 | elif self.mode == 'c_s': 118 | return _ssim_c_s(img1, img2, window, self.window_size, channel, self.size_average) 119 | else: 120 | return _ssim_l(img1, img2, window, self.window_size, channel, self.size_average) 121 | 122 | class DSSIM(torch.nn.Module): 123 | def __init__(self, window_size = 11, size_average = True, mode='all'): 124 | super(DSSIM, self).__init__() 125 | self.window_size = window_size 126 | self.size_average = size_average 127 | self.channel = 1 128 | self.window = create_window(window_size, self.channel) 129 | self.mode = mode 130 | def forward(self, img1, img2): 131 | (_, channel, _, _) = img1.size() 132 | 133 | if channel == self.channel and self.window.data.type() == img1.data.type(): 134 | window = self.window 135 | else: 136 | window = create_window(self.window_size, channel) 137 | 138 | # if img1.is_cuda: 139 | # window = window.cuda(img1.get_device()) 140 | window = window.type_as(img1) 141 | 142 | self.window = window 143 | self.channel = channel 144 | if self.mode == 'all': 145 | ssim_v = _ssim(img1, img2, window, self.window_size, channel, self.size_average) 146 | elif self.mode == 'c_s': 147 | ssim_v = _ssim_c_s(img1, img2, window, self.window_size, channel, self.size_average) 148 | else: 149 | ssim_v = _ssim_l(img1, img2, window, self.window_size, channel, self.size_average) 150 | return (1-ssim_v)/2 -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from data.base_dataset import BaseDataset 16 | from torch.utils.data.distributed import DistributedSampler 17 | from torch.utils.data.sampler import RandomSampler 18 | 19 | 20 | def find_dataset_using_name(dataset_name): 21 | """Import the module "data/[dataset_name]_dataset.py". 22 | 23 | In the file, the class called DatasetNameDataset() will 24 | be instantiated. It has to be a subclass of BaseDataset, 25 | and it is case-insensitive. 26 | """ 27 | dataset_filename = "data." + dataset_name + "_dataset" 28 | datasetlib = importlib.import_module(dataset_filename) 29 | 30 | dataset = None 31 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 32 | for name, cls in datasetlib.__dict__.items(): 33 | if name.lower() == target_dataset_name.lower() \ 34 | and issubclass(cls, BaseDataset): 35 | dataset = cls 36 | 37 | if dataset is None: 38 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 39 | 40 | return dataset 41 | 42 | 43 | def get_option_setter(dataset_name): 44 | """Return the static method of the dataset class.""" 45 | dataset_class = find_dataset_using_name(dataset_name) 46 | return dataset_class.modify_commandline_options 47 | 48 | 49 | def create_dataset(opt): 50 | """Create a dataset given the option. 51 | 52 | This function wraps the class CustomDatasetDataLoader. 53 | This is the main interface between this package and 'train.py'/'test.py' 54 | 55 | Example: 56 | >>> from data import create_dataset 57 | >>> dataset = create_dataset(opt) 58 | """ 59 | # data_loader = CustomDatasetDataLoader(opt) 60 | # dataset = data_loader.load_data() 61 | 62 | dataset_class = find_dataset_using_name(opt.dataset_mode) 63 | dataset = dataset_class(opt) 64 | print("dataset [%s] was created" % type(dataset).__name__) 65 | 66 | # batch_size = int(opt.batch_size / max(1, opt.NUM_GPUS)) 67 | if opt.isTrain==True: 68 | shuffle = True 69 | drop_last = True 70 | elif opt.isTrain==False: 71 | shuffle = False 72 | drop_last = False 73 | 74 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) if opt.NUM_GPUS > 1 else None 75 | 76 | # Create a loader 77 | dataloader = torch.utils.data.DataLoader( 78 | dataset, 79 | batch_size=opt.batch_size, 80 | shuffle=(False if sampler else shuffle), 81 | sampler=sampler, 82 | num_workers=int(opt.num_threads), 83 | drop_last=drop_last, 84 | pin_memory=True, 85 | ) 86 | return dataloader 87 | 88 | 89 | class CustomDatasetDataLoader(): 90 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 91 | 92 | def __init__(self, opt): 93 | """Initialize this class 94 | 95 | Step 1: create a dataset instance given the name [dataset_mode] 96 | Step 2: create a multi-threaded data loader. 97 | """ 98 | self.opt = opt 99 | dataset_class = find_dataset_using_name(opt.dataset_mode) 100 | self.dataset = dataset_class(opt) 101 | print("dataset [%s] was created" % type(self.dataset).__name__) 102 | 103 | batch_size = int(opt.batch_size / max(1, opt.NUM_GPUS)) 104 | if opt.isTrain==True: 105 | shuffle = True 106 | drop_last = True 107 | elif opt.isTrain==False: 108 | shuffle = False 109 | drop_last = False 110 | 111 | self.sampler = torch.utils.data.distributed.DistributedSampler(self.dataset) if opt.NUM_GPUS > 1 else None 112 | 113 | # Create a loader 114 | self.dataloader = torch.utils.data.DataLoader( 115 | self.dataset, 116 | batch_size=batch_size, 117 | shuffle=(False if self.sampler else shuffle), 118 | sampler=self.sampler, 119 | num_workers=int(opt.num_threads), 120 | drop_last=drop_last, 121 | ) 122 | 123 | # self.dataloader = torch.utils.data.DataLoader( 124 | # self.dataset, 125 | # batch_size=opt.batch_size, 126 | # shuffle=not opt.serial_batches, 127 | # num_workers=int(opt.num_threads)) 128 | 129 | def load_data(self): 130 | return self 131 | 132 | def __len__(self): 133 | """Return the number of data in the dataset""" 134 | return min(len(self.dataset), self.opt.max_dataset_size) 135 | 136 | # def __iter__(self): 137 | # """Return a batch of data""" 138 | # for i, data in enumerate(self.dataloader): 139 | # if i * self.opt.batch_size >= self.opt.max_dataset_size: 140 | # break 141 | # yield data 142 | 143 | def shuffle_dataset(loader, cur_epoch): 144 | """ " 145 | Shuffles the data. 146 | Args: 147 | loader (loader): data loader to perform shuffle. 148 | cur_epoch (int): number of the current epoch. 149 | """ 150 | # sampler = ( 151 | # loader.batch_sampler.sampler 152 | # if isinstance(loader.batch_sampler, ShortCycleBatchSampler) 153 | # else loader.sampler 154 | # ) 155 | sampler = loader.sampler 156 | assert isinstance( 157 | sampler, (RandomSampler, DistributedSampler) 158 | ), "Sampler type '{}' not supported".format(type(sampler)) 159 | # RandomSampler handles shuffling automatically 160 | if isinstance(sampler, DistributedSampler): 161 | # DistributedSampler shuffles data based on epoch 162 | sampler.set_epoch(cur_epoch) 163 | 164 | 165 | 166 | 167 | 168 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | # self.root = opt.dataroot 31 | self.root = opt.dataset_root 32 | 33 | @staticmethod 34 | def modify_commandline_options(parser, is_train): 35 | """Add new dataset-specific options, and rewrite default values for existing options. 36 | 37 | Parameters: 38 | parser -- original option parser 39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 40 | 41 | Returns: 42 | the modified parser. 43 | """ 44 | return parser 45 | 46 | @abstractmethod 47 | def __len__(self): 48 | """Return the total number of images in the dataset.""" 49 | return 0 50 | 51 | @abstractmethod 52 | def __getitem__(self, index): 53 | """Return a data point and its metadata information. 54 | 55 | Parameters: 56 | index - - a random integer for data indexing 57 | 58 | Returns: 59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 60 | """ 61 | pass 62 | 63 | 64 | def get_params(opt, size): 65 | w, h = size 66 | new_h = h 67 | new_w = w 68 | if opt.preprocess == 'resize_and_crop': 69 | new_h = new_w = opt.load_size 70 | elif opt.preprocess == 'scale_width_and_crop': 71 | new_w = opt.load_size 72 | new_h = opt.load_size * h // w 73 | 74 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 75 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 76 | 77 | flip = random.random() > 0.5 78 | 79 | return {'crop_pos': (x, y), 'flip': flip} 80 | 81 | 82 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 83 | transform_list = [] 84 | if grayscale: 85 | transform_list.append(transforms.Grayscale(1)) 86 | if 'resize' in opt.preprocess: 87 | osize = [opt.load_size, opt.load_size] 88 | transform_list.append(transforms.Resize(osize, method)) 89 | elif 'scale_width' in opt.preprocess: 90 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 91 | 92 | if 'crop' in opt.preprocess: 93 | if params is None: 94 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 95 | else: 96 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 97 | 98 | if opt.preprocess == 'none': 99 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 100 | 101 | if not opt.no_flip: 102 | if params is None: 103 | transform_list.append(transforms.RandomHorizontalFlip()) 104 | elif params['flip']: 105 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 106 | 107 | if convert: 108 | transform_list += [transforms.ToTensor()] 109 | if grayscale: 110 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 111 | else: 112 | # transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 113 | transform_list += [transforms.Normalize((0, 0, 0), (1, 1, 1))] 114 | return transforms.Compose(transform_list) 115 | 116 | 117 | def __make_power_2(img, base, method=Image.BICUBIC): 118 | ow, oh = img.size 119 | h = int(round(oh / base) * base) 120 | w = int(round(ow / base) * base) 121 | if (h == oh) and (w == ow): 122 | return img 123 | 124 | __print_size_warning(ow, oh, w, h) 125 | return img.resize((w, h), method) 126 | 127 | 128 | def __scale_width(img, target_width, method=Image.BICUBIC): 129 | ow, oh = img.size 130 | if (ow == target_width): 131 | return img 132 | w = target_width 133 | h = int(target_width * oh / ow) 134 | return img.resize((w, h), method) 135 | 136 | 137 | def __crop(img, pos, size): 138 | ow, oh = img.size 139 | x1, y1 = pos 140 | tw = th = size 141 | if (ow > tw or oh > th): 142 | return img.crop((x1, y1, x1 + tw, y1 + th)) 143 | return img 144 | 145 | 146 | def __flip(img, flip): 147 | if flip: 148 | return img.transpose(Image.FLIP_LEFT_RIGHT) 149 | return img 150 | 151 | 152 | def __print_size_warning(ow, oh, w, h): 153 | """Print warning information about image size(only print once)""" 154 | if not hasattr(__print_size_warning, 'has_printed'): 155 | print("The image size needs to be a multiple of 4. " 156 | "The loaded image size was (%d, %d), so it was adjusted to " 157 | "(%d, %d). This adjustment will be done to all images " 158 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 159 | __print_size_warning.has_printed = True 160 | -------------------------------------------------------------------------------- /data/ihd_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class template 2 | 3 | This module provides a template for users to implement custom datasets. 4 | You can specify '--dataset_mode template' to use this dataset. 5 | The class name should be consistent with both the filename and its dataset_mode option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | You need to implement the following functions: 9 | -- : Add dataset-specific options and rewrite default values for existing options. 10 | -- <__init__>: Initialize this dataset class. 11 | -- <__getitem__>: Return a data point and its metadata information. 12 | -- <__len__>: Return the number of images. 13 | """ 14 | import os.path 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | import torch.nn.functional as F 18 | from data.base_dataset import BaseDataset, get_transform 19 | from data.image_folder import make_dataset 20 | from PIL import Image 21 | import numpy as np 22 | import torchvision.transforms as transforms 23 | from util import util 24 | 25 | class IhdDataset(BaseDataset): 26 | """A template dataset class for you to implement custom datasets.""" 27 | @staticmethod 28 | def modify_commandline_options(parser, is_train): 29 | """Add new dataset-specific options, and rewrite default values for existing options. 30 | 31 | Parameters: 32 | parser -- original option parser 33 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 34 | 35 | Returns: 36 | the modified parser. 37 | """ 38 | parser.add_argument('--is_train', type=bool, default=True, help='whether in the training phase') 39 | parser.set_defaults(max_dataset_size=float("inf"), new_dataset_option=2.0) # specify dataset-specific default values 40 | return parser 41 | 42 | def __init__(self, opt): 43 | """Initialize this dataset class. 44 | 45 | Parameters: 46 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 47 | 48 | A few things can be done here. 49 | - save the options (have been done in BaseDataset) 50 | - get image paths and meta information of the dataset. 51 | - define the image transformation. 52 | """ 53 | # save the option and dataset root 54 | BaseDataset.__init__(self, opt) 55 | self.image_paths = [] 56 | self.isTrain = opt.isTrain 57 | self.image_size = opt.crop_size 58 | 59 | if opt.isTrain==True: 60 | print('loading training file') 61 | self.trainfile = opt.dataset_root+opt.dataset_name+'_train.txt' 62 | with open(self.trainfile,'r') as f: 63 | for line in f.readlines(): 64 | self.image_paths.append(os.path.join(opt.dataset_root,'composite_images',line.rstrip())) 65 | elif opt.isTrain==False: 66 | #self.real_ext='.jpg' 67 | print('loading test file') 68 | self.trainfile = opt.dataset_root+opt.dataset_name+'_test.txt' 69 | with open(self.trainfile,'r') as f: 70 | for line in f.readlines(): 71 | self.image_paths.append(os.path.join(opt.dataset_root,'composite_images',line.rstrip())) 72 | # get the image paths of your dataset; 73 | # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root 74 | # define the default transform function. You can use ; You can also define your custom transform function 75 | transform_list = [ 76 | transforms.ToTensor(), 77 | transforms.Normalize((0, 0, 0), (1, 1, 1)) 78 | ] 79 | self.transforms = transforms.Compose(transform_list) 80 | 81 | def __getitem__(self, index): 82 | """Return a data point and its metadata information. 83 | 84 | Parameters: 85 | index -- a random integer for data indexing 86 | 87 | Returns: 88 | a dictionary of data with their names. It usually contains the data itself and its metadata information. 89 | 90 | Step 1: get a random image path: e.g., path = self.image_paths[index] 91 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). 92 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) 93 | Step 4: return a data point as a dictionary. 94 | """ 95 | path = self.image_paths[index] 96 | name_parts=path.split('_') 97 | mask_path = self.image_paths[index].replace('composite_images','masks') 98 | mask_path = mask_path.replace(('_'+name_parts[-1]),'.png') 99 | target_path = self.image_paths[index].replace('composite_images','real_images') 100 | target_path = target_path.replace(('_'+name_parts[-2]+'_'+name_parts[-1]),'.jpg') 101 | 102 | comp = Image.open(path).convert('RGB') 103 | real = Image.open(target_path).convert('RGB') 104 | mask = Image.open(mask_path).convert('1') 105 | 106 | if np.random.rand() > 0.5 and self.isTrain: 107 | comp, mask, real = tf.hflip(comp), tf.hflip(mask), tf.hflip(real) 108 | 109 | if comp.size[0] != self.image_size: 110 | comp = tf.resize(comp, [self.image_size, self.image_size]) 111 | mask = tf.resize(mask, [self.image_size, self.image_size]) 112 | real = tf.resize(real, [self.image_size,self.image_size]) 113 | 114 | comp = self.transforms(comp) 115 | mask = tf.to_tensor(mask) 116 | real = self.transforms(real) 117 | 118 | inputs=torch.cat([comp,mask],0) 119 | 120 | return {'inputs': inputs, 'comp': comp, 'real': real,'img_path':path,'mask':mask} 121 | 122 | def __len__(self): 123 | """Return the total number of images.""" 124 | return len(self.image_paths) 125 | 126 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | import os.path 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | ] 17 | 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | 23 | def make_dataset(dir, max_dataset_size=float("inf")): 24 | images = [] 25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 26 | 27 | for root, _, fnames in sorted(os.walk(dir)): 28 | for fname in fnames: 29 | if is_image_file(fname): 30 | path = os.path.join(root, fname) 31 | images.append(path) 32 | return images[:min(max_dataset_size, len(images))] 33 | 34 | 35 | def default_loader(path): 36 | return Image.open(path).convert('RGB') 37 | 38 | 39 | class ImageFolder(data.Dataset): 40 | 41 | def __init__(self, root, transform=None, return_paths=False, 42 | loader=default_loader): 43 | imgs = make_dataset(root) 44 | if len(imgs) == 0: 45 | raise(RuntimeError("Found 0 images in: " + root + "\n" 46 | "Supported image extensions are: " + 47 | ",".join(IMG_EXTENSIONS))) 48 | 49 | self.root = root 50 | self.imgs = imgs 51 | self.transform = transform 52 | self.return_paths = return_paths 53 | self.loader = loader 54 | 55 | def __getitem__(self, index): 56 | path = self.imgs[index] 57 | img = self.loader(path) 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | if self.return_paths: 61 | return img, path 62 | else: 63 | return img 64 | 65 | def __len__(self): 66 | return len(self.imgs) 67 | -------------------------------------------------------------------------------- /data/loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Data loader.""" 5 | 6 | import itertools 7 | import numpy as np 8 | import torch 9 | from torch.utils.data._utils.collate import default_collate 10 | from torch.utils.data.distributed import DistributedSampler 11 | from torch.utils.data.sampler import RandomSampler 12 | 13 | from slowfast.datasets.multigrid_helper import ShortCycleBatchSampler 14 | 15 | from . import utils as utils 16 | 17 | def build_dataset(cfg): 18 | image_paths = [] 19 | if cfg.phase == 'train': 20 | print('loading training file') 21 | file = cfg.dataset_root+cfg.dataset_name+'_train.txt' 22 | with open(file,'r') as f: 23 | for line in f.readlines(): 24 | image_paths.append(os.path.join(cfg.dataset_root,'composite_images',line.rstrip())) 25 | 26 | 27 | def construct_loader(cfg, split, is_precise_bn=False): 28 | """ 29 | Constructs the data loader for the given dataset. 30 | Args: 31 | cfg (CfgNode): configs. Details can be found in 32 | slowfast/config/defaults.py 33 | split (str): the split of the data loader. Options include `train`, 34 | `val`, and `test`. 35 | """ 36 | assert split in ["train", "val", "test"] 37 | if split in ["train"]: 38 | dataset_name = cfg.TRAIN.DATASET 39 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 40 | shuffle = True 41 | drop_last = True 42 | elif split in ["val"]: 43 | dataset_name = cfg.TRAIN.DATASET 44 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 45 | shuffle = False 46 | drop_last = False 47 | elif split in ["test"]: 48 | dataset_name = cfg.TEST.DATASET 49 | batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS)) 50 | shuffle = False 51 | drop_last = False 52 | 53 | # Construct the dataset 54 | dataset = build_dataset(dataset_name, cfg, split) 55 | 56 | if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn: 57 | # Create a sampler for multi-process training 58 | sampler = utils.create_sampler(dataset, shuffle, cfg) 59 | batch_sampler = ShortCycleBatchSampler( 60 | sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg 61 | ) 62 | # Create a loader 63 | loader = torch.utils.data.DataLoader( 64 | dataset, 65 | batch_sampler=batch_sampler, 66 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 67 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 68 | worker_init_fn=utils.loader_worker_init_fn(dataset), 69 | ) 70 | else: 71 | # Create a sampler for multi-process training 72 | sampler = utils.create_sampler(dataset, shuffle, cfg) 73 | # Create a loader 74 | loader = torch.utils.data.DataLoader( 75 | dataset, 76 | batch_size=batch_size, 77 | shuffle=(False if sampler else shuffle), 78 | sampler=sampler, 79 | num_workers=cfg.DATA_LOADER.NUM_WORKERS, 80 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY, 81 | drop_last=drop_last, 82 | collate_fn=detection_collate if cfg.DETECTION.ENABLE else None, 83 | worker_init_fn=utils.loader_worker_init_fn(dataset), 84 | ) 85 | return loader 86 | 87 | 88 | def shuffle_dataset(loader, cur_epoch): 89 | """ " 90 | Shuffles the data. 91 | Args: 92 | loader (loader): data loader to perform shuffle. 93 | cur_epoch (int): number of the current epoch. 94 | """ 95 | sampler = ( 96 | loader.batch_sampler.sampler 97 | if isinstance(loader.batch_sampler, ShortCycleBatchSampler) 98 | else loader.sampler 99 | ) 100 | assert isinstance( 101 | sampler, (RandomSampler, DistributedSampler) 102 | ), "Sampler type '{}' not supported".format(type(sampler)) 103 | # RandomSampler handles shuffling automatically 104 | if isinstance(sampler, DistributedSampler): 105 | # DistributedSampler shuffles data based on epoch 106 | sampler.set_epoch(cur_epoch) 107 | -------------------------------------------------------------------------------- /data/real_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class template 2 | 3 | This module provides a template for users to implement custom datasets. 4 | You can specify '--dataset_mode template' to use this dataset. 5 | The class name should be consistent with both the filename and its dataset_mode option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | You need to implement the following functions: 9 | -- : Add dataset-specific options and rewrite default values for existing options. 10 | -- <__init__>: Initialize this dataset class. 11 | -- <__getitem__>: Return a data point and its metadata information. 12 | -- <__len__>: Return the number of images. 13 | """ 14 | import os.path 15 | import torch 16 | import torchvision.transforms.functional as tf 17 | from data.base_dataset import BaseDataset, get_transform 18 | from data.image_folder import make_dataset 19 | from PIL import Image 20 | import numpy as np 21 | import torchvision.transforms as transforms 22 | from scipy import sparse 23 | from util import util 24 | 25 | class RealDataset(BaseDataset): 26 | """A template dataset class for you to implement custom datasets.""" 27 | @staticmethod 28 | def modify_commandline_options(parser, is_train): 29 | """Add new dataset-specific options, and rewrite default values for existing options. 30 | 31 | Parameters: 32 | parser -- original option parser 33 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 34 | 35 | Returns: 36 | the modified parser. 37 | """ 38 | parser.add_argument('--is_train', type=bool, default=True, help='whether in the training phase') 39 | parser.set_defaults(max_dataset_size=float("inf"), new_dataset_option=2.0) # specify dataset-specific default values 40 | return parser 41 | 42 | def __init__(self, opt): 43 | """Initialize this dataset class. 44 | 45 | Parameters: 46 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 47 | 48 | A few things can be done here. 49 | - save the options (have been done in BaseDataset) 50 | - get image paths and meta information of the dataset. 51 | - define the image transformation. 52 | """ 53 | # save the option and dataset root 54 | BaseDataset.__init__(self, opt) 55 | self.image_paths = [] 56 | self.isTrain = opt.isTrain 57 | self.image_size = opt.crop_size 58 | if opt.isTrain==True: 59 | print('loading training file') 60 | self.trainfile = opt.dataset_root+opt.dataset_name+'_train.txt' 61 | with open(self.trainfile,'r') as f: 62 | for line in f.readlines(): 63 | self.image_paths.append(os.path.join(opt.dataset_root,'composite_images',line.rstrip())) 64 | elif opt.isTrain==False: 65 | print('loading test file') 66 | self.trainfile = opt.dataset_root+opt.dataset_name+'_test.txt' 67 | with open(self.trainfile,'r') as f: 68 | for line in f.readlines(): 69 | self.image_paths.append(os.path.join(opt.dataset_root,'composite_images',line.rstrip())) 70 | # get the image paths of your dataset; 71 | # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root 72 | # define the default transform function. You can use ; You can also define your custom transform function 73 | self.transform = get_transform(opt) 74 | 75 | def __getitem__(self, index): 76 | path = self.image_paths[index] 77 | mask_path = self.image_paths[index].replace('composite_images','masks') 78 | 79 | comp = Image.open(path).convert('RGB') 80 | mask = Image.open(mask_path).convert('1') 81 | 82 | if np.random.rand() > 0.5 and self.isTrain: 83 | comp, mask = tf.hflip(comp), tf.hflip(mask) 84 | 85 | comp = tf.resize(comp, [self.image_size, self.image_size]) 86 | mask = tf.resize(mask, [self.image_size, self.image_size]) 87 | comp = self.transform(comp) 88 | mask = tf.to_tensor(mask) 89 | inputs=torch.cat([comp,mask],0) 90 | 91 | return {'inputs': inputs, 'comp': comp, 'real': comp,'img_path':path,'mask':mask} 92 | 93 | def __len__(self): 94 | """Return the total number of images.""" 95 | return len(self.image_paths) 96 | -------------------------------------------------------------------------------- /evaluation/pytorch_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | import os 7 | 8 | 9 | def gaussian(window_size, sigma): 10 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 11 | return gauss/gauss.sum() 12 | 13 | def create_window(window_size, channel): 14 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 15 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 16 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 17 | return window 18 | 19 | def _ssim(img1, img2, window, window_size, channel, size_average = True, mask=None): 20 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 21 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 22 | 23 | mu1_sq = mu1.pow(2) 24 | mu2_sq = mu2.pow(2) 25 | mu1_mu2 = mu1*mu2 26 | 27 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 28 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 29 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 30 | 31 | C1 = 0.01**2 32 | C2 = 0.03**2 33 | 34 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 35 | 36 | if mask is not None: 37 | mask_sum = mask.sum() 38 | fg_ssim_map = ssim_map*mask 39 | fg_ssim_map_sum = fg_ssim_map.sum(3).sum(2) 40 | fg_ssim = fg_ssim_map_sum/mask_sum 41 | fg_ssim_mu = fg_ssim.mean() 42 | ssim_mu = ssim_map.mean() 43 | return ssim_mu.item(), fg_ssim_mu.item() 44 | 45 | # if size_average: 46 | # return ssim_map.mean() 47 | # else: 48 | # return ssim_map.mean(1).mean(1).mean(1) 49 | 50 | class SSIM(torch.nn.Module): 51 | def __init__(self, window_size = 11, size_average = True): 52 | super(SSIM, self).__init__() 53 | self.window_size = window_size 54 | self.size_average = size_average 55 | self.channel = 1 56 | self.window = create_window(window_size, self.channel) 57 | 58 | def forward(self, img1, img2): 59 | (_, channel, _, _) = img1.size() 60 | 61 | if channel == self.channel and self.window.data.type() == img1.data.type(): 62 | window = self.window 63 | else: 64 | window = create_window(self.window_size, channel) 65 | 66 | if img1.is_cuda: 67 | window = window.cuda(img1.get_device()) 68 | window = window.type_as(img1) 69 | 70 | self.window = window 71 | self.channel = channel 72 | 73 | 74 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 75 | 76 | def ssim(img1, img2, window_size = 11, size_average = True, mask=None): 77 | (_, channel, _, _) = img1.size() 78 | window = create_window(window_size, channel) 79 | 80 | if img1.is_cuda: 81 | window = window.cuda(img1.get_device()) 82 | window = window.type_as(img1) 83 | 84 | return _ssim(img1, img2, window, window_size, channel, size_average, mask) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /models/iih_base_gd_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import itertools 4 | import torch.nn.functional as F 5 | from util import distributed as du 6 | from .base_model import BaseModel 7 | from util import util 8 | from . import harmony_networks as networks 9 | import util.ssim as ssim 10 | 11 | 12 | class IIHBaseGDModel(BaseModel): 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train=True): 15 | parser.set_defaults(norm='instance', netG='base_gd', dataset_mode='ihd') 16 | if is_train: 17 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 18 | parser.add_argument('--lambda_R_gradient', type=float, default=20., help='weight for reflectance gradient loss') 19 | parser.add_argument('--lambda_I_L2', type=float, default=10., help='weight for illumination L2 loss') 20 | parser.add_argument('--lambda_I_smooth', type=float, default=1, help='weight for Illumination smooth loss') 21 | parser.add_argument('--lambda_ifm', type=float, default=100, help='weight for pm loss') 22 | return parser 23 | 24 | def __init__(self, opt): 25 | BaseModel.__init__(self, opt) 26 | self.opt = opt 27 | # specify the training losses you want to print out. The training/test scripts will call 28 | self.loss_names = ['G','G_L1',"IF"] 29 | if opt.loss_RH: 30 | self.loss_names.append("G_R_grident") 31 | if opt.loss_IH: 32 | self.loss_names.append("G_I_L2") 33 | if opt.loss_IS: 34 | self.loss_names.append("G_I_smooth") 35 | 36 | # specify the images you want to save/display. The training/test scripts will call 37 | self.visual_names = ['mask', 'harmonized','comp','real','reflectance','illumination','ifm_mean'] 38 | # specify the models you want to save to the disk. The training/test scripts will call and 39 | self.model_names = ['G'] 40 | self.opt.device = self.device 41 | self.netG = networks.define_G(opt.netG, opt.init_type, opt.init_gain, self.opt) 42 | self.cur_device = torch.cuda.current_device() 43 | self.ismaster = du.is_master_proc(opt.NUM_GPUS) 44 | if self.ismaster: 45 | print(self.netG) 46 | 47 | if self.isTrain: 48 | util.saveprint(self.opt, 'netG', str(self.netG)) 49 | # define loss functions 50 | self.criterionL1 = torch.nn.L1Loss().cuda(self.cur_device) 51 | self.criterionL2 = torch.nn.MSELoss().cuda(self.cur_device) 52 | self.criterionDSSIM_CS = ssim.DSSIM(mode='c_s').to(self.device) 53 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 54 | self.optimizers.append(self.optimizer_G) 55 | 56 | def set_input(self, input): 57 | self.comp = input['comp'].to(self.device) 58 | self.real = input['real'].to(self.device) 59 | self.inputs = input['inputs'].to(self.device) 60 | self.mask = input['mask'].to(self.device) 61 | self.image_paths = input['img_path'] 62 | self.mask_r = F.interpolate(self.mask, size=[64,64]) 63 | self.mask_r_32 = F.interpolate(self.mask, size=[32,32]) 64 | self.real_r = F.interpolate(self.real, size=[32,32]) 65 | self.real_gray = util.rgbtogray(self.real_r) 66 | 67 | def forward(self): 68 | """Run forward pass; called by both functions and .""" 69 | self.harmonized, self.reflectance, self.illumination, self.ifm_mean = self.netG(self.inputs, self.mask_r, self.mask_r_32) 70 | if not self.isTrain: 71 | self.harmonized = self.comp*(1-self.mask) + self.harmonized*self.mask 72 | def backward_G(self): 73 | """Calculate GAN and L1 loss for the generator""" 74 | self.loss_IF = self.criterionDSSIM_CS(self.ifm_mean, self.real_gray)*self.opt.lambda_ifm 75 | 76 | self.loss_G_L1 = self.criterionL1(self.harmonized, self.real)*self.opt.lambda_L1 77 | self.loss_G = self.loss_G_L1+self.loss_IF 78 | if self.opt.loss_RH: 79 | self.loss_G_R_grident = self.gradient_loss(self.reflectance, self.real)*self.opt.lambda_R_gradient 80 | self.loss_G = self.loss_G + self.loss_G_R_grident 81 | if self.opt.loss_IH: 82 | self.loss_G_I_L2 = self.criterionL2(self.illumination, self.real)*self.opt.lambda_I_L2 83 | self.loss_G = self.loss_G + self.loss_G_I_L2 84 | if self.opt.loss_IS: 85 | self.loss_G_I_smooth = util.compute_smooth_loss(self.illumination)*self.opt.lambda_I_smooth 86 | self.loss_G = self.loss_G + self.loss_G_I_smooth 87 | self.loss_G.backward() 88 | 89 | def optimize_parameters(self): 90 | self.forward() # compute fake images: G(A) 91 | # update G 92 | self.optimizer_G.zero_grad() # set G's gradients to zero 93 | self.backward_G() # calculate graidents for G 94 | self.optimizer_G.step() # udpate G's weights 95 | 96 | def gradient_loss(self, input_1, input_2): 97 | g_x = self.criterionL1(util.gradient(input_1, 'x'), util.gradient(input_2, 'x')) 98 | g_y = self.criterionL1(util.gradient(input_1, 'y'), util.gradient(input_2, 'y')) 99 | return g_x+g_y 100 | 101 | 102 | -------------------------------------------------------------------------------- /models/iih_base_lt_gd_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import itertools 4 | import torch.nn.functional as F 5 | from util import distributed as du 6 | from .base_model import BaseModel 7 | from util import util 8 | from . import harmony_networks as networks 9 | import util.ssim as ssim 10 | 11 | class IIHBaseLTGDModel(BaseModel): 12 | @staticmethod 13 | def modify_commandline_options(parser, is_train=True): 14 | parser.set_defaults(norm='instance', netG='base_lt_gd', dataset_mode='ihd') 15 | if is_train: 16 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 17 | parser.add_argument('--lambda_R_gradient', type=float, default=20., help='weight for reflectance gradient loss') 18 | parser.add_argument('--lambda_I_L2', type=float, default=10, help='weight for illumination L2 loss') 19 | parser.add_argument('--lambda_I_smooth', type=float, default=1, help='weight for Illumination smooth loss') 20 | parser.add_argument('--lambda_ifm', type=float, default=100, help='weight for pm loss') 21 | 22 | return parser 23 | 24 | def __init__(self, opt): 25 | """Initialize the pix2pix class. 26 | 27 | Parameters: 28 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 29 | """ 30 | BaseModel.__init__(self, opt) 31 | self.opt = opt 32 | self.loss_names = ['G','G_L1','G_R_grident','G_I_L2','G_I_smooth',"IF"] 33 | 34 | # specify the images you want to save/display. The training/test scripts will call 35 | self.visual_names = ['mask', 'harmonized','comp','real','reflectance','illumination','ifm_mean'] 36 | # specify the models you want to save to the disk. The training/test scripts will call and 37 | self.model_names = ['G'] 38 | self.opt.device = self.device 39 | self.netG = networks.define_G(opt.netG, opt.init_type, opt.init_gain, self.opt) 40 | self.cur_device = torch.cuda.current_device() 41 | self.ismaster = du.is_master_proc(opt.NUM_GPUS) 42 | if self.ismaster: 43 | print(self.netG) 44 | 45 | if self.isTrain: 46 | if self.ismaster == 0: 47 | util.saveprint(self.opt, 'netG', str(self.netG)) 48 | # define loss functions 49 | self.criterionL1 = torch.nn.L1Loss().cuda(self.cur_device) 50 | self.criterionL2 = torch.nn.MSELoss().cuda(self.cur_device) 51 | self.criterionDSSIM_CS = ssim.DSSIM(mode='c_s').to(self.device) 52 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 53 | self.optimizers.append(self.optimizer_G) 54 | 55 | def set_input(self, input): 56 | self.comp = input['comp'].to(self.device) 57 | self.real = input['real'].to(self.device) 58 | self.inputs = input['inputs'].to(self.device) 59 | self.mask = input['mask'].to(self.device) 60 | self.image_paths = input['img_path'] 61 | 62 | self.mask_r = F.interpolate(self.mask, size=[64,64]) 63 | self.mask_r_32 = F.interpolate(self.mask, size=[32,32]) 64 | self.real_r = F.interpolate(self.real, size=[32,32]) 65 | self.real_gray = util.rgbtogray(self.real_r) 66 | 67 | def forward(self): 68 | """Run forward pass; called by both functions and .""" 69 | self.harmonized, self.reflectance, self.illumination, self.ifm_mean = self.netG(self.inputs, self.mask_r, self.mask_r_32) 70 | if not self.isTrain: 71 | self.harmonized = self.comp*(1-self.mask) + self.harmonized*self.mask 72 | def backward_G(self): 73 | """Calculate GAN and L1 loss for the generator""" 74 | self.loss_IF = self.criterionDSSIM_CS(self.ifm_mean, self.real_gray)*self.opt.lambda_ifm 75 | 76 | self.loss_G_L1 = self.criterionL1(self.harmonized, self.real)*self.opt.lambda_L1 77 | self.loss_G_R_grident = self.gradient_loss(self.reflectance, self.real)*self.opt.lambda_R_gradient 78 | self.loss_G_I_L2 = self.criterionL2(self.illumination, self.real)*self.opt.lambda_I_L2 79 | self.loss_G_I_smooth = util.compute_smooth_loss(self.illumination)*self.opt.lambda_I_smooth 80 | # assert 0 81 | self.loss_G = self.loss_G_L1 + self.loss_G_R_grident + self.loss_G_I_L2 + self.loss_G_I_smooth + self.loss_IF 82 | self.loss_G.backward() 83 | 84 | def optimize_parameters(self): 85 | self.forward() # compute fake images: G(A) 86 | # update G 87 | self.optimizer_G.zero_grad() # set G's gradients to zero 88 | self.backward_G() # calculate graidents for G 89 | self.optimizer_G.step() # udpate G's weights 90 | 91 | def gradient_loss(self, input_1, input_2): 92 | g_x = self.criterionL1(util.gradient(input_1, 'x'), util.gradient(input_2, 'x')) 93 | g_y = self.criterionL1(util.gradient(input_1, 'y'), util.gradient(input_2, 'y')) 94 | return g_x+g_y 95 | 96 | -------------------------------------------------------------------------------- /models/iih_base_lt_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import itertools 3 | import torch.nn.functional as F 4 | from util import distributed as du 5 | from .base_model import BaseModel 6 | from util import util 7 | from . import harmony_networks as networks 8 | from . import networks as network_init 9 | 10 | 11 | class IIHBaseLTModel(BaseModel): 12 | @staticmethod 13 | def modify_commandline_options(parser, is_train=True): 14 | parser.set_defaults(norm='instance', netG='base_lt', dataset_mode='ihd') 15 | if is_train: 16 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 17 | parser.add_argument('--lambda_R_gradient', type=float, default=20., help='weight for reflectance gradient loss') 18 | parser.add_argument('--lambda_I_L2', type=float, default=10., help='weight for illumination L2 loss') 19 | parser.add_argument('--lambda_I_smooth', type=float, default=1, help='weight for Illumination smooth loss') 20 | 21 | return parser 22 | 23 | def __init__(self, opt): 24 | BaseModel.__init__(self, opt) 25 | self.opt = opt 26 | self.loss_names = ['G','G_L1'] 27 | if opt.loss_RH: 28 | self.loss_names.append("G_R_grident") 29 | if opt.loss_IH: 30 | self.loss_names.append("G_I_L2") 31 | if opt.loss_IS: 32 | self.loss_names.append("G_I_smooth") 33 | 34 | self.visual_names = ['mask', 'harmonized','comp','real','reflectance','illumination'] 35 | self.model_names = ['G'] 36 | self.opt.device = self.device 37 | self.netG = networks.define_G(opt.netG, opt.init_type, opt.init_gain, self.opt) 38 | self.cur_device = torch.cuda.current_device() 39 | self.ismaster = du.is_master_proc(opt.NUM_GPUS) 40 | if self.ismaster: 41 | print(self.netG) 42 | 43 | if self.isTrain: 44 | if self.ismaster == 0: 45 | util.saveprint(self.opt, 'netG', str(self.netG)) 46 | self.criterionL1 = torch.nn.L1Loss().cuda(self.cur_device) 47 | self.criterionL2 = torch.nn.MSELoss().cuda(self.cur_device) 48 | # initialize optimizers; schedulers will be automatically created by function . 49 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 50 | self.optimizers.append(self.optimizer_G) 51 | 52 | def set_input(self, input): 53 | self.comp = input['comp'].to(self.device) 54 | self.real = input['real'].to(self.device) 55 | self.inputs = input['inputs'].to(self.device) 56 | self.mask = input['mask'].to(self.device) 57 | self.image_paths = input['img_path'] 58 | self.mask_r = F.interpolate(self.mask, size=[64,64]) 59 | 60 | def forward(self): 61 | """Run forward pass; called by both functions and .""" 62 | self.harmonized, self.reflectance, self.illumination = self.netG(self.inputs, self.mask, self.mask_r) 63 | if not self.isTrain: 64 | self.harmonized = self.comp*(1-self.mask) + self.harmonized*self.mask 65 | 66 | def backward_G(self): 67 | """Calculate GAN and L1 loss for the generator""" 68 | self.loss_G_L1 = self.criterionL1(self.harmonized, self.real)*self.opt.lambda_L1 69 | self.loss_G = self.loss_G_L1 70 | if self.opt.loss_RH: 71 | self.loss_G_R_grident = self.gradient_loss(self.reflectance, self.real)*self.opt.lambda_R_gradient 72 | self.loss_G = self.loss_G + self.loss_G_R_grident 73 | if self.opt.loss_IH: 74 | self.loss_G_I_L2 = self.criterionL2(self.illumination, self.real)*self.opt.lambda_I_L2 75 | self.loss_G = self.loss_G + self.loss_G_I_L2 76 | if self.opt.loss_IS: 77 | self.loss_G_I_smooth = util.compute_smooth_loss(self.illumination)*self.opt.lambda_I_smooth 78 | self.loss_G = self.loss_G + self.loss_G_I_smooth 79 | # self.loss_G_R_grident = self.gradient_loss(self.reflectance, self.real)*self.opt.lambda_R_gradient 80 | # self.loss_G_I_L2 = self.criterionL2(self.illumination, self.real)*self.opt.lambda_I_L2 81 | # self.loss_G_I_smooth = util.compute_smooth_loss(self.illumination)*self.opt.lambda_I_smooth 82 | # # assert 0 83 | # self.loss_G = self.loss_G_L1 + self.loss_G_R_grident + self.loss_G_I_L2 + self.loss_G_I_smooth 84 | self.loss_G.backward() 85 | 86 | def optimize_parameters(self): 87 | self.forward() # compute fake images: G(A) 88 | # update G 89 | self.optimizer_G.zero_grad() # set G's gradients to zero 90 | self.backward_G() # calculate graidents for G 91 | self.optimizer_G.step() # udpate G's weights 92 | 93 | def gradient_loss(self, input_1, input_2): 94 | g_x = self.criterionL1(util.gradient(input_1, 'x'), util.gradient(input_2, 'x')) 95 | g_y = self.criterionL1(util.gradient(input_1, 'y'), util.gradient(input_2, 'y')) 96 | return g_x+g_y 97 | 98 | -------------------------------------------------------------------------------- /models/iih_base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import itertools 4 | import torch.nn.functional as F 5 | from util import distributed as du 6 | from .base_model import BaseModel 7 | from util import util 8 | from . import harmony_networks as networks 9 | 10 | 11 | class IIHBaseModel(BaseModel): 12 | @staticmethod 13 | def modify_commandline_options(parser, is_train=True): 14 | parser.set_defaults(norm='instance', netG='base', dataset_mode='ihd') 15 | if is_train: 16 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss') 17 | parser.add_argument('--lambda_R_gradient', type=float, default=50., help='weight for reflectance gradient loss') 18 | parser.add_argument('--lambda_I_L2', type=float, default=10., help='weight for illumination L2 loss') 19 | parser.add_argument('--lambda_I_smooth', type=float, default=1, help='weight for Illumination smooth loss') 20 | return parser 21 | 22 | def __init__(self, opt): 23 | BaseModel.__init__(self, opt) 24 | self.opt = opt 25 | self.loss_names = ['G','G_L1'] 26 | if opt.loss_RH: 27 | self.loss_names.append("G_R_grident") 28 | if opt.loss_IH: 29 | self.loss_names.append("G_I_L2") 30 | if opt.loss_IS: 31 | self.loss_names.append("G_I_smooth") 32 | 33 | self.visual_names = ['mask', 'harmonized','comp','real','reflectance','illumination'] 34 | self.model_names = ['G'] 35 | self.opt.device = self.device 36 | self.netG = networks.define_G(opt.netG, opt.init_type, opt.init_gain, self.opt) 37 | self.cur_device = torch.cuda.current_device() 38 | self.ismaster = du.is_master_proc(opt.NUM_GPUS) 39 | if self.ismaster: 40 | print(self.netG) 41 | if self.isTrain: 42 | util.saveprint(self.opt, 'netG', str(self.netG)) 43 | self.criterionL1 = torch.nn.L1Loss().cuda(self.cur_device) 44 | self.criterionL2 = torch.nn.MSELoss().cuda(self.cur_device) 45 | # initialize optimizers; schedulers will be automatically created by function . 46 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 47 | self.optimizers.append(self.optimizer_G) 48 | 49 | def set_input(self, input): 50 | self.comp = input['comp'].to(self.device) 51 | self.real = input['real'].to(self.device) 52 | self.inputs = input['inputs'].to(self.device) 53 | self.mask = input['mask'].to(self.device) 54 | self.image_paths = input['img_path'] 55 | 56 | def forward(self): 57 | self.harmonized, self.reflectance, self.illumination = self.netG(self.inputs, self.mask) 58 | if not self.isTrain: 59 | self.harmonized = self.comp*(1-self.mask) + self.harmonized*self.mask 60 | 61 | def backward_G(self): 62 | self.loss_G_L1 = self.criterionL1(self.harmonized, self.real)*self.opt.lambda_L1 63 | self.loss_G = self.loss_G_L1 64 | if self.opt.loss_RH: 65 | self.loss_G_R_grident = self.gradient_loss(self.reflectance, self.real)*self.opt.lambda_R_gradient 66 | self.loss_G = self.loss_G + self.loss_G_R_grident 67 | if self.opt.loss_IH: 68 | self.loss_G_I_L2 = self.criterionL2(self.illumination, self.real)*self.opt.lambda_I_L2 69 | self.loss_G = self.loss_G + self.loss_G_I_L2 70 | if self.opt.loss_IS: 71 | self.loss_G_I_smooth = util.compute_smooth_loss(self.illumination)*self.opt.lambda_I_smooth 72 | self.loss_G = self.loss_G + self.loss_G_I_smooth 73 | self.loss_G.backward() 74 | 75 | def optimize_parameters(self): 76 | self.forward() # compute fake images: G(A) 77 | # update G 78 | self.optimizer_G.zero_grad() # set G's gradients to zero 79 | self.backward_G() # calculate graidents for G 80 | self.optimizer_G.step() # udpate G's weights 81 | 82 | 83 | def gradient_loss(self, input_1, input_2): 84 | g_x = self.criterionL1(util.gradient(input_1, 'x'), util.gradient(input_2, 'x')) 85 | g_y = self.criterionL1(util.gradient(input_1, 'y'), util.gradient(input_2, 'y')) 86 | return g_x+g_y -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 16 | # Dropout and Batchnorm has different behavioir during training and test. 17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 18 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 19 | parser.add_argument('--test_epoch', type=str, default="0", help='how many test images to run') 20 | # rewrite devalue values 21 | parser.set_defaults(model='test') 22 | # To avoid cropping, the load_size should be the same as crop_size 23 | parser.set_defaults(load_size=parser.get_default('crop_size')) 24 | self.isTrain = False 25 | return parser 26 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | # visdom and HTML visualization parameters 13 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 14 | parser.add_argument('--display_ncols', type=int, default=6, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 15 | parser.add_argument('--display_id', type=int, default=0, help='window id of the web display') 16 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 17 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 18 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 19 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 20 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 21 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 22 | # network saving and loading parameters 23 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 24 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 25 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 26 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 27 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 28 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 29 | # training parameters 30 | parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate') 31 | parser.add_argument('--niter_decay', type=int, default=50, help='# of iter to linearly decay learning rate to zero') 32 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 33 | parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') 34 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') 35 | parser.add_argument('--g_lr_ratio', type=float, default=1.0, help='a ratio for changing learning rate of generator') 36 | parser.add_argument('--d_lr_ratio', type=float, default=1.0, help='a ratio for changing learning rate of discriminator') 37 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') 38 | parser.add_argument('--pool_size', type=int, default=40, help='the size of image buffer that stores previously generated images') 39 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') 40 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 41 | parser.add_argument('--save_iter_model', action='store_true', help='whether saves model by iteration') 42 | 43 | self.isTrain = True 44 | return parser 45 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.4.0 2 | torchvision>=0.5.0 3 | dominate>=2.4.0 4 | visdom>=0.1.8.8 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from util.misc import launch_job 5 | from train_net import test 6 | 7 | from options.test_options import TestOptions 8 | 9 | def main(): 10 | cfg = TestOptions().parse() # get training options 11 | cfg.NUM_GPUS = torch.cuda.device_count() 12 | cfg.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 13 | cfg.no_flip = True # no flip; comment this line if results on flipped images are needed. 14 | cfg.display_id = -1 # no visdom display; the test code saves the results to a HTML file. 15 | 16 | cfg.phase = 'test' 17 | cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS)) 18 | launch_job(cfg=cfg, init_method=cfg.init_method, func=test) 19 | 20 | 21 | if __name__=="__main__": 22 | main() -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.distributed as dist 4 | from util.misc import launch_job 5 | from train_net import train 6 | 7 | from options.train_options import TrainOptions 8 | 9 | def main(): 10 | cfg = TrainOptions().parse() # get training options 11 | cfg.NUM_GPUS = torch.cuda.device_count() 12 | cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS)) 13 | cfg.phase = 'train' 14 | launch_job(cfg=cfg, init_method=cfg.init_method, func=train) 15 | 16 | 17 | if __name__=="__main__": 18 | main() -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /util/multiprocessing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 3 | 4 | """Multiprocessing helpers.""" 5 | 6 | import torch 7 | 8 | 9 | def run( 10 | local_rank, 11 | num_proc, 12 | func, 13 | init_method, 14 | shard_id, 15 | num_shards, 16 | backend, 17 | cfg, 18 | output_queue=None, 19 | ): 20 | """ 21 | Runs a function from a child process. 22 | Args: 23 | local_rank (int): rank of the current process on the current machine. 24 | num_proc (int): number of processes per machine. 25 | func (function): function to execute on each of the process. 26 | init_method (string): method to initialize the distributed training. 27 | TCP initialization: equiring a network address reachable from all 28 | processes followed by the port. 29 | Shared file-system initialization: makes use of a file system that 30 | is shared and visible from all machines. The URL should start with 31 | file:// and contain a path to a non-existent file on a shared file 32 | system. 33 | shard_id (int): the rank of the current machine. 34 | num_shards (int): number of overall machines for the distributed 35 | training job. 36 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are 37 | supports, each with different capabilities. Details can be found 38 | here: 39 | https://pytorch.org/docs/stable/distributed.html 40 | cfg (CfgNode): configs. Details can be found in 41 | slowfast/config/defaults.py 42 | output_queue (queue): can optionally be used to return values from the 43 | master process. 44 | """ 45 | # Initialize the process group. 46 | world_size = num_proc * num_shards 47 | rank = shard_id * num_proc + local_rank 48 | try: 49 | torch.distributed.init_process_group( 50 | backend=backend, 51 | init_method=init_method, 52 | world_size=world_size, 53 | rank=rank, 54 | ) 55 | 56 | except Exception as e: 57 | raise e 58 | 59 | torch.cuda.set_device(local_rank) 60 | ret = func(cfg) 61 | if output_queue is not None and local_rank == 0: 62 | output_queue.put(ret) 63 | -------------------------------------------------------------------------------- /util/ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | def _ssim_c_s(img1, img2, window, window_size, channel, size_average = True): 40 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 41 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 42 | 43 | mu1_sq = mu1.pow(2) 44 | mu2_sq = mu2.pow(2) 45 | mu1_mu2 = mu1*mu2 46 | 47 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 48 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 49 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 50 | 51 | C1 = 0.01**2 52 | C2 = 0.03**2 53 | 54 | ssim_map = ((2*sigma12 + C2))/((sigma1_sq + sigma2_sq + C2)) 55 | 56 | if size_average: 57 | return ssim_map.mean() 58 | else: 59 | return ssim_map.mean(1).mean(1).mean(1) 60 | def _ssim_l(img1, img2, window, window_size, channel, size_average = True): 61 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 62 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 63 | 64 | mu1_sq = mu1.pow(2) 65 | mu2_sq = mu2.pow(2) 66 | mu1_mu2 = mu1*mu2 67 | 68 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 69 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 70 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 71 | 72 | C1 = 0.01**2 73 | C2 = 0.03**2 74 | 75 | ssim_map = (2*mu1_mu2 + C1)/(mu1_sq + mu2_sq + C1) 76 | 77 | if size_average: 78 | return ssim_map.mean() 79 | else: 80 | return ssim_map.mean(1).mean(1).mean(1) 81 | 82 | 83 | def ssim(img1, img2, window_size = 11, size_average = True): 84 | (_, channel, _, _) = img1.size() 85 | window = create_window(window_size, channel) 86 | 87 | if img1.is_cuda: 88 | window = window.cuda(img1.get_device()) 89 | window = window.type_as(img1) 90 | 91 | return _ssim(img1, img2, window, window_size, channel, size_average) 92 | 93 | class SSIM(torch.nn.Module): 94 | def __init__(self, window_size = 11, size_average = True, mode='all'): 95 | super(SSIM, self).__init__() 96 | self.window_size = window_size 97 | self.size_average = size_average 98 | self.channel = 1 99 | self.window = create_window(window_size, self.channel) 100 | self.mode = mode 101 | def forward(self, img1, img2): 102 | (_, channel, _, _) = img1.size() 103 | 104 | if channel == self.channel and self.window.data.type() == img1.data.type(): 105 | window = self.window 106 | else: 107 | window = create_window(self.window_size, channel) 108 | 109 | # if img1.is_cuda: 110 | # window = window.cuda(img1.get_device()) 111 | window = window.type_as(img1) 112 | 113 | self.window = window 114 | self.channel = channel 115 | if self.mode == 'all': 116 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 117 | elif self.mode == 'c_s': 118 | return _ssim_c_s(img1, img2, window, self.window_size, channel, self.size_average) 119 | else: 120 | return _ssim_l(img1, img2, window, self.window_size, channel, self.size_average) 121 | 122 | class DSSIM(torch.nn.Module): 123 | def __init__(self, window_size = 11, size_average = True, mode='all'): 124 | super(DSSIM, self).__init__() 125 | self.window_size = window_size 126 | self.size_average = size_average 127 | self.channel = 1 128 | self.window = create_window(window_size, self.channel) 129 | self.mode = mode 130 | def forward(self, img1, img2): 131 | (_, channel, _, _) = img1.size() 132 | 133 | if channel == self.channel and self.window.data.type() == img1.data.type(): 134 | window = self.window 135 | else: 136 | window = create_window(self.window_size, channel) 137 | 138 | # if img1.is_cuda: 139 | # window = window.cuda(img1.get_device()) 140 | window = window.type_as(img1) 141 | 142 | self.window = window 143 | self.channel = channel 144 | if self.mode == 'all': 145 | ssim_v = _ssim(img1, img2, window, self.window_size, channel, self.size_average) 146 | elif self.mode == 'c_s': 147 | ssim_v = _ssim_c_s(img1, img2, window, self.window_size, channel, self.size_average) 148 | else: 149 | ssim_v = _ssim_l(img1, img2, window, self.window_size, channel, self.size_average) 150 | return (1-ssim_v)/2 --------------------------------------------------------------------------------