├── .gitignore
├── LICENSE.md
├── README.md
├── applications
├── iih_enhancement
│ ├── .gitignore
│ ├── README.md
│ ├── data
│ │ ├── __init__.py
│ │ ├── adobe5k_dataset.py
│ │ ├── base_dataset.py
│ │ ├── image_folder.py
│ │ ├── loader.py
│ │ └── test_dataset.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── harmony_networks.py
│ │ ├── iih_base_gd_model.py
│ │ └── networks.py
│ ├── options
│ │ ├── __init__.py
│ │ ├── base_options.py
│ │ ├── test_options.py
│ │ └── train_options.py
│ ├── test.py
│ ├── train.py
│ ├── train_net.py
│ └── util
│ │ ├── distributed.py
│ │ ├── evaluation.py
│ │ ├── html.py
│ │ ├── misc.py
│ │ ├── multiprocessing.py
│ │ ├── ssim.py
│ │ ├── tools.py
│ │ ├── util.py
│ │ └── visualizer.py
├── iih_mef
│ ├── .gitignore
│ ├── README.md
│ ├── data
│ │ ├── __init__.py
│ │ ├── base_dataset.py
│ │ ├── image_folder.py
│ │ ├── loader.py
│ │ ├── mef_dataset.py
│ │ └── test_dataset.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── base_model.py
│ │ ├── harmony_networks.py
│ │ ├── iih_base_gd_model.py
│ │ └── networks.py
│ ├── options
│ │ ├── __init__.py
│ │ ├── base_options.py
│ │ ├── test_options.py
│ │ └── train_options.py
│ ├── test.py
│ ├── train.py
│ ├── train_net.py
│ └── util
│ │ ├── distributed.py
│ │ ├── evaluation.py
│ │ ├── html.py
│ │ ├── misc.py
│ │ ├── multiprocessing.py
│ │ ├── ssim.py
│ │ ├── tools.py
│ │ ├── util.py
│ │ └── visualizer.py
└── iih_relighting
│ ├── .gitignore
│ ├── README.md
│ ├── data
│ ├── __init__.py
│ ├── base_dataset.py
│ ├── dpr_dataset.py
│ ├── dprtransfer_dataset.py
│ ├── image_folder.py
│ ├── loader.py
│ └── test_dataset.py
│ ├── models
│ ├── __init__.py
│ ├── base_model.py
│ ├── iih_base_lt_model.py
│ ├── networks.py
│ └── relighting_networks.py
│ ├── options
│ ├── __init__.py
│ ├── base_options.py
│ ├── test_options.py
│ └── train_options.py
│ ├── test.py
│ ├── train.py
│ ├── train_net.py
│ └── util
│ ├── distributed.py
│ ├── evaluation.py
│ ├── html.py
│ ├── misc.py
│ ├── multiprocessing.py
│ ├── ssim.py
│ ├── tools.py
│ ├── util.py
│ └── visualizer.py
├── data
├── __init__.py
├── base_dataset.py
├── ihd_dataset.py
├── image_folder.py
├── loader.py
└── real_dataset.py
├── evaluation
├── ih_evaluation.py
└── pytorch_ssim.py
├── models
├── __init__.py
├── base_model.py
├── harmony_networks.py
├── iih_base_gd_model.py
├── iih_base_lt_gd_model.py
├── iih_base_lt_model.py
├── iih_base_model.py
└── networks.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── requirements.txt
├── test.py
├── train.py
├── train_net.py
└── util
├── distributed.py
├── html.py
├── misc.py
├── multiprocessing.py
├── ssim.py
├── tools.py
├── util.py
└── visualizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | .vscode/*
132 | checkpoints/*
133 | results/*
134 | applications/iih_enhancement/evaluation/enhance_evaluation.py
135 | applications/iih_enhancement/evaluation/pytorch_ssim.py
136 | applications/iih_mef/evaluation.py
137 | applications/iih_relighting/evaluation/relighting.py
138 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 AI @ OUC
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/applications/iih_enhancement/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/*
2 | checkpoints/*
3 | results/*
4 | */__pycache__/*
5 | tmp/*
6 | __pycache__/distribute.cpython-37.pyc
7 | __pycache__/options.cpython-37.pyc
8 | __pycache__/train_net.cpython-37.pyc
9 | __pycache__/train.cpython-37.pyc
--------------------------------------------------------------------------------
/applications/iih_enhancement/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # Image Enhancement
5 |
6 | Here we provide PyTorch implementation and the pre-trained model of our latest version.
7 |
8 | ## Prerequisites
9 |
10 | - Linux
11 | - Python 3
12 | - CPU or NVIDIA GPU + CUDA CuDNN
13 |
14 | ## Base Model with Guiding
15 | - Download MIT-Adobe-5K-UPE dataset.
16 |
17 | - Train
18 | ```bash
19 | CUDA_VISIBLE_DEVICES=0 python train.py --model iih_base_gd --name base_gd_adobe5k_test --dataset_root --dataset_name Adobe5k --batch_size xx --init_port xxx
20 | ```
21 | - Test
22 | ```bash
23 | CUDA_VISIBLE_DEVICES=0 python test.py --model iih_base_gd --name base_gd_adobe5k_test --dataset_root --dataset_name Adobe5k --batch_size xx --init_port xxxx
24 | ```
25 | - Apply pre-trained model
26 |
27 | Download pre-trained model from [Google Drive](https://drive.google.com/file/d/1h9EG2kZnYi3GI4nAsqnJb1HHBv8GeNf7/view?usp=sharing) or [BaiduCloud](https://pan.baidu.com/s/1mhAxHjetfIvZv-O-kqeHTA) (access code: 0r0k), and put `latest_net_G.pth` in the directory `checkpoints/base_gd_enhancement`. Run:
28 | ```bash
29 | CUDA_VISIBLE_DEVICES=0 python test.py --model iih_base_gd --name base_gd_enhancement --dataset_root --dataset_name Adobe5k --batch_size xx --init_port xxxx
30 | ```
31 |
--------------------------------------------------------------------------------
/applications/iih_enhancement/data/adobe5k_dataset.py:
--------------------------------------------------------------------------------
1 | """Dataset class template
2 |
3 | This module provides a template for users to implement custom datasets.
4 | You can specify '--dataset_mode template' to use this dataset.
5 | The class name should be consistent with both the filename and its dataset_mode option.
6 | The filename should be _dataset.py
7 | The class name should be Dataset.py
8 | You need to implement the following functions:
9 | -- : Add dataset-specific options and rewrite default values for existing options.
10 | -- <__init__>: Initialize this dataset class.
11 | -- <__getitem__>: Return a data point and its metadata information.
12 | -- <__len__>: Return the number of images.
13 | """
14 | import os.path
15 | import torch
16 | import torchvision.transforms.functional as tf
17 | import torch.nn.functional as F
18 | from data.base_dataset import BaseDataset, get_transform
19 | from data.image_folder import make_dataset
20 | from PIL import Image
21 | import numpy as np
22 | import torchvision.transforms as transforms
23 | from util import util
24 |
25 | class Adobe5kDataset(BaseDataset):
26 | @staticmethod
27 | def modify_commandline_options(parser, is_train):
28 | parser.add_argument('--is_train', type=bool, default=True, help='whether in the training phase')
29 | parser.set_defaults(max_dataset_size=float("inf"), new_dataset_option=2.0) # specify dataset-specific default values
30 | return parser
31 |
32 | def __init__(self, opt):
33 | """Initialize this dataset class.
34 |
35 | Parameters:
36 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
37 |
38 | A few things can be done here.
39 | - save the options (have been done in BaseDataset)
40 | - get image paths and meta information of the dataset.
41 | - define the image transformation.
42 | """
43 | # save the option and dataset root
44 | BaseDataset.__init__(self, opt)
45 | self.fake_image_paths = []
46 | self.image_paths = []
47 | self.isTrain = opt.isTrain
48 | self.image_size = opt.crop_size
49 |
50 | if opt.isTrain==True:
51 | #self.real_ext='.jpg'
52 | print('loading training file')
53 | self.trainfile = opt.dataset_root+opt.dataset_name+'_train.txt'
54 | with open(self.trainfile,'r') as f:
55 | for line in f.readlines():
56 | self.image_paths.append(os.path.join(opt.dataset_root,'UPEresize',line.rstrip()))
57 | elif opt.isTrain==False:
58 | #self.real_ext='.jpg'
59 | print('loading test file')
60 | self.trainfile = opt.dataset_root+opt.dataset_name+'_test.txt'
61 | with open(self.trainfile,'r') as f:
62 | for line in f.readlines():
63 | self.image_paths.append(os.path.join(opt.dataset_root,'test_set/input/',line.rstrip()))
64 | # get the image paths of your dataset;
65 | # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
66 | # define the default transform function. You can use ; You can also define your custom transform function
67 | transform_list = [
68 | transforms.ToTensor(),
69 | transforms.Normalize((0, 0, 0), (1, 1, 1))
70 | ]
71 | self.transforms = transforms.Compose(transform_list)
72 |
73 | def __getitem__(self, index):
74 | """Return a data point and its metadata information.
75 |
76 | Parameters:
77 | index -- a random integer for data indexing
78 |
79 | Returns:
80 | a dictionary of data with their names. It usually contains the data itself and its metadata information.
81 |
82 | Step 1: get a random image path: e.g., path = self.image_paths[index]
83 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
84 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
85 | Step 4: return a data point as a dictionary.
86 | """
87 | path = self.image_paths[index]
88 | name_parts=path.split('/')
89 | if self.isTrain:
90 | target_path = self.image_paths[index].replace(name_parts[-2],'Expert_C_resize')
91 | else:
92 | target_path = self.image_paths[index].replace('input','expertC_gt')
93 |
94 | comp = Image.open(path).convert('RGB')
95 | real = Image.open(target_path).convert('RGB')
96 |
97 | if np.random.rand() > 0.5 and self.isTrain:
98 | comp, real = tf.hflip(comp), tf.hflip(real)
99 | if comp.size[0] != self.image_size:
100 | # assert 0
101 | comp = tf.resize(comp, [self.image_size, self.image_size])
102 | real = tf.resize(real, [self.image_size,self.image_size])
103 |
104 | comp = self.transforms(comp)
105 | real = self.transforms(real)
106 | return {'fake': comp, 'real': real,'img_path':path}
107 |
108 | def __len__(self):
109 | """Return the total number of images."""
110 | return len(self.image_paths)
111 |
112 |
--------------------------------------------------------------------------------
/applications/iih_enhancement/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2 |
3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4 | """
5 | import random
6 | import numpy as np
7 | import torch.utils.data as data
8 | from PIL import Image
9 | import torchvision.transforms as transforms
10 | from abc import ABC, abstractmethod
11 |
12 |
13 | class BaseDataset(data.Dataset, ABC):
14 | """This class is an abstract base class (ABC) for datasets.
15 |
16 | To create a subclass, you need to implement the following four functions:
17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18 | -- <__len__>: return the size of dataset.
19 | -- <__getitem__>: get a data point.
20 | -- : (optionally) add dataset-specific options and set default options.
21 | """
22 |
23 | def __init__(self, opt):
24 | """Initialize the class; save the options in the class
25 |
26 | Parameters:
27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28 | """
29 | self.opt = opt
30 | # self.root = opt.dataroot
31 | self.root = opt.dataset_root
32 |
33 | @staticmethod
34 | def modify_commandline_options(parser, is_train):
35 | """Add new dataset-specific options, and rewrite default values for existing options.
36 |
37 | Parameters:
38 | parser -- original option parser
39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40 |
41 | Returns:
42 | the modified parser.
43 | """
44 | return parser
45 |
46 | @abstractmethod
47 | def __len__(self):
48 | """Return the total number of images in the dataset."""
49 | return 0
50 |
51 | @abstractmethod
52 | def __getitem__(self, index):
53 | """Return a data point and its metadata information.
54 |
55 | Parameters:
56 | index - - a random integer for data indexing
57 |
58 | Returns:
59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60 | """
61 | pass
62 |
63 |
64 | def get_params(opt, size):
65 | w, h = size
66 | new_h = h
67 | new_w = w
68 | if opt.preprocess == 'resize_and_crop':
69 | new_h = new_w = opt.load_size
70 | elif opt.preprocess == 'scale_width_and_crop':
71 | new_w = opt.load_size
72 | new_h = opt.load_size * h // w
73 |
74 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
75 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
76 |
77 | flip = random.random() > 0.5
78 |
79 | return {'crop_pos': (x, y), 'flip': flip}
80 |
81 |
82 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
83 | transform_list = []
84 | if grayscale:
85 | transform_list.append(transforms.Grayscale(1))
86 | if 'resize' in opt.preprocess:
87 | osize = [opt.load_size, opt.load_size]
88 | transform_list.append(transforms.Resize(osize, method))
89 | elif 'scale_width' in opt.preprocess:
90 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
91 |
92 | if 'crop' in opt.preprocess:
93 | if params is None:
94 | transform_list.append(transforms.RandomCrop(opt.crop_size))
95 | else:
96 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
97 |
98 | if opt.preprocess == 'none':
99 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
100 |
101 | if not opt.no_flip:
102 | if params is None:
103 | transform_list.append(transforms.RandomHorizontalFlip())
104 | elif params['flip']:
105 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
106 |
107 | if convert:
108 | transform_list += [transforms.ToTensor()]
109 | if grayscale:
110 | transform_list += [transforms.Normalize((0.5,), (0.5,))]
111 | else:
112 | # transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
113 | transform_list += [transforms.Normalize((0, 0, 0), (1, 1, 1))]
114 | return transforms.Compose(transform_list)
115 |
116 |
117 | def __make_power_2(img, base, method=Image.BICUBIC):
118 | ow, oh = img.size
119 | h = int(round(oh / base) * base)
120 | w = int(round(ow / base) * base)
121 | if (h == oh) and (w == ow):
122 | return img
123 |
124 | __print_size_warning(ow, oh, w, h)
125 | return img.resize((w, h), method)
126 |
127 |
128 | def __scale_width(img, target_width, method=Image.BICUBIC):
129 | ow, oh = img.size
130 | if (ow == target_width):
131 | return img
132 | w = target_width
133 | h = int(target_width * oh / ow)
134 | return img.resize((w, h), method)
135 |
136 |
137 | def __crop(img, pos, size):
138 | ow, oh = img.size
139 | x1, y1 = pos
140 | tw = th = size
141 | if (ow > tw or oh > th):
142 | return img.crop((x1, y1, x1 + tw, y1 + th))
143 | return img
144 |
145 |
146 | def __flip(img, flip):
147 | if flip:
148 | return img.transpose(Image.FLIP_LEFT_RIGHT)
149 | return img
150 |
151 |
152 | def __print_size_warning(ow, oh, w, h):
153 | """Print warning information about image size(only print once)"""
154 | if not hasattr(__print_size_warning, 'has_printed'):
155 | print("The image size needs to be a multiple of 4. "
156 | "The loaded image size was (%d, %d), so it was adjusted to "
157 | "(%d, %d). This adjustment will be done to all images "
158 | "whose sizes are not multiples of 4" % (ow, oh, w, h))
159 | __print_size_warning.has_printed = True
160 |
--------------------------------------------------------------------------------
/applications/iih_enhancement/data/image_folder.py:
--------------------------------------------------------------------------------
1 | """A modified image folder class
2 |
3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4 | so that this class can load images from both current directory and its subdirectories.
5 | """
6 |
7 | import torch.utils.data as data
8 |
9 | from PIL import Image
10 | import os
11 | import os.path
12 |
13 | IMG_EXTENSIONS = [
14 | '.jpg', '.JPG', '.jpeg', '.JPEG',
15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16 | ]
17 |
18 |
19 | def is_image_file(filename):
20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21 |
22 |
23 | def make_dataset(dir, max_dataset_size=float("inf")):
24 | images = []
25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
26 |
27 | for root, _, fnames in sorted(os.walk(dir)):
28 | for fname in fnames:
29 | if is_image_file(fname):
30 | path = os.path.join(root, fname)
31 | images.append(path)
32 | return images[:min(max_dataset_size, len(images))]
33 |
34 |
35 | def default_loader(path):
36 | return Image.open(path).convert('RGB')
37 |
38 |
39 | class ImageFolder(data.Dataset):
40 |
41 | def __init__(self, root, transform=None, return_paths=False,
42 | loader=default_loader):
43 | imgs = make_dataset(root)
44 | if len(imgs) == 0:
45 | raise(RuntimeError("Found 0 images in: " + root + "\n"
46 | "Supported image extensions are: " +
47 | ",".join(IMG_EXTENSIONS)))
48 |
49 | self.root = root
50 | self.imgs = imgs
51 | self.transform = transform
52 | self.return_paths = return_paths
53 | self.loader = loader
54 |
55 | def __getitem__(self, index):
56 | path = self.imgs[index]
57 | img = self.loader(path)
58 | if self.transform is not None:
59 | img = self.transform(img)
60 | if self.return_paths:
61 | return img, path
62 | else:
63 | return img
64 |
65 | def __len__(self):
66 | return len(self.imgs)
67 |
--------------------------------------------------------------------------------
/applications/iih_enhancement/data/loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | """Data loader."""
5 |
6 | import itertools
7 | import numpy as np
8 | import torch
9 | from torch.utils.data._utils.collate import default_collate
10 | from torch.utils.data.distributed import DistributedSampler
11 | from torch.utils.data.sampler import RandomSampler
12 |
13 | from slowfast.datasets.multigrid_helper import ShortCycleBatchSampler
14 |
15 | from . import utils as utils
16 |
17 | def build_dataset(cfg):
18 | image_paths = []
19 | if cfg.phase == 'train':
20 | print('loading training file')
21 | file = cfg.dataset_root+cfg.dataset_name+'_train.txt'
22 | with open(file,'r') as f:
23 | for line in f.readlines():
24 | image_paths.append(os.path.join(cfg.dataset_root,'composite_images',line.rstrip()))
25 |
26 |
27 | def construct_loader(cfg, split, is_precise_bn=False):
28 | """
29 | Constructs the data loader for the given dataset.
30 | Args:
31 | cfg (CfgNode): configs. Details can be found in
32 | slowfast/config/defaults.py
33 | split (str): the split of the data loader. Options include `train`,
34 | `val`, and `test`.
35 | """
36 | assert split in ["train", "val", "test"]
37 | if split in ["train"]:
38 | dataset_name = cfg.TRAIN.DATASET
39 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
40 | shuffle = True
41 | drop_last = True
42 | elif split in ["val"]:
43 | dataset_name = cfg.TRAIN.DATASET
44 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
45 | shuffle = False
46 | drop_last = False
47 | elif split in ["test"]:
48 | dataset_name = cfg.TEST.DATASET
49 | batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS))
50 | shuffle = False
51 | drop_last = False
52 |
53 | # Construct the dataset
54 | dataset = build_dataset(dataset_name, cfg, split)
55 |
56 | if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn:
57 | # Create a sampler for multi-process training
58 | sampler = utils.create_sampler(dataset, shuffle, cfg)
59 | batch_sampler = ShortCycleBatchSampler(
60 | sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg
61 | )
62 | # Create a loader
63 | loader = torch.utils.data.DataLoader(
64 | dataset,
65 | batch_sampler=batch_sampler,
66 | num_workers=cfg.DATA_LOADER.NUM_WORKERS,
67 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
68 | worker_init_fn=utils.loader_worker_init_fn(dataset),
69 | )
70 | else:
71 | # Create a sampler for multi-process training
72 | sampler = utils.create_sampler(dataset, shuffle, cfg)
73 | # Create a loader
74 | loader = torch.utils.data.DataLoader(
75 | dataset,
76 | batch_size=batch_size,
77 | shuffle=(False if sampler else shuffle),
78 | sampler=sampler,
79 | num_workers=cfg.DATA_LOADER.NUM_WORKERS,
80 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
81 | drop_last=drop_last,
82 | collate_fn=detection_collate if cfg.DETECTION.ENABLE else None,
83 | worker_init_fn=utils.loader_worker_init_fn(dataset),
84 | )
85 | return loader
86 |
87 |
88 | def shuffle_dataset(loader, cur_epoch):
89 | """ "
90 | Shuffles the data.
91 | Args:
92 | loader (loader): data loader to perform shuffle.
93 | cur_epoch (int): number of the current epoch.
94 | """
95 | sampler = (
96 | loader.batch_sampler.sampler
97 | if isinstance(loader.batch_sampler, ShortCycleBatchSampler)
98 | else loader.sampler
99 | )
100 | assert isinstance(
101 | sampler, (RandomSampler, DistributedSampler)
102 | ), "Sampler type '{}' not supported".format(type(sampler))
103 | # RandomSampler handles shuffling automatically
104 | if isinstance(sampler, DistributedSampler):
105 | # DistributedSampler shuffles data based on epoch
106 | sampler.set_epoch(cur_epoch)
107 |
--------------------------------------------------------------------------------
/applications/iih_enhancement/models/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to objective functions, optimizations, and network architectures.
2 |
3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4 | You need to implement the following five functions:
5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6 | -- : unpack data from dataset and apply preprocessing.
7 | -- : produce intermediate results.
8 | -- : calculate loss, gradients, and update network weights.
9 | -- : (optionally) add model-specific options and set default options.
10 |
11 | In the function <__init__>, you need to define four lists:
12 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
13 | -- self.model_names (str list): define networks used in our training.
14 | -- self.visual_names (str list): specify the images that you want to display and save.
15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16 |
17 | Now you can use the model class by specifying flag '--model dummy'.
18 | See our template model class 'template_model.py' for more details.
19 | """
20 |
21 | import importlib
22 | from models.base_model import BaseModel
23 |
24 |
25 | def find_model_using_name(model_name):
26 | """Import the module "models/[model_name]_model.py".
27 |
28 | In the file, the class called DatasetNameModel() will
29 | be instantiated. It has to be a subclass of BaseModel,
30 | and it is case-insensitive.
31 | """
32 | model_filename = "models." + model_name + "_model"
33 | modellib = importlib.import_module(model_filename)
34 | model = None
35 | target_model_name = model_name.replace('_', '') + 'model'
36 | for name, cls in modellib.__dict__.items():
37 | if name.lower() == target_model_name.lower() \
38 | and issubclass(cls, BaseModel):
39 | model = cls
40 |
41 | if model is None:
42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43 | exit(0)
44 |
45 | return model
46 |
47 |
48 | def get_option_setter(model_name):
49 | """Return the static method of the model class."""
50 | model_class = find_model_using_name(model_name)
51 | return model_class.modify_commandline_options
52 |
53 |
54 | def create_model(opt):
55 | """Create a model given the option.
56 |
57 | This function warps the class CustomDatasetDataLoader.
58 | This is the main interface between this package and 'train.py'/'test.py'
59 |
60 | Example:
61 | >>> from models import create_model
62 | >>> model = create_model(opt)
63 | """
64 | model = find_model_using_name(opt.model)
65 | instance = model(opt)
66 | print("model [%s] was created" % type(instance).__name__)
67 | return instance
68 |
--------------------------------------------------------------------------------
/applications/iih_enhancement/models/iih_base_gd_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import itertools
4 | import torch.nn.functional as F
5 | from util import distributed as du
6 | # import pytorch_colors as colors
7 | from .base_model import BaseModel
8 | from util import util
9 | from . import harmony_networks as networks
10 | import util.ssim as ssim
11 |
12 |
13 | class IIHBaseGDModel(BaseModel):
14 | @staticmethod
15 | def modify_commandline_options(parser, is_train=True):
16 | parser.set_defaults(norm='instance', netG='base_gd', dataset_mode='adobe5k')
17 | if is_train:
18 | parser.add_argument('--lambda_L1', type=float, default=50.0, help='weight for L1 loss')
19 | parser.add_argument('--lambda_R', type=float, default=100., help='weight for R gradient loss')
20 | parser.add_argument('--lambda_ssim', type=float, default=50., help='weight for L L2 loss')
21 | parser.add_argument('--lambda_ifm', type=float, default=100, help='weight for pm loss')
22 |
23 | return parser
24 |
25 | def __init__(self, opt):
26 | BaseModel.__init__(self, opt)
27 | self.opt = opt
28 | # specify the training losses you want to print out. The training/test scripts will call
29 | self.loss_names = ['G','G_L1','G_R','G_R_SSIM',"IF"]
30 |
31 | # specify the images you want to save/display. The training/test scripts will call
32 | self.visual_names = ['enhanced','real','fake','reconstruct','illumination']
33 | # specify the models you want to save to the disk. The training/test scripts will call and
34 | self.model_names = ['G']
35 | self.opt.device = self.device
36 | self.netG = networks.define_G(opt.netG, opt.init_type, opt.init_gain, self.opt)
37 | self.cur_device = torch.cuda.current_device()
38 | self.ismaster = du.is_master_proc(opt.NUM_GPUS)
39 | print(self.netG)
40 |
41 | if self.isTrain:
42 | # if self.ismaster == 0:
43 | util.saveprint(self.opt, 'netG', str(self.netG))
44 | # define loss functions
45 | self.criterionL1 = torch.nn.L1Loss()
46 | self.criterionL2 = torch.nn.MSELoss()
47 | self.criterionSSIM = ssim.SSIM()
48 | self.criterionDSSIM_CS = ssim.DSSIM(mode='c_s').to(self.device)
49 | # initialize optimizers; schedulers will be automatically created by function .
50 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
51 | self.optimizers.append(self.optimizer_G)
52 |
53 | def set_input(self, input):
54 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
55 |
56 | Parameters:
57 | input (dict): include the data itself and its metadata information.
58 |
59 | The option 'direction' can be used to swap images in domain A and domain B.
60 | """
61 | self.fake = input['fake'].to(self.device)
62 | self.real = input['real'].to(self.device)
63 | self.image_paths = input['img_path']
64 | self.real_r = F.interpolate(self.real, size=[32,32])
65 | self.real_gray = util.rgbtogray(self.real_r)
66 | def forward(self):
67 | self.reconstruct, self.enhanced, self.illumination, self.ifm_mean = self.netG(self.fake)
68 | def backward_G(self):
69 | self.loss_IF = self.criterionDSSIM_CS(self.ifm_mean, self.real_gray)*self.opt.lambda_ifm
70 |
71 | self.loss_G_L1 = self.criterionL1(self.reconstruct, self.fake)*self.opt.lambda_L1
72 | self.loss_G_R = self.criterionL2(self.enhanced, self.real)*self.opt.lambda_R
73 | self.loss_G_R_SSIM = (1-self.criterionSSIM(self.enhanced, self.real))*self.opt.lambda_ssim
74 | self.loss_G = self.loss_G_L1 + self.loss_G_R + self.loss_G_R_SSIM + self.loss_IF
75 | self.loss_G.backward()
76 |
77 | def optimize_parameters(self):
78 | self.forward() # compute fake images: G(A)
79 | # update G
80 | self.optimizer_G.zero_grad() # set G's gradients to zero
81 | self.backward_G() # calculate graidents for G
82 | self.optimizer_G.step() # udpate G's weights
--------------------------------------------------------------------------------
/applications/iih_enhancement/options/__init__.py:
--------------------------------------------------------------------------------
1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
2 |
--------------------------------------------------------------------------------
/applications/iih_enhancement/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | """This class includes test options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser) # define shared options
12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
16 | # Dropout and Batchnorm has different behavioir during training and test.
17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
18 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
19 | parser.add_argument('--test_epoch', type=str, default="0", help='how many test images to run')
20 | # rewrite devalue values
21 | parser.set_defaults(model='test')
22 | # To avoid cropping, the load_size should be the same as crop_size
23 | parser.set_defaults(load_size=parser.get_default('crop_size'))
24 | self.isTrain = False
25 | return parser
26 |
--------------------------------------------------------------------------------
/applications/iih_enhancement/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | """This class includes training options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser)
12 | # visdom and HTML visualization parameters
13 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
14 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
15 | parser.add_argument('--display_id', type=int, default=0, help='window id of the web display')
16 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
17 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
18 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
19 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
20 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
21 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
22 | # network saving and loading parameters
23 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
24 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
25 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
26 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
27 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
28 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
29 | # training parameters
30 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
31 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
32 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
33 | parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
34 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
35 | parser.add_argument('--g_lr_ratio', type=float, default=1.0, help='a ratio for changing learning rate of generator')
36 | parser.add_argument('--d_lr_ratio', type=float, default=1.0, help='a ratio for changing learning rate of discriminator')
37 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
38 | parser.add_argument('--pool_size', type=int, default=40, help='the size of image buffer that stores previously generated images')
39 | parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]')
40 | parser.add_argument('--lr_decay_iters', type=int, default=40, help='multiply by a gamma every lr_decay_iters iterations')
41 | parser.add_argument('--save_iter_model', action='store_true', help='whether saves model by iteration')
42 |
43 |
44 |
45 |
46 | self.isTrain = True
47 | return parser
48 |
--------------------------------------------------------------------------------
/applications/iih_enhancement/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 | from util.misc import launch_job
5 | from train_net import test
6 |
7 | from options.test_options import TestOptions
8 |
9 | def main():
10 | cfg = TestOptions().parse() # get training options
11 | cfg.NUM_GPUS = torch.cuda.device_count()
12 | cfg.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
13 | cfg.no_flip = True # no flip; comment this line if results on flipped images are needed.
14 | cfg.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
15 |
16 | cfg.phase = 'test'
17 | cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS))
18 | launch_job(cfg=cfg, init_method=cfg.init_method, func=test)
19 |
20 |
21 | if __name__=="__main__":
22 | main()
--------------------------------------------------------------------------------
/applications/iih_enhancement/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 | from util.misc import launch_job
5 | from train_net import train
6 |
7 | from options.train_options import TrainOptions
8 |
9 | def main():
10 | cfg = TrainOptions().parse() # get training options
11 | cfg.NUM_GPUS = torch.cuda.device_count()
12 | cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS))
13 | cfg.phase = 'train'
14 | launch_job(cfg=cfg, init_method=cfg.init_method, func=train)
15 |
16 |
17 | if __name__=="__main__":
18 | main()
--------------------------------------------------------------------------------
/applications/iih_enhancement/util/evaluation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as f
3 |
4 |
5 | def evaluation(name, fake, real, mask):
6 | b,c,w,h = real.size()
7 | mse_score = f.mse_loss(fake, real)
8 | fore_area = torch.sum(mask)
9 | fmse_score = f.mse_loss(fake*mask,real*mask)*w*h/fore_area
10 | mse_score = mse_score.item()
11 | fmse_score = fmse_score.item()
12 | # score_str = "%s MSE %0.2f | fMSE %0.2f" % (name, mse_score,fmse_score)
13 | image_fmse_info = (name, round(fmse_score,2), round(mse_score, 2))
14 | return mse_score, fmse_score, image_fmse_info
--------------------------------------------------------------------------------
/applications/iih_enhancement/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3 | import os
4 |
5 |
6 | class HTML:
7 | """This HTML class allows us to save images and write texts into a single HTML file.
8 |
9 | It consists of functions such as (add a text header to the HTML file),
10 | (add a row of images to the HTML file), and (save the HTML to the disk).
11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12 | """
13 |
14 | def __init__(self, web_dir, title, refresh=0):
15 | """Initialize the HTML classes
16 |
17 | Parameters:
18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
32 | with self.doc.head:
33 | meta(http_equiv="refresh", content=str(refresh))
34 |
35 | def get_image_dir(self):
36 | """Return the directory that stores images"""
37 | return self.img_dir
38 |
39 | def add_header(self, text):
40 | """Insert a header to the HTML file
41 |
42 | Parameters:
43 | text (str) -- the header text
44 | """
45 | with self.doc:
46 | h3(text)
47 |
48 | def add_images(self, ims, txts, links, width=400):
49 | """add images to the HTML file
50 |
51 | Parameters:
52 | ims (str list) -- a list of image paths
53 | txts (str list) -- a list of image names shown on the website
54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55 | """
56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57 | self.doc.add(self.t)
58 | with self.t:
59 | with tr():
60 | for im, txt, link in zip(ims, txts, links):
61 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
62 | with p():
63 | with a(href=os.path.join('images', link)):
64 | img(style="width:%dpx" % width, src=os.path.join('images', im))
65 | br()
66 | p(txt)
67 |
68 | def save(self):
69 | """save the current content to the HMTL file"""
70 | html_file = '%s/index.html' % self.web_dir
71 | f = open(html_file, 'wt')
72 | f.write(self.doc.render())
73 | f.close()
74 |
75 |
76 | if __name__ == '__main__': # we show an example usage here.
77 | html = HTML('web/', 'test_html')
78 | html.add_header('hello world')
79 |
80 | ims, txts, links = [], [], []
81 | for n in range(4):
82 | ims.append('image_%d.png' % n)
83 | txts.append('text_%d' % n)
84 | links.append('image_%d.png' % n)
85 | html.add_images(ims, txts, links)
86 | html.save()
87 |
--------------------------------------------------------------------------------
/applications/iih_enhancement/util/multiprocessing.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | """Multiprocessing helpers."""
5 |
6 | import torch
7 |
8 |
9 | def run(
10 | local_rank,
11 | num_proc,
12 | func,
13 | init_method,
14 | shard_id,
15 | num_shards,
16 | backend,
17 | cfg,
18 | output_queue=None,
19 | ):
20 | """
21 | Runs a function from a child process.
22 | Args:
23 | local_rank (int): rank of the current process on the current machine.
24 | num_proc (int): number of processes per machine.
25 | func (function): function to execute on each of the process.
26 | init_method (string): method to initialize the distributed training.
27 | TCP initialization: equiring a network address reachable from all
28 | processes followed by the port.
29 | Shared file-system initialization: makes use of a file system that
30 | is shared and visible from all machines. The URL should start with
31 | file:// and contain a path to a non-existent file on a shared file
32 | system.
33 | shard_id (int): the rank of the current machine.
34 | num_shards (int): number of overall machines for the distributed
35 | training job.
36 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are
37 | supports, each with different capabilities. Details can be found
38 | here:
39 | https://pytorch.org/docs/stable/distributed.html
40 | cfg (CfgNode): configs. Details can be found in
41 | slowfast/config/defaults.py
42 | output_queue (queue): can optionally be used to return values from the
43 | master process.
44 | """
45 | # Initialize the process group.
46 | world_size = num_proc * num_shards
47 | rank = shard_id * num_proc + local_rank
48 | try:
49 | torch.distributed.init_process_group(
50 | backend=backend,
51 | init_method=init_method,
52 | world_size=world_size,
53 | rank=rank,
54 | )
55 |
56 | except Exception as e:
57 | raise e
58 |
59 | torch.cuda.set_device(local_rank)
60 | ret = func(cfg)
61 | if output_queue is not None and local_rank == 0:
62 | output_queue.put(ret)
63 |
--------------------------------------------------------------------------------
/applications/iih_enhancement/util/ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 | def create_window(window_size, channel):
12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15 | return window
16 |
17 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
20 |
21 | mu1_sq = mu1.pow(2)
22 | mu2_sq = mu2.pow(2)
23 | mu1_mu2 = mu1*mu2
24 |
25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
28 |
29 | C1 = 0.01**2
30 | C2 = 0.03**2
31 |
32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
33 |
34 | if size_average:
35 | return ssim_map.mean()
36 | else:
37 | return ssim_map.mean(1).mean(1).mean(1)
38 |
39 | def _ssim_c_s(img1, img2, window, window_size, channel, size_average = True):
40 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
41 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
42 |
43 | mu1_sq = mu1.pow(2)
44 | mu2_sq = mu2.pow(2)
45 | mu1_mu2 = mu1*mu2
46 |
47 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
48 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
49 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
50 |
51 | C1 = 0.01**2
52 | C2 = 0.03**2
53 |
54 | ssim_map = ((2*sigma12 + C2))/((sigma1_sq + sigma2_sq + C2))
55 |
56 | if size_average:
57 | return ssim_map.mean()
58 | else:
59 | return ssim_map.mean(1).mean(1).mean(1)
60 | def _ssim_l(img1, img2, window, window_size, channel, size_average = True):
61 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
62 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
63 |
64 | mu1_sq = mu1.pow(2)
65 | mu2_sq = mu2.pow(2)
66 | mu1_mu2 = mu1*mu2
67 |
68 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
69 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
70 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
71 |
72 | C1 = 0.01**2
73 | C2 = 0.03**2
74 |
75 | ssim_map = (2*mu1_mu2 + C1)/(mu1_sq + mu2_sq + C1)
76 |
77 | if size_average:
78 | return ssim_map.mean()
79 | else:
80 | return ssim_map.mean(1).mean(1).mean(1)
81 |
82 |
83 | def ssim(img1, img2, window_size = 11, size_average = True):
84 | (_, channel, _, _) = img1.size()
85 | window = create_window(window_size, channel)
86 |
87 | if img1.is_cuda:
88 | window = window.cuda(img1.get_device())
89 | window = window.type_as(img1)
90 |
91 | return _ssim(img1, img2, window, window_size, channel, size_average)
92 |
93 | class SSIM(torch.nn.Module):
94 | def __init__(self, window_size = 11, size_average = True, mode='all'):
95 | super(SSIM, self).__init__()
96 | self.window_size = window_size
97 | self.size_average = size_average
98 | self.channel = 1
99 | self.window = create_window(window_size, self.channel)
100 | self.mode = mode
101 | def forward(self, img1, img2):
102 | (_, channel, _, _) = img1.size()
103 |
104 | if channel == self.channel and self.window.data.type() == img1.data.type():
105 | window = self.window
106 | else:
107 | window = create_window(self.window_size, channel)
108 |
109 | # if img1.is_cuda:
110 | # window = window.cuda(img1.get_device())
111 | window = window.type_as(img1)
112 |
113 | self.window = window
114 | self.channel = channel
115 | if self.mode == 'all':
116 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
117 | elif self.mode == 'c_s':
118 | return _ssim_c_s(img1, img2, window, self.window_size, channel, self.size_average)
119 | else:
120 | return _ssim_l(img1, img2, window, self.window_size, channel, self.size_average)
121 |
122 | class DSSIM(torch.nn.Module):
123 | def __init__(self, window_size = 11, size_average = True, mode='all'):
124 | super(DSSIM, self).__init__()
125 | self.window_size = window_size
126 | self.size_average = size_average
127 | self.channel = 1
128 | self.window = create_window(window_size, self.channel)
129 | self.mode = mode
130 | def forward(self, img1, img2):
131 | (_, channel, _, _) = img1.size()
132 |
133 | if channel == self.channel and self.window.data.type() == img1.data.type():
134 | window = self.window
135 | else:
136 | window = create_window(self.window_size, channel)
137 |
138 | # if img1.is_cuda:
139 | # window = window.cuda(img1.get_device())
140 | window = window.type_as(img1)
141 |
142 | self.window = window
143 | self.channel = channel
144 | if self.mode == 'all':
145 | ssim_v = _ssim(img1, img2, window, self.window_size, channel, self.size_average)
146 | elif self.mode == 'c_s':
147 | ssim_v = _ssim_c_s(img1, img2, window, self.window_size, channel, self.size_average)
148 | else:
149 | ssim_v = _ssim_l(img1, img2, window, self.window_size, channel, self.size_average)
150 | return (1-ssim_v)/2
--------------------------------------------------------------------------------
/applications/iih_mef/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/*
2 | checkpoints/*
3 | results/*
4 | */__pycache__/*
5 | tmp/*
6 | __pycache__/distribute.cpython-37.pyc
7 | __pycache__/options.cpython-37.pyc
8 | __pycache__/train_net.cpython-37.pyc
9 | __pycache__/train.cpython-37.pyc
--------------------------------------------------------------------------------
/applications/iih_mef/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # Multi-Exposure Image Fusion
5 |
6 | Here we provide PyTorch implementation and the pre-trained model of our latest version.
7 |
8 | ## Prerequisites
9 |
10 | - Linux
11 | - Python 3
12 | - CPU or NVIDIA GPU + CUDA CuDNN
13 |
14 | ## Base Model with Guiding
15 | - Download SICE dataset.
16 |
17 | - Train
18 | ```bash
19 | CUDA_VISIBLE_DEVICES=0 python train.py --model iih_base_gd --name base_gd_sice_test --dataset_root --dataset_name mef --batch_size xx --init_port xxxx
20 | ```
21 | - Test
22 | ```bash
23 | CUDA_VISIBLE_DEVICES=0 python test.py --model iih_base_gd --name base_gd_sice_test --dataset_root --dataset_name mef --batch_size xx --init_port xxxx
24 | ```
25 | - Apply pre-trained model
26 |
27 | Download pre-trained model from [Google Drive](https://drive.google.com/file/d/17SIkVhRFW5LTuX2PXDPkVw2IwKWDpO-B/view?usp=sharing) or [BaiduCloud](https://pan.baidu.com/s/1V4ulhcC1eqM6EfVbxRIz1g) (access code: 15vn), and put `latest_net_G.pth` in the directory `checkpoints/base_gd_mef`. Run:
28 | ```bash
29 | CUDA_VISIBLE_DEVICES=0 python test.py --model iih_base_gd --name base_gd_mef --dataset_root --dataset_name mef --batch_size xx --init_port xxxx
30 | ```
--------------------------------------------------------------------------------
/applications/iih_mef/data/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes all the modules related to data loading and preprocessing
2 |
3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4 | You need to implement four functions:
5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6 | -- <__len__>: return the size of dataset.
7 | -- <__getitem__>: get a data point from data loader.
8 | -- : (optionally) add dataset-specific options and set default options.
9 |
10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11 | See our template dataset class 'template_dataset.py' for more details.
12 | """
13 | import importlib
14 | import torch.utils.data
15 | from data.base_dataset import BaseDataset
16 | from torch.utils.data.distributed import DistributedSampler
17 | from torch.utils.data.sampler import RandomSampler
18 |
19 |
20 | def find_dataset_using_name(dataset_name):
21 | """Import the module "data/[dataset_name]_dataset.py".
22 |
23 | In the file, the class called DatasetNameDataset() will
24 | be instantiated. It has to be a subclass of BaseDataset,
25 | and it is case-insensitive.
26 | """
27 | dataset_filename = "data." + dataset_name + "_dataset"
28 | datasetlib = importlib.import_module(dataset_filename)
29 |
30 | dataset = None
31 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
32 | for name, cls in datasetlib.__dict__.items():
33 | if name.lower() == target_dataset_name.lower() \
34 | and issubclass(cls, BaseDataset):
35 | dataset = cls
36 |
37 | if dataset is None:
38 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
39 |
40 | return dataset
41 |
42 |
43 | def get_option_setter(dataset_name):
44 | """Return the static method of the dataset class."""
45 | dataset_class = find_dataset_using_name(dataset_name)
46 | return dataset_class.modify_commandline_options
47 |
48 |
49 | def create_dataset(opt):
50 | """Create a dataset given the option.
51 |
52 | This function wraps the class CustomDatasetDataLoader.
53 | This is the main interface between this package and 'train.py'/'test.py'
54 |
55 | Example:
56 | >>> from data import create_dataset
57 | >>> dataset = create_dataset(opt)
58 | """
59 | # data_loader = CustomDatasetDataLoader(opt)
60 | # dataset = data_loader.load_data()
61 |
62 | dataset_class = find_dataset_using_name(opt.dataset_mode)
63 | dataset = dataset_class(opt)
64 | print("dataset [%s] was created" % type(dataset).__name__)
65 |
66 | # batch_size = int(opt.batch_size / max(1, opt.NUM_GPUS))
67 | if opt.isTrain==True:
68 | shuffle = True
69 | drop_last = True
70 | elif opt.isTrain==False:
71 | shuffle = False
72 | drop_last = False
73 |
74 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) if opt.NUM_GPUS > 1 else None
75 |
76 | # Create a loader
77 | dataloader = torch.utils.data.DataLoader(
78 | dataset,
79 | batch_size=opt.batch_size,
80 | shuffle=(False if sampler else shuffle),
81 | sampler=sampler,
82 | num_workers=int(opt.num_threads),
83 | drop_last=drop_last,
84 | pin_memory=True,
85 | )
86 | return dataloader
87 |
88 |
89 | class CustomDatasetDataLoader():
90 | """Wrapper class of Dataset class that performs multi-threaded data loading"""
91 |
92 | def __init__(self, opt):
93 | """Initialize this class
94 |
95 | Step 1: create a dataset instance given the name [dataset_mode]
96 | Step 2: create a multi-threaded data loader.
97 | """
98 | self.opt = opt
99 | dataset_class = find_dataset_using_name(opt.dataset_mode)
100 | self.dataset = dataset_class(opt)
101 | print("dataset [%s] was created" % type(self.dataset).__name__)
102 |
103 | batch_size = int(opt.batch_size / max(1, opt.NUM_GPUS))
104 | if opt.isTrain==True:
105 | shuffle = True
106 | drop_last = True
107 | elif opt.isTrain==False:
108 | shuffle = False
109 | drop_last = False
110 |
111 | self.sampler = torch.utils.data.distributed.DistributedSampler(self.dataset) if opt.NUM_GPUS > 1 else None
112 |
113 | # Create a loader
114 | self.dataloader = torch.utils.data.DataLoader(
115 | self.dataset,
116 | batch_size=batch_size,
117 | shuffle=(False if self.sampler else shuffle),
118 | sampler=self.sampler,
119 | num_workers=int(opt.num_threads),
120 | drop_last=drop_last,
121 | )
122 |
123 | # self.dataloader = torch.utils.data.DataLoader(
124 | # self.dataset,
125 | # batch_size=opt.batch_size,
126 | # shuffle=not opt.serial_batches,
127 | # num_workers=int(opt.num_threads))
128 |
129 | def load_data(self):
130 | return self
131 |
132 | def __len__(self):
133 | """Return the number of data in the dataset"""
134 | return min(len(self.dataset), self.opt.max_dataset_size)
135 |
136 | # def __iter__(self):
137 | # """Return a batch of data"""
138 | # for i, data in enumerate(self.dataloader):
139 | # if i * self.opt.batch_size >= self.opt.max_dataset_size:
140 | # break
141 | # yield data
142 |
143 | def shuffle_dataset(loader, cur_epoch):
144 | """ "
145 | Shuffles the data.
146 | Args:
147 | loader (loader): data loader to perform shuffle.
148 | cur_epoch (int): number of the current epoch.
149 | """
150 | # sampler = (
151 | # loader.batch_sampler.sampler
152 | # if isinstance(loader.batch_sampler, ShortCycleBatchSampler)
153 | # else loader.sampler
154 | # )
155 | sampler = loader.sampler
156 | assert isinstance(
157 | sampler, (RandomSampler, DistributedSampler)
158 | ), "Sampler type '{}' not supported".format(type(sampler))
159 | # RandomSampler handles shuffling automatically
160 | if isinstance(sampler, DistributedSampler):
161 | # DistributedSampler shuffles data based on epoch
162 | sampler.set_epoch(cur_epoch)
163 |
164 |
165 |
166 |
167 |
168 |
--------------------------------------------------------------------------------
/applications/iih_mef/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2 |
3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4 | """
5 | import random
6 | import numpy as np
7 | import torch.utils.data as data
8 | from PIL import Image
9 | import torchvision.transforms as transforms
10 | from abc import ABC, abstractmethod
11 |
12 |
13 | class BaseDataset(data.Dataset, ABC):
14 | """This class is an abstract base class (ABC) for datasets.
15 |
16 | To create a subclass, you need to implement the following four functions:
17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18 | -- <__len__>: return the size of dataset.
19 | -- <__getitem__>: get a data point.
20 | -- : (optionally) add dataset-specific options and set default options.
21 | """
22 |
23 | def __init__(self, opt):
24 | """Initialize the class; save the options in the class
25 |
26 | Parameters:
27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28 | """
29 | self.opt = opt
30 | # self.root = opt.dataroot
31 | self.root = opt.dataset_root
32 |
33 | @staticmethod
34 | def modify_commandline_options(parser, is_train):
35 | """Add new dataset-specific options, and rewrite default values for existing options.
36 |
37 | Parameters:
38 | parser -- original option parser
39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40 |
41 | Returns:
42 | the modified parser.
43 | """
44 | return parser
45 |
46 | @abstractmethod
47 | def __len__(self):
48 | """Return the total number of images in the dataset."""
49 | return 0
50 |
51 | @abstractmethod
52 | def __getitem__(self, index):
53 | """Return a data point and its metadata information.
54 |
55 | Parameters:
56 | index - - a random integer for data indexing
57 |
58 | Returns:
59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60 | """
61 | pass
62 |
63 |
64 | def get_params(opt, size):
65 | w, h = size
66 | new_h = h
67 | new_w = w
68 | if opt.preprocess == 'resize_and_crop':
69 | new_h = new_w = opt.load_size
70 | elif opt.preprocess == 'scale_width_and_crop':
71 | new_w = opt.load_size
72 | new_h = opt.load_size * h // w
73 |
74 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
75 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
76 |
77 | flip = random.random() > 0.5
78 |
79 | return {'crop_pos': (x, y), 'flip': flip}
80 |
81 |
82 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
83 | transform_list = []
84 | if grayscale:
85 | transform_list.append(transforms.Grayscale(1))
86 | if 'resize' in opt.preprocess:
87 | osize = [opt.load_size, opt.load_size]
88 | transform_list.append(transforms.Resize(osize, method))
89 | elif 'scale_width' in opt.preprocess:
90 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
91 |
92 | if 'crop' in opt.preprocess:
93 | if params is None:
94 | transform_list.append(transforms.RandomCrop(opt.crop_size))
95 | else:
96 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
97 |
98 | if opt.preprocess == 'none':
99 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
100 |
101 | if not opt.no_flip:
102 | if params is None:
103 | transform_list.append(transforms.RandomHorizontalFlip())
104 | elif params['flip']:
105 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
106 |
107 | if convert:
108 | transform_list += [transforms.ToTensor()]
109 | if grayscale:
110 | transform_list += [transforms.Normalize((0.5,), (0.5,))]
111 | else:
112 | # transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
113 | transform_list += [transforms.Normalize((0, 0, 0), (1, 1, 1))]
114 | return transforms.Compose(transform_list)
115 |
116 |
117 | def __make_power_2(img, base, method=Image.BICUBIC):
118 | ow, oh = img.size
119 | h = int(round(oh / base) * base)
120 | w = int(round(ow / base) * base)
121 | if (h == oh) and (w == ow):
122 | return img
123 |
124 | __print_size_warning(ow, oh, w, h)
125 | return img.resize((w, h), method)
126 |
127 |
128 | def __scale_width(img, target_width, method=Image.BICUBIC):
129 | ow, oh = img.size
130 | if (ow == target_width):
131 | return img
132 | w = target_width
133 | h = int(target_width * oh / ow)
134 | return img.resize((w, h), method)
135 |
136 |
137 | def __crop(img, pos, size):
138 | ow, oh = img.size
139 | x1, y1 = pos
140 | tw = th = size
141 | if (ow > tw or oh > th):
142 | return img.crop((x1, y1, x1 + tw, y1 + th))
143 | return img
144 |
145 |
146 | def __flip(img, flip):
147 | if flip:
148 | return img.transpose(Image.FLIP_LEFT_RIGHT)
149 | return img
150 |
151 |
152 | def __print_size_warning(ow, oh, w, h):
153 | """Print warning information about image size(only print once)"""
154 | if not hasattr(__print_size_warning, 'has_printed'):
155 | print("The image size needs to be a multiple of 4. "
156 | "The loaded image size was (%d, %d), so it was adjusted to "
157 | "(%d, %d). This adjustment will be done to all images "
158 | "whose sizes are not multiples of 4" % (ow, oh, w, h))
159 | __print_size_warning.has_printed = True
160 |
--------------------------------------------------------------------------------
/applications/iih_mef/data/image_folder.py:
--------------------------------------------------------------------------------
1 | """A modified image folder class
2 |
3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4 | so that this class can load images from both current directory and its subdirectories.
5 | """
6 |
7 | import torch.utils.data as data
8 |
9 | from PIL import Image
10 | import os
11 | import os.path
12 |
13 | IMG_EXTENSIONS = [
14 | '.jpg', '.JPG', '.jpeg', '.JPEG',
15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16 | ]
17 |
18 |
19 | def is_image_file(filename):
20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21 |
22 |
23 | def make_dataset(dir, max_dataset_size=float("inf")):
24 | images = []
25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
26 |
27 | for root, _, fnames in sorted(os.walk(dir)):
28 | for fname in fnames:
29 | if is_image_file(fname):
30 | path = os.path.join(root, fname)
31 | images.append(path)
32 | return images[:min(max_dataset_size, len(images))]
33 |
34 |
35 | def default_loader(path):
36 | return Image.open(path).convert('RGB')
37 |
38 |
39 | class ImageFolder(data.Dataset):
40 |
41 | def __init__(self, root, transform=None, return_paths=False,
42 | loader=default_loader):
43 | imgs = make_dataset(root)
44 | if len(imgs) == 0:
45 | raise(RuntimeError("Found 0 images in: " + root + "\n"
46 | "Supported image extensions are: " +
47 | ",".join(IMG_EXTENSIONS)))
48 |
49 | self.root = root
50 | self.imgs = imgs
51 | self.transform = transform
52 | self.return_paths = return_paths
53 | self.loader = loader
54 |
55 | def __getitem__(self, index):
56 | path = self.imgs[index]
57 | img = self.loader(path)
58 | if self.transform is not None:
59 | img = self.transform(img)
60 | if self.return_paths:
61 | return img, path
62 | else:
63 | return img
64 |
65 | def __len__(self):
66 | return len(self.imgs)
67 |
--------------------------------------------------------------------------------
/applications/iih_mef/data/loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | """Data loader."""
5 |
6 | import itertools
7 | import numpy as np
8 | import torch
9 | from torch.utils.data._utils.collate import default_collate
10 | from torch.utils.data.distributed import DistributedSampler
11 | from torch.utils.data.sampler import RandomSampler
12 |
13 | from slowfast.datasets.multigrid_helper import ShortCycleBatchSampler
14 |
15 | from . import utils as utils
16 |
17 | def build_dataset(cfg):
18 | image_paths = []
19 | if cfg.phase == 'train':
20 | print('loading training file')
21 | file = cfg.dataset_root+cfg.dataset_name+'_train.txt'
22 | with open(file,'r') as f:
23 | for line in f.readlines():
24 | image_paths.append(os.path.join(cfg.dataset_root,'composite_images',line.rstrip()))
25 |
26 |
27 | def construct_loader(cfg, split, is_precise_bn=False):
28 | """
29 | Constructs the data loader for the given dataset.
30 | Args:
31 | cfg (CfgNode): configs. Details can be found in
32 | slowfast/config/defaults.py
33 | split (str): the split of the data loader. Options include `train`,
34 | `val`, and `test`.
35 | """
36 | assert split in ["train", "val", "test"]
37 | if split in ["train"]:
38 | dataset_name = cfg.TRAIN.DATASET
39 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
40 | shuffle = True
41 | drop_last = True
42 | elif split in ["val"]:
43 | dataset_name = cfg.TRAIN.DATASET
44 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
45 | shuffle = False
46 | drop_last = False
47 | elif split in ["test"]:
48 | dataset_name = cfg.TEST.DATASET
49 | batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS))
50 | shuffle = False
51 | drop_last = False
52 |
53 | # Construct the dataset
54 | dataset = build_dataset(dataset_name, cfg, split)
55 |
56 | if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn:
57 | # Create a sampler for multi-process training
58 | sampler = utils.create_sampler(dataset, shuffle, cfg)
59 | batch_sampler = ShortCycleBatchSampler(
60 | sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg
61 | )
62 | # Create a loader
63 | loader = torch.utils.data.DataLoader(
64 | dataset,
65 | batch_sampler=batch_sampler,
66 | num_workers=cfg.DATA_LOADER.NUM_WORKERS,
67 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
68 | worker_init_fn=utils.loader_worker_init_fn(dataset),
69 | )
70 | else:
71 | # Create a sampler for multi-process training
72 | sampler = utils.create_sampler(dataset, shuffle, cfg)
73 | # Create a loader
74 | loader = torch.utils.data.DataLoader(
75 | dataset,
76 | batch_size=batch_size,
77 | shuffle=(False if sampler else shuffle),
78 | sampler=sampler,
79 | num_workers=cfg.DATA_LOADER.NUM_WORKERS,
80 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
81 | drop_last=drop_last,
82 | collate_fn=detection_collate if cfg.DETECTION.ENABLE else None,
83 | worker_init_fn=utils.loader_worker_init_fn(dataset),
84 | )
85 | return loader
86 |
87 |
88 | def shuffle_dataset(loader, cur_epoch):
89 | """ "
90 | Shuffles the data.
91 | Args:
92 | loader (loader): data loader to perform shuffle.
93 | cur_epoch (int): number of the current epoch.
94 | """
95 | sampler = (
96 | loader.batch_sampler.sampler
97 | if isinstance(loader.batch_sampler, ShortCycleBatchSampler)
98 | else loader.sampler
99 | )
100 | assert isinstance(
101 | sampler, (RandomSampler, DistributedSampler)
102 | ), "Sampler type '{}' not supported".format(type(sampler))
103 | # RandomSampler handles shuffling automatically
104 | if isinstance(sampler, DistributedSampler):
105 | # DistributedSampler shuffles data based on epoch
106 | sampler.set_epoch(cur_epoch)
107 |
--------------------------------------------------------------------------------
/applications/iih_mef/data/mef_dataset.py:
--------------------------------------------------------------------------------
1 |
2 | import os.path
3 | import torch
4 | import torchvision.transforms.functional as tf
5 | import torch.nn.functional as F
6 | import random
7 | from torchvision.transforms.transforms import RandomCrop, RandomResizedCrop
8 | from data.base_dataset import BaseDataset, get_transform
9 | from data.image_folder import make_dataset
10 | from PIL import Image
11 | import numpy as np
12 | import torchvision.transforms as transforms
13 | from util import util
14 |
15 | class MefDataset(BaseDataset):
16 | @staticmethod
17 | def modify_commandline_options(parser, is_train):
18 | parser.add_argument('--is_train', type=bool, default=True, help='whether in the training phase')
19 | parser.set_defaults(max_dataset_size=float("inf"), new_dataset_option=2.0) # specify dataset-specific default values
20 | return parser
21 |
22 | def __init__(self, opt):
23 | # save the option and dataset root
24 | BaseDataset.__init__(self, opt)
25 | self.fake_image_paths = []
26 | self.image_paths = []
27 | self.isTrain = opt.isTrain
28 | self.image_size = opt.crop_size
29 |
30 | if opt.isTrain==True:
31 | print('loading training file')
32 | self.trainfile = opt.dataset_root+'Dataset_Part1_resize/'+'part1_train.txt'
33 | with open(self.trainfile,'r') as f:
34 | for line in f.readlines():
35 | name = line.rstrip().split('.')
36 | self.image_paths.append(os.path.join(opt.dataset_root,'Dataset_Part1_resize/',name[0]))
37 | self.trainfile = opt.dataset_root+'Dataset_Part2_resize/'+'part2_train.txt'
38 | with open(self.trainfile,'r') as f:
39 | for line in f.readlines():
40 | name = line.rstrip().split('.')
41 | self.image_paths.append(os.path.join(opt.dataset_root,'Dataset_Part2_resize/',name[0]))
42 | elif opt.isTrain==False:
43 | print('loading test file')
44 | self.trainfile = opt.dataset_root+'Dataset_Part1_resize/'+'part1_test.txt'
45 | with open(self.trainfile,'r') as f:
46 | for line in f.readlines():
47 | name = line.rstrip().split('.')
48 | self.image_paths.append(os.path.join(opt.dataset_root,'Dataset_Part1_resize/',name[0]))
49 | self.trainfile = opt.dataset_root+'Dataset_Part2_resize/'+'part2_test.txt'
50 | with open(self.trainfile,'r') as f:
51 | for line in f.readlines():
52 | name = line.rstrip().split('.')
53 | self.image_paths.append(os.path.join(opt.dataset_root,'Dataset_Part2_resize/',name[0]))
54 | transform_list = [
55 | # transforms.RandomCrop(self.image_size),
56 | transforms.ToTensor(),
57 | transforms.Normalize((0, 0, 0), (1, 1, 1))
58 | ]
59 | self.transforms = transforms.Compose(transform_list)
60 | def __getitem__(self, index):
61 | path = self.image_paths[index]
62 | files = os.listdir(path)
63 |
64 | files.sort(key= lambda x:int(x[:-4]))
65 | if self.isTrain:
66 | max_file = files[-1]
67 | min_file = files[0]
68 | else:
69 | max_file = files[-1]
70 | min_file = files[0]
71 |
72 | u_path = os.path.join(path,min_file)
73 | o_path = os.path.join(path,max_file)
74 | file_name_path = path+".JPG"
75 | name_parts=file_name_path.split('/')
76 | target_path = file_name_path.replace(name_parts[-1],'Label/'+name_parts[-1])
77 | if not os.path.exists(target_path):
78 | target_path = target_path.replace(".JPG",".PNG")
79 | fake_u = Image.open(u_path).convert('RGB')
80 | fake_o = Image.open(o_path).convert('RGB')
81 | real = Image.open(target_path).convert('RGB')
82 | if np.random.rand() > 0.5 and self.isTrain:
83 | fake_u, fake_o, real = tf.hflip(fake_u), tf.hflip(fake_o), tf.hflip(real)
84 | fake_u = self.transforms(fake_u)
85 | fake_o = self.transforms(fake_o)
86 | real = self.transforms(real)
87 |
88 | return {'fake_u': fake_u, 'fake_o': fake_o, 'real': real,'img_path':path}
89 |
90 | def __len__(self):
91 | """Return the total number of images."""
92 | return len(self.image_paths)
--------------------------------------------------------------------------------
/applications/iih_mef/models/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to objective functions, optimizations, and network architectures.
2 |
3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4 | You need to implement the following five functions:
5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6 | -- : unpack data from dataset and apply preprocessing.
7 | -- : produce intermediate results.
8 | -- : calculate loss, gradients, and update network weights.
9 | -- : (optionally) add model-specific options and set default options.
10 |
11 | In the function <__init__>, you need to define four lists:
12 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
13 | -- self.model_names (str list): define networks used in our training.
14 | -- self.visual_names (str list): specify the images that you want to display and save.
15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16 |
17 | Now you can use the model class by specifying flag '--model dummy'.
18 | See our template model class 'template_model.py' for more details.
19 | """
20 |
21 | import importlib
22 | from models.base_model import BaseModel
23 |
24 |
25 | def find_model_using_name(model_name):
26 | """Import the module "models/[model_name]_model.py".
27 |
28 | In the file, the class called DatasetNameModel() will
29 | be instantiated. It has to be a subclass of BaseModel,
30 | and it is case-insensitive.
31 | """
32 | model_filename = "models." + model_name + "_model"
33 | modellib = importlib.import_module(model_filename)
34 | model = None
35 | target_model_name = model_name.replace('_', '') + 'model'
36 | for name, cls in modellib.__dict__.items():
37 | if name.lower() == target_model_name.lower() \
38 | and issubclass(cls, BaseModel):
39 | model = cls
40 |
41 | if model is None:
42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43 | exit(0)
44 |
45 | return model
46 |
47 |
48 | def get_option_setter(model_name):
49 | """Return the static method of the model class."""
50 | model_class = find_model_using_name(model_name)
51 | return model_class.modify_commandline_options
52 |
53 |
54 | def create_model(opt):
55 | """Create a model given the option.
56 |
57 | This function warps the class CustomDatasetDataLoader.
58 | This is the main interface between this package and 'train.py'/'test.py'
59 |
60 | Example:
61 | >>> from models import create_model
62 | >>> model = create_model(opt)
63 | """
64 | model = find_model_using_name(opt.model)
65 | instance = model(opt)
66 | print("model [%s] was created" % type(instance).__name__)
67 | return instance
68 |
--------------------------------------------------------------------------------
/applications/iih_mef/models/iih_base_gd_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import itertools
4 | import torch.nn.functional as F
5 | from util import distributed as du
6 | # import pytorch_colors as colors
7 | from .base_model import BaseModel
8 | from util import util
9 | from . import harmony_networks as networks
10 | import util.ssim as ssim
11 |
12 |
13 | class IIHBaseGDModel(BaseModel):
14 | @staticmethod
15 | def modify_commandline_options(parser, is_train=True):
16 | parser.set_defaults(norm='instance', netG='base_gd', dataset_mode='mef')
17 | if is_train:
18 | parser.add_argument('--lambda_L1', type=float, default=50.0, help='weight for L1 loss')
19 | parser.add_argument('--lambda_R', type=float, default=100., help='weight for R gradient loss')
20 | parser.add_argument('--lambda_ssim', type=float, default=50., help='weight for L L2 loss')
21 | parser.add_argument('--lambda_ifm', type=float, default=100, help='weight for pm loss')
22 |
23 | return parser
24 |
25 | def __init__(self, opt):
26 | BaseModel.__init__(self, opt)
27 | self.opt = opt
28 | # specify the training losses you want to print out. The training/test scripts will call
29 | self.loss_names = ['G','G_L1','G_R','G_R_SSIM',"IF"]
30 |
31 | self.visual_names = ['hdr','real','fake_u','fake_o']
32 | self.model_names = ['G']
33 | self.opt.device = self.device
34 | self.netG = networks.define_G(opt.netG, opt.init_type, opt.init_gain, self.opt)
35 | self.cur_device = torch.cuda.current_device()
36 | self.ismaster = du.is_master_proc(opt.NUM_GPUS)
37 | print(self.netG)
38 |
39 | if self.isTrain:
40 | util.saveprint(self.opt, 'netG', str(self.netG))
41 | # define loss functions
42 | self.criterionL1 = torch.nn.L1Loss()
43 | self.criterionL2 = torch.nn.MSELoss()
44 | self.criterionSSIM = ssim.SSIM()
45 | self.criterionDSSIM_CS = ssim.DSSIM(mode='c_s').to(self.device)
46 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
47 | self.optimizers.append(self.optimizer_G)
48 |
49 | def set_input(self, input):
50 | self.fake_u = input['fake_u'].to(self.device)
51 | self.fake_o = input['fake_o'].to(self.device)
52 | self.real = input['real'].to(self.device)
53 | self.image_paths = input['img_path']
54 | self.real_r = F.interpolate(self.real, size=[32,32])
55 | self.real_gray = util.rgbtogray(self.real_r)
56 | def forward(self):
57 | self.reconstruct_u, self.reconstruct_o, self.hdr, self.ifm_mean = self.netG(self.fake_u, self.fake_o)
58 | def backward_G(self):
59 | self.loss_IF = (self.criterionDSSIM_CS(self.ifm_mean, self.real_gray))*self.opt.lambda_ifm
60 |
61 | self.loss_G_L1 = (self.criterionL1(self.reconstruct_u, self.fake_u)+self.criterionL1(self.reconstruct_o, self.fake_o))*self.opt.lambda_L1
62 | self.loss_G_R = self.criterionL2(self.hdr, self.real)*self.opt.lambda_R
63 | self.loss_G_R_SSIM = (1-self.criterionSSIM(self.hdr, self.real))*self.opt.lambda_ssim
64 | self.loss_G = self.loss_G_L1 + self.loss_G_R + self.loss_G_R_SSIM + self.loss_IF
65 | self.loss_G.backward()
66 |
67 | def optimize_parameters(self):
68 | self.forward() # compute fake images: G(A)
69 | # update G
70 | self.optimizer_G.zero_grad() # set G's gradients to zero
71 | self.backward_G() # calculate graidents for G
72 | self.optimizer_G.step() # udpate G's weights
73 |
--------------------------------------------------------------------------------
/applications/iih_mef/options/__init__.py:
--------------------------------------------------------------------------------
1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
2 |
--------------------------------------------------------------------------------
/applications/iih_mef/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | """This class includes test options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser) # define shared options
12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
16 | # Dropout and Batchnorm has different behavioir during training and test.
17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
18 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
19 | parser.add_argument('--test_epoch', type=str, default="0", help='how many test images to run')
20 | # rewrite devalue values
21 | parser.set_defaults(model='test')
22 | # To avoid cropping, the load_size should be the same as crop_size
23 | parser.set_defaults(load_size=parser.get_default('crop_size'))
24 | self.isTrain = False
25 | return parser
26 |
--------------------------------------------------------------------------------
/applications/iih_mef/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | """This class includes training options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser)
12 | # visdom and HTML visualization parameters
13 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
14 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
15 | parser.add_argument('--display_id', type=int, default=0, help='window id of the web display')
16 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
17 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
18 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
19 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
20 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
21 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
22 | # network saving and loading parameters
23 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
24 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
25 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
26 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
27 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
28 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
29 | # training parameters
30 | parser.add_argument('--niter', type=int, default=300, help='# of iter at starting learning rate')
31 | parser.add_argument('--niter_decay', type=int, default=300, help='# of iter to linearly decay learning rate to zero')
32 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
33 | parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
34 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
35 | parser.add_argument('--g_lr_ratio', type=float, default=1.0, help='a ratio for changing learning rate of generator')
36 | parser.add_argument('--d_lr_ratio', type=float, default=1.0, help='a ratio for changing learning rate of discriminator')
37 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
38 | parser.add_argument('--pool_size', type=int, default=40, help='the size of image buffer that stores previously generated images')
39 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
40 | parser.add_argument('--lr_decay_iters', type=int, default=40, help='multiply by a gamma every lr_decay_iters iterations')
41 | parser.add_argument('--save_iter_model', action='store_true', help='whether saves model by iteration')
42 |
43 |
44 |
45 |
46 | self.isTrain = True
47 | return parser
48 |
--------------------------------------------------------------------------------
/applications/iih_mef/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 | from util.misc import launch_job
5 | from train_net import test
6 |
7 | from options.test_options import TestOptions
8 |
9 | def main():
10 | cfg = TestOptions().parse() # get training options
11 | cfg.NUM_GPUS = torch.cuda.device_count()
12 | cfg.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
13 | cfg.no_flip = True # no flip; comment this line if results on flipped images are needed.
14 | cfg.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
15 |
16 | cfg.phase = 'test'
17 | cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS))
18 | launch_job(cfg=cfg, init_method=cfg.init_method, func=test)
19 |
20 |
21 | if __name__=="__main__":
22 | main()
--------------------------------------------------------------------------------
/applications/iih_mef/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 | from util.misc import launch_job
5 | from train_net import train
6 |
7 | from options.train_options import TrainOptions
8 |
9 | def main():
10 | cfg = TrainOptions().parse() # get training options
11 | cfg.NUM_GPUS = torch.cuda.device_count()
12 | cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS))
13 | cfg.phase = 'train'
14 | launch_job(cfg=cfg, init_method=cfg.init_method, func=train)
15 |
16 |
17 | if __name__=="__main__":
18 | main()
--------------------------------------------------------------------------------
/applications/iih_mef/util/evaluation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as f
3 |
4 |
5 | def evaluation(name, fake, real, mask):
6 | b,c,w,h = real.size()
7 | mse_score = f.mse_loss(fake, real)
8 | fore_area = torch.sum(mask)
9 | fmse_score = f.mse_loss(fake*mask,real*mask)*w*h/fore_area
10 | mse_score = mse_score.item()
11 | fmse_score = fmse_score.item()
12 | # score_str = "%s MSE %0.2f | fMSE %0.2f" % (name, mse_score,fmse_score)
13 | image_fmse_info = (name, round(fmse_score,2), round(mse_score, 2))
14 | return mse_score, fmse_score, image_fmse_info
--------------------------------------------------------------------------------
/applications/iih_mef/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3 | import os
4 |
5 |
6 | class HTML:
7 | """This HTML class allows us to save images and write texts into a single HTML file.
8 |
9 | It consists of functions such as (add a text header to the HTML file),
10 | (add a row of images to the HTML file), and (save the HTML to the disk).
11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12 | """
13 |
14 | def __init__(self, web_dir, title, refresh=0):
15 | """Initialize the HTML classes
16 |
17 | Parameters:
18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
32 | with self.doc.head:
33 | meta(http_equiv="refresh", content=str(refresh))
34 |
35 | def get_image_dir(self):
36 | """Return the directory that stores images"""
37 | return self.img_dir
38 |
39 | def add_header(self, text):
40 | """Insert a header to the HTML file
41 |
42 | Parameters:
43 | text (str) -- the header text
44 | """
45 | with self.doc:
46 | h3(text)
47 |
48 | def add_images(self, ims, txts, links, width=400):
49 | """add images to the HTML file
50 |
51 | Parameters:
52 | ims (str list) -- a list of image paths
53 | txts (str list) -- a list of image names shown on the website
54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55 | """
56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57 | self.doc.add(self.t)
58 | with self.t:
59 | with tr():
60 | for im, txt, link in zip(ims, txts, links):
61 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
62 | with p():
63 | with a(href=os.path.join('images', link)):
64 | img(style="width:%dpx" % width, src=os.path.join('images', im))
65 | br()
66 | p(txt)
67 |
68 | def save(self):
69 | """save the current content to the HMTL file"""
70 | html_file = '%s/index.html' % self.web_dir
71 | f = open(html_file, 'wt')
72 | f.write(self.doc.render())
73 | f.close()
74 |
75 |
76 | if __name__ == '__main__': # we show an example usage here.
77 | html = HTML('web/', 'test_html')
78 | html.add_header('hello world')
79 |
80 | ims, txts, links = [], [], []
81 | for n in range(4):
82 | ims.append('image_%d.png' % n)
83 | txts.append('text_%d' % n)
84 | links.append('image_%d.png' % n)
85 | html.add_images(ims, txts, links)
86 | html.save()
87 |
--------------------------------------------------------------------------------
/applications/iih_mef/util/multiprocessing.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | """Multiprocessing helpers."""
5 |
6 | import torch
7 |
8 |
9 | def run(
10 | local_rank,
11 | num_proc,
12 | func,
13 | init_method,
14 | shard_id,
15 | num_shards,
16 | backend,
17 | cfg,
18 | output_queue=None,
19 | ):
20 | """
21 | Runs a function from a child process.
22 | Args:
23 | local_rank (int): rank of the current process on the current machine.
24 | num_proc (int): number of processes per machine.
25 | func (function): function to execute on each of the process.
26 | init_method (string): method to initialize the distributed training.
27 | TCP initialization: equiring a network address reachable from all
28 | processes followed by the port.
29 | Shared file-system initialization: makes use of a file system that
30 | is shared and visible from all machines. The URL should start with
31 | file:// and contain a path to a non-existent file on a shared file
32 | system.
33 | shard_id (int): the rank of the current machine.
34 | num_shards (int): number of overall machines for the distributed
35 | training job.
36 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are
37 | supports, each with different capabilities. Details can be found
38 | here:
39 | https://pytorch.org/docs/stable/distributed.html
40 | cfg (CfgNode): configs. Details can be found in
41 | slowfast/config/defaults.py
42 | output_queue (queue): can optionally be used to return values from the
43 | master process.
44 | """
45 | # Initialize the process group.
46 | world_size = num_proc * num_shards
47 | rank = shard_id * num_proc + local_rank
48 | try:
49 | torch.distributed.init_process_group(
50 | backend=backend,
51 | init_method=init_method,
52 | world_size=world_size,
53 | rank=rank,
54 | )
55 |
56 | except Exception as e:
57 | raise e
58 |
59 | torch.cuda.set_device(local_rank)
60 | ret = func(cfg)
61 | if output_queue is not None and local_rank == 0:
62 | output_queue.put(ret)
63 |
--------------------------------------------------------------------------------
/applications/iih_mef/util/ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 | def create_window(window_size, channel):
12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15 | return window
16 |
17 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
20 |
21 | mu1_sq = mu1.pow(2)
22 | mu2_sq = mu2.pow(2)
23 | mu1_mu2 = mu1*mu2
24 |
25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
28 |
29 | C1 = 0.01**2
30 | C2 = 0.03**2
31 |
32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
33 |
34 | if size_average:
35 | return ssim_map.mean()
36 | else:
37 | return ssim_map.mean(1).mean(1).mean(1)
38 |
39 | def _ssim_c_s(img1, img2, window, window_size, channel, size_average = True):
40 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
41 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
42 |
43 | mu1_sq = mu1.pow(2)
44 | mu2_sq = mu2.pow(2)
45 | mu1_mu2 = mu1*mu2
46 |
47 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
48 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
49 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
50 |
51 | C1 = 0.01**2
52 | C2 = 0.03**2
53 |
54 | ssim_map = ((2*sigma12 + C2))/((sigma1_sq + sigma2_sq + C2))
55 |
56 | if size_average:
57 | return ssim_map.mean()
58 | else:
59 | return ssim_map.mean(1).mean(1).mean(1)
60 | def _ssim_l(img1, img2, window, window_size, channel, size_average = True):
61 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
62 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
63 |
64 | mu1_sq = mu1.pow(2)
65 | mu2_sq = mu2.pow(2)
66 | mu1_mu2 = mu1*mu2
67 |
68 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
69 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
70 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
71 |
72 | C1 = 0.01**2
73 | C2 = 0.03**2
74 |
75 | ssim_map = (2*mu1_mu2 + C1)/(mu1_sq + mu2_sq + C1)
76 |
77 | if size_average:
78 | return ssim_map.mean()
79 | else:
80 | return ssim_map.mean(1).mean(1).mean(1)
81 |
82 |
83 | def ssim(img1, img2, window_size = 11, size_average = True):
84 | (_, channel, _, _) = img1.size()
85 | window = create_window(window_size, channel)
86 |
87 | if img1.is_cuda:
88 | window = window.cuda(img1.get_device())
89 | window = window.type_as(img1)
90 |
91 | return _ssim(img1, img2, window, window_size, channel, size_average)
92 |
93 | class SSIM(torch.nn.Module):
94 | def __init__(self, window_size = 11, size_average = True, mode='all'):
95 | super(SSIM, self).__init__()
96 | self.window_size = window_size
97 | self.size_average = size_average
98 | self.channel = 1
99 | self.window = create_window(window_size, self.channel)
100 | self.mode = mode
101 | def forward(self, img1, img2):
102 | (_, channel, _, _) = img1.size()
103 |
104 | if channel == self.channel and self.window.data.type() == img1.data.type():
105 | window = self.window
106 | else:
107 | window = create_window(self.window_size, channel)
108 |
109 | # if img1.is_cuda:
110 | # window = window.cuda(img1.get_device())
111 | window = window.type_as(img1)
112 |
113 | self.window = window
114 | self.channel = channel
115 | if self.mode == 'all':
116 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
117 | elif self.mode == 'c_s':
118 | return _ssim_c_s(img1, img2, window, self.window_size, channel, self.size_average)
119 | else:
120 | return _ssim_l(img1, img2, window, self.window_size, channel, self.size_average)
121 |
122 | class DSSIM(torch.nn.Module):
123 | def __init__(self, window_size = 11, size_average = True, mode='all'):
124 | super(DSSIM, self).__init__()
125 | self.window_size = window_size
126 | self.size_average = size_average
127 | self.channel = 1
128 | self.window = create_window(window_size, self.channel)
129 | self.mode = mode
130 | def forward(self, img1, img2):
131 | (_, channel, _, _) = img1.size()
132 |
133 | if channel == self.channel and self.window.data.type() == img1.data.type():
134 | window = self.window
135 | else:
136 | window = create_window(self.window_size, channel)
137 |
138 | # if img1.is_cuda:
139 | # window = window.cuda(img1.get_device())
140 | window = window.type_as(img1)
141 |
142 | self.window = window
143 | self.channel = channel
144 | if self.mode == 'all':
145 | ssim_v = _ssim(img1, img2, window, self.window_size, channel, self.size_average)
146 | elif self.mode == 'c_s':
147 | ssim_v = _ssim_c_s(img1, img2, window, self.window_size, channel, self.size_average)
148 | else:
149 | ssim_v = _ssim_l(img1, img2, window, self.window_size, channel, self.size_average)
150 | return (1-ssim_v)/2
--------------------------------------------------------------------------------
/applications/iih_relighting/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/*
2 | checkpoints/*
3 | results/*
4 | */__pycache__/*
5 | tmp/*
6 | __pycache__/distribute.cpython-37.pyc
7 | __pycache__/options.cpython-37.pyc
8 | __pycache__/train_net.cpython-37.pyc
9 | __pycache__/train.cpython-37.pyc
--------------------------------------------------------------------------------
/applications/iih_relighting/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | # Portrait Relighting
5 |
6 | Here we provide PyTorch implementation and the pre-trained model of our latest version.
7 |
8 | ## Prerequisites
9 |
10 | - Linux
11 | - Python 3
12 | - CPU or NVIDIA GPU + CUDA CuDNN
13 |
14 | ## Base Model with Lighting
15 | - Download DPR dataset.
16 |
17 | - Train
18 | ```bash
19 | CUDA_VISIBLE_DEVICES=0 python train.py --model iih_base_lt --name base_lt_relighting_test --dataset_root --dataset_name DPR --batch_size xx --init_port xxxx
20 | ```
21 | - Test
22 | ```bash
23 | # SH-based relighting
24 | CUDA_VISIBLE_DEVICES=0 python test.py --model iih_base_lt --name base_lt_relighting_test --dataset_root --dataset_name DPR --batch_size xx --init_port xxxx
25 | #Image-based relighting
26 | CUDA_VISIBLE_DEVICES=0 python test.py --model iih_base_lt --name base_lt_relighting_test --relighting_action transfer --dataset_root --dataset_name DPR --dataset_mode dprtransfer --batch_size xx --init_port xxxx
27 | ```
28 |
29 | - Apply pre-trained model
30 |
31 | Download pre-trained model from [Google Drive](https://drive.google.com/file/d/11yGZvo-gLDRyfnO0A6xuqPmDaPcMB1en/view?usp=sharing) or [BaiduCloud](https://pan.baidu.com/s/1yrUZ2YkT2bY9ThfYn_gJAg) (access code: bjqb), and put `latest_net_G.pth` in the directory `checkpoints/base_lt_relighting`. Run:
32 |
33 | ```bash
34 | # SH-based relighting
35 | CUDA_VISIBLE_DEVICES=0 python test.py --model iih_base_lt --name base_lt_relighting --dataset_root --dataset_name DPR --batch_size xx --init_port xxxx
36 | #Image-based relighting
37 | CUDA_VISIBLE_DEVICES=0 python test.py --model iih_base_lt --name base_lt_relighting --relighting_action transfer --dataset_root --dataset_name DPR --dataset_mode dprtransfer --batch_size xx --init_port xxxx
38 | ```
39 |
--------------------------------------------------------------------------------
/applications/iih_relighting/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2 |
3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4 | """
5 | import random
6 | import numpy as np
7 | import torch.utils.data as data
8 | from PIL import Image
9 | import torchvision.transforms as transforms
10 | from abc import ABC, abstractmethod
11 |
12 |
13 | class BaseDataset(data.Dataset, ABC):
14 | """This class is an abstract base class (ABC) for datasets.
15 |
16 | To create a subclass, you need to implement the following four functions:
17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18 | -- <__len__>: return the size of dataset.
19 | -- <__getitem__>: get a data point.
20 | -- : (optionally) add dataset-specific options and set default options.
21 | """
22 |
23 | def __init__(self, opt):
24 | """Initialize the class; save the options in the class
25 |
26 | Parameters:
27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28 | """
29 | self.opt = opt
30 | # self.root = opt.dataroot
31 | self.root = opt.dataset_root
32 |
33 | @staticmethod
34 | def modify_commandline_options(parser, is_train):
35 | """Add new dataset-specific options, and rewrite default values for existing options.
36 |
37 | Parameters:
38 | parser -- original option parser
39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40 |
41 | Returns:
42 | the modified parser.
43 | """
44 | return parser
45 |
46 | @abstractmethod
47 | def __len__(self):
48 | """Return the total number of images in the dataset."""
49 | return 0
50 |
51 | @abstractmethod
52 | def __getitem__(self, index):
53 | """Return a data point and its metadata information.
54 |
55 | Parameters:
56 | index - - a random integer for data indexing
57 |
58 | Returns:
59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60 | """
61 | pass
62 |
63 |
64 | def get_params(opt, size):
65 | w, h = size
66 | new_h = h
67 | new_w = w
68 | if opt.preprocess == 'resize_and_crop':
69 | new_h = new_w = opt.load_size
70 | elif opt.preprocess == 'scale_width_and_crop':
71 | new_w = opt.load_size
72 | new_h = opt.load_size * h // w
73 |
74 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
75 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
76 |
77 | flip = random.random() > 0.5
78 |
79 | return {'crop_pos': (x, y), 'flip': flip}
80 |
81 |
82 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
83 | transform_list = []
84 | if grayscale:
85 | transform_list.append(transforms.Grayscale(1))
86 | if 'resize' in opt.preprocess:
87 | osize = [opt.load_size, opt.load_size]
88 | transform_list.append(transforms.Resize(osize, method))
89 | elif 'scale_width' in opt.preprocess:
90 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
91 |
92 | if 'crop' in opt.preprocess:
93 | if params is None:
94 | transform_list.append(transforms.RandomCrop(opt.crop_size))
95 | else:
96 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
97 |
98 | if opt.preprocess == 'none':
99 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
100 |
101 | if not opt.no_flip:
102 | if params is None:
103 | transform_list.append(transforms.RandomHorizontalFlip())
104 | elif params['flip']:
105 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
106 |
107 | if convert:
108 | transform_list += [transforms.ToTensor()]
109 | if grayscale:
110 | transform_list += [transforms.Normalize((0.5,), (0.5,))]
111 | else:
112 | # transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
113 | transform_list += [transforms.Normalize((0, 0, 0), (1, 1, 1))]
114 | return transforms.Compose(transform_list)
115 |
116 |
117 | def __make_power_2(img, base, method=Image.BICUBIC):
118 | ow, oh = img.size
119 | h = int(round(oh / base) * base)
120 | w = int(round(ow / base) * base)
121 | if (h == oh) and (w == ow):
122 | return img
123 |
124 | __print_size_warning(ow, oh, w, h)
125 | return img.resize((w, h), method)
126 |
127 |
128 | def __scale_width(img, target_width, method=Image.BICUBIC):
129 | ow, oh = img.size
130 | if (ow == target_width):
131 | return img
132 | w = target_width
133 | h = int(target_width * oh / ow)
134 | return img.resize((w, h), method)
135 |
136 |
137 | def __crop(img, pos, size):
138 | ow, oh = img.size
139 | x1, y1 = pos
140 | tw = th = size
141 | if (ow > tw or oh > th):
142 | return img.crop((x1, y1, x1 + tw, y1 + th))
143 | return img
144 |
145 |
146 | def __flip(img, flip):
147 | if flip:
148 | return img.transpose(Image.FLIP_LEFT_RIGHT)
149 | return img
150 |
151 |
152 | def __print_size_warning(ow, oh, w, h):
153 | """Print warning information about image size(only print once)"""
154 | if not hasattr(__print_size_warning, 'has_printed'):
155 | print("The image size needs to be a multiple of 4. "
156 | "The loaded image size was (%d, %d), so it was adjusted to "
157 | "(%d, %d). This adjustment will be done to all images "
158 | "whose sizes are not multiples of 4" % (ow, oh, w, h))
159 | __print_size_warning.has_printed = True
160 |
--------------------------------------------------------------------------------
/applications/iih_relighting/data/image_folder.py:
--------------------------------------------------------------------------------
1 | """A modified image folder class
2 |
3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4 | so that this class can load images from both current directory and its subdirectories.
5 | """
6 |
7 | import torch.utils.data as data
8 |
9 | from PIL import Image
10 | import os
11 | import os.path
12 |
13 | IMG_EXTENSIONS = [
14 | '.jpg', '.JPG', '.jpeg', '.JPEG',
15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16 | ]
17 |
18 |
19 | def is_image_file(filename):
20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21 |
22 |
23 | def make_dataset(dir, max_dataset_size=float("inf")):
24 | images = []
25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
26 |
27 | for root, _, fnames in sorted(os.walk(dir)):
28 | for fname in fnames:
29 | if is_image_file(fname):
30 | path = os.path.join(root, fname)
31 | images.append(path)
32 | return images[:min(max_dataset_size, len(images))]
33 |
34 |
35 | def default_loader(path):
36 | return Image.open(path).convert('RGB')
37 |
38 |
39 | class ImageFolder(data.Dataset):
40 |
41 | def __init__(self, root, transform=None, return_paths=False,
42 | loader=default_loader):
43 | imgs = make_dataset(root)
44 | if len(imgs) == 0:
45 | raise(RuntimeError("Found 0 images in: " + root + "\n"
46 | "Supported image extensions are: " +
47 | ",".join(IMG_EXTENSIONS)))
48 |
49 | self.root = root
50 | self.imgs = imgs
51 | self.transform = transform
52 | self.return_paths = return_paths
53 | self.loader = loader
54 |
55 | def __getitem__(self, index):
56 | path = self.imgs[index]
57 | img = self.loader(path)
58 | if self.transform is not None:
59 | img = self.transform(img)
60 | if self.return_paths:
61 | return img, path
62 | else:
63 | return img
64 |
65 | def __len__(self):
66 | return len(self.imgs)
67 |
--------------------------------------------------------------------------------
/applications/iih_relighting/data/loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | """Data loader."""
5 |
6 | import itertools
7 | import numpy as np
8 | import torch
9 | from torch.utils.data._utils.collate import default_collate
10 | from torch.utils.data.distributed import DistributedSampler
11 | from torch.utils.data.sampler import RandomSampler
12 |
13 | from slowfast.datasets.multigrid_helper import ShortCycleBatchSampler
14 |
15 | from . import utils as utils
16 |
17 | def build_dataset(cfg):
18 | image_paths = []
19 | if cfg.phase == 'train':
20 | print('loading training file')
21 | file = cfg.dataset_root+cfg.dataset_name+'_train.txt'
22 | with open(file,'r') as f:
23 | for line in f.readlines():
24 | image_paths.append(os.path.join(cfg.dataset_root,'composite_images',line.rstrip()))
25 |
26 |
27 | def construct_loader(cfg, split, is_precise_bn=False):
28 | """
29 | Constructs the data loader for the given dataset.
30 | Args:
31 | cfg (CfgNode): configs. Details can be found in
32 | slowfast/config/defaults.py
33 | split (str): the split of the data loader. Options include `train`,
34 | `val`, and `test`.
35 | """
36 | assert split in ["train", "val", "test"]
37 | if split in ["train"]:
38 | dataset_name = cfg.TRAIN.DATASET
39 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
40 | shuffle = True
41 | drop_last = True
42 | elif split in ["val"]:
43 | dataset_name = cfg.TRAIN.DATASET
44 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
45 | shuffle = False
46 | drop_last = False
47 | elif split in ["test"]:
48 | dataset_name = cfg.TEST.DATASET
49 | batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS))
50 | shuffle = False
51 | drop_last = False
52 |
53 | # Construct the dataset
54 | dataset = build_dataset(dataset_name, cfg, split)
55 |
56 | if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn:
57 | # Create a sampler for multi-process training
58 | sampler = utils.create_sampler(dataset, shuffle, cfg)
59 | batch_sampler = ShortCycleBatchSampler(
60 | sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg
61 | )
62 | # Create a loader
63 | loader = torch.utils.data.DataLoader(
64 | dataset,
65 | batch_sampler=batch_sampler,
66 | num_workers=cfg.DATA_LOADER.NUM_WORKERS,
67 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
68 | worker_init_fn=utils.loader_worker_init_fn(dataset),
69 | )
70 | else:
71 | # Create a sampler for multi-process training
72 | sampler = utils.create_sampler(dataset, shuffle, cfg)
73 | # Create a loader
74 | loader = torch.utils.data.DataLoader(
75 | dataset,
76 | batch_size=batch_size,
77 | shuffle=(False if sampler else shuffle),
78 | sampler=sampler,
79 | num_workers=cfg.DATA_LOADER.NUM_WORKERS,
80 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
81 | drop_last=drop_last,
82 | collate_fn=detection_collate if cfg.DETECTION.ENABLE else None,
83 | worker_init_fn=utils.loader_worker_init_fn(dataset),
84 | )
85 | return loader
86 |
87 |
88 | def shuffle_dataset(loader, cur_epoch):
89 | """ "
90 | Shuffles the data.
91 | Args:
92 | loader (loader): data loader to perform shuffle.
93 | cur_epoch (int): number of the current epoch.
94 | """
95 | sampler = (
96 | loader.batch_sampler.sampler
97 | if isinstance(loader.batch_sampler, ShortCycleBatchSampler)
98 | else loader.sampler
99 | )
100 | assert isinstance(
101 | sampler, (RandomSampler, DistributedSampler)
102 | ), "Sampler type '{}' not supported".format(type(sampler))
103 | # RandomSampler handles shuffling automatically
104 | if isinstance(sampler, DistributedSampler):
105 | # DistributedSampler shuffles data based on epoch
106 | sampler.set_epoch(cur_epoch)
107 |
--------------------------------------------------------------------------------
/applications/iih_relighting/models/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to objective functions, optimizations, and network architectures.
2 |
3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4 | You need to implement the following five functions:
5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6 | -- : unpack data from dataset and apply preprocessing.
7 | -- : produce intermediate results.
8 | -- : calculate loss, gradients, and update network weights.
9 | -- : (optionally) add model-specific options and set default options.
10 |
11 | In the function <__init__>, you need to define four lists:
12 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
13 | -- self.model_names (str list): define networks used in our training.
14 | -- self.visual_names (str list): specify the images that you want to display and save.
15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16 |
17 | Now you can use the model class by specifying flag '--model dummy'.
18 | See our template model class 'template_model.py' for more details.
19 | """
20 |
21 | import importlib
22 | from models.base_model import BaseModel
23 |
24 |
25 | def find_model_using_name(model_name):
26 | """Import the module "models/[model_name]_model.py".
27 |
28 | In the file, the class called DatasetNameModel() will
29 | be instantiated. It has to be a subclass of BaseModel,
30 | and it is case-insensitive.
31 | """
32 | model_filename = "models." + model_name + "_model"
33 | modellib = importlib.import_module(model_filename)
34 | model = None
35 | target_model_name = model_name.replace('_', '') + 'model'
36 | for name, cls in modellib.__dict__.items():
37 | if name.lower() == target_model_name.lower() \
38 | and issubclass(cls, BaseModel):
39 | model = cls
40 |
41 | if model is None:
42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43 | exit(0)
44 |
45 | return model
46 |
47 |
48 | def get_option_setter(model_name):
49 | """Return the static method of the model class."""
50 | model_class = find_model_using_name(model_name)
51 | return model_class.modify_commandline_options
52 |
53 |
54 | def create_model(opt):
55 | """Create a model given the option.
56 |
57 | This function warps the class CustomDatasetDataLoader.
58 | This is the main interface between this package and 'train.py'/'test.py'
59 |
60 | Example:
61 | >>> from models import create_model
62 | >>> model = create_model(opt)
63 | """
64 | model = find_model_using_name(opt.model)
65 | instance = model(opt)
66 | print("model [%s] was created" % type(instance).__name__)
67 | return instance
68 |
--------------------------------------------------------------------------------
/applications/iih_relighting/models/iih_base_lt_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import itertools
3 | import torch.nn.functional as F
4 | from util import distributed as du
5 | from .base_model import BaseModel
6 | from util import util
7 | from . import relighting_networks as networks
8 | from . import networks as network_init
9 | import util.ssim as ssim
10 |
11 |
12 | class IIHBaseLTModel(BaseModel):
13 |
14 | @staticmethod
15 | def modify_commandline_options(parser, is_train=True):
16 |
17 | parser.set_defaults(norm='instance', netG='base_lt', dataset_mode='dpr')
18 | parser.add_argument('--action', type=str, default='relighting', help='weight for L1 loss')
19 | if is_train:
20 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
21 | parser.add_argument('--lambda_R_gradient', type=float, default=10., help='weight for R gradient loss')
22 | parser.add_argument('--lambda_ssim', type=float, default=50., help='weight for L L2 loss')
23 | parser.add_argument('--lambda_I_smooth', type=float, default=1., help='weight for L L2 loss')
24 | parser.add_argument('--lambda_I_L2', type=float, default=10., help='weight for L L2 loss')
25 | parser.add_argument('--lambda_ifm', type=float, default=100, help='weight for pm loss')
26 | parser.add_argument('--lambda_L', type=float, default=100.0, help='weight for L1 loss')
27 |
28 | return parser
29 |
30 | def __init__(self, opt):
31 | """Initialize the pix2pix class.
32 |
33 | Parameters:
34 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
35 | """
36 | BaseModel.__init__(self, opt)
37 | self.opt = opt
38 | # specify the training losses you want to print out. The training/test scripts will call
39 | self.loss_names = ['G','G_L1','G_R','G_I_L2','G_I_smooth',"G_L"]
40 |
41 | # specify the images you want to save/display. The training/test scripts will call
42 | self.visual_names = ['harmonized','real','fake','reflectance','illumination']
43 | # specify the models you want to save to the disk. The training/test scripts will call and
44 | self.model_names = ['G']
45 | self.opt.device = self.device
46 | self.netG = networks.define_G(opt.netG, opt.init_type, opt.init_gain, self.opt)
47 | self.cur_device = torch.cuda.current_device()
48 | self.ismaster = du.is_master_proc(opt.NUM_GPUS)
49 | if self.ismaster:
50 | print(self.netG)
51 |
52 | if self.isTrain:
53 | util.saveprint(self.opt, 'netG', str(self.netG))
54 | # define loss functions
55 | self.criterionL1 = torch.nn.L1Loss()
56 | self.criterionL2 = torch.nn.MSELoss()
57 | self.criterionSSIM = ssim.SSIM()
58 | # initialize optimizers; schedulers will be automatically created by function .
59 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
60 | self.optimizers.append(self.optimizer_G)
61 |
62 | def set_input(self, input):
63 |
64 | self.fake = input['fake'].to(self.device)
65 | self.real = input['real'].to(self.device)
66 | if input['target'] is not None:
67 | self.target = input['target'].to(self.device)
68 | if input['fake_light'] is not None:
69 | self.light_fake = input['fake_light'].to(self.device)
70 | if input['real_light'] is not None:
71 | self.light_real = input['real_light'].to(self.device)
72 | self.image_paths = input['img_path']
73 |
74 | def forward(self):
75 | """Run forward pass; called by both functions and ."""
76 | if self.isTrain:
77 | self.harmonized, self.reflectance, self.illumination, self.light_gen_fake = self.netG(self.fake, self.light_real)
78 | else:
79 | if self.opt.action == "relighting":
80 | self.harmonized, self.reflectance, self.illumination, self.light_gen_fake = self.netG(self.fake, isTest=True, light=self.light_real)
81 | else:
82 | self.harmonized, self.reflectance, self.illumination, self.light_gen_fake = self.netG(self.fake, isTest=True, target=self.target)
83 | def backward_G(self):
84 | """Calculate GAN and L1 loss for the generator"""
85 | self.loss_G_L1 = self.criterionL1(self.harmonized, self.real)*self.opt.lambda_L1
86 | self.loss_G_R = (self.gradient_loss(self.reflectance, self.fake)+self.gradient_loss(self.reflectance, self.real))*self.opt.lambda_R_gradient
87 | self.loss_G_I_smooth = util.compute_smooth_loss(self.illumination)*self.opt.lambda_I_smooth
88 | self.loss_G_I_L2 = self.criterionL2(self.illumination, self.real)*self.opt.lambda_I_L2
89 | self.loss_G_L = self.criterionL2(self.light_gen_fake, self.light_fake)*self.opt.lambda_L
90 | # assert 0
91 | self.loss_G = self.loss_G_L1 + self.loss_G_R + self.loss_G_I_smooth + self.loss_G_I_L2 + self.loss_G_L
92 | self.loss_G.backward()
93 |
94 | def optimize_parameters(self):
95 | self.forward() # compute fake images: G(A)
96 | # update G
97 | self.optimizer_G.zero_grad() # set G's gradients to zero
98 | self.backward_G() # calculate graidents for G
99 | self.optimizer_G.step() # udpate G's weights
100 |
101 | def gradient_loss(self, input_1, input_2):
102 | g_x = self.criterionL1(util.gradient(input_1, 'x'), util.gradient(input_2, 'x'))
103 | g_y = self.criterionL1(util.gradient(input_1, 'y'), util.gradient(input_2, 'y'))
104 | return g_x+g_y
105 |
106 | def __compute_kl(self, mu):
107 | mu_2 = torch.pow(mu, 2)
108 | encoding_loss = torch.mean(mu_2)
109 | return encoding_loss
110 |
111 |
--------------------------------------------------------------------------------
/applications/iih_relighting/options/__init__.py:
--------------------------------------------------------------------------------
1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
2 |
--------------------------------------------------------------------------------
/applications/iih_relighting/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | """This class includes test options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser) # define shared options
12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
16 | # Dropout and Batchnorm has different behavioir during training and test.
17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
18 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
19 | parser.add_argument('--test_epoch', type=str, default="0", help='how many test images to run')
20 | # rewrite devalue values
21 | parser.set_defaults(model='test')
22 | # To avoid cropping, the load_size should be the same as crop_size
23 | parser.set_defaults(load_size=parser.get_default('crop_size'))
24 | self.isTrain = False
25 | return parser
26 |
--------------------------------------------------------------------------------
/applications/iih_relighting/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | """This class includes training options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser)
12 | # visdom and HTML visualization parameters
13 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
14 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
15 | parser.add_argument('--display_id', type=int, default=0, help='window id of the web display')
16 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
17 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
18 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
19 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
20 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
21 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
22 | # network saving and loading parameters
23 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
24 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
25 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
26 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
27 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
28 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
29 | # training parameters
30 | parser.add_argument('--niter', type=int, default=5, help='# of iter at starting learning rate')
31 | parser.add_argument('--niter_decay', type=int, default=5, help='# of iter to linearly decay learning rate to zero')
32 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
33 | parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
34 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
35 | parser.add_argument('--g_lr_ratio', type=float, default=1.0, help='a ratio for changing learning rate of generator')
36 | parser.add_argument('--d_lr_ratio', type=float, default=1.0, help='a ratio for changing learning rate of discriminator')
37 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
38 | parser.add_argument('--pool_size', type=int, default=40, help='the size of image buffer that stores previously generated images')
39 | parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]')
40 | parser.add_argument('--lr_decay_iters', type=int, default=40, help='multiply by a gamma every lr_decay_iters iterations')
41 | parser.add_argument('--save_iter_model', action='store_true', help='whether saves model by iteration')
42 |
43 |
44 |
45 |
46 | self.isTrain = True
47 | return parser
48 |
--------------------------------------------------------------------------------
/applications/iih_relighting/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 | from util.misc import launch_job
5 | from train_net import test
6 |
7 | from options.test_options import TestOptions
8 |
9 | def main():
10 | cfg = TestOptions().parse() # get training options
11 | cfg.NUM_GPUS = torch.cuda.device_count()
12 | cfg.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
13 | cfg.no_flip = True # no flip; comment this line if results on flipped images are needed.
14 | cfg.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
15 |
16 | cfg.phase = 'test'
17 | cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS))
18 | launch_job(cfg=cfg, init_method=cfg.init_method, func=test)
19 |
20 |
21 | if __name__=="__main__":
22 | main()
--------------------------------------------------------------------------------
/applications/iih_relighting/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 | from util.misc import launch_job
5 | from train_net import train
6 |
7 | from options.train_options import TrainOptions
8 |
9 | def main():
10 | cfg = TrainOptions().parse() # get training options
11 | cfg.NUM_GPUS = torch.cuda.device_count()
12 | cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS))
13 | cfg.phase = 'train'
14 | launch_job(cfg=cfg, init_method=cfg.init_method, func=train)
15 |
16 |
17 | if __name__=="__main__":
18 | main()
--------------------------------------------------------------------------------
/applications/iih_relighting/util/evaluation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as f
3 |
4 |
5 | def evaluation(name, fake, real, mask):
6 | b,c,w,h = real.size()
7 | mse_score = f.mse_loss(fake, real)
8 | fore_area = torch.sum(mask)
9 | fmse_score = f.mse_loss(fake*mask,real*mask)*w*h/fore_area
10 | mse_score = mse_score.item()
11 | fmse_score = fmse_score.item()
12 | # score_str = "%s MSE %0.2f | fMSE %0.2f" % (name, mse_score,fmse_score)
13 | image_fmse_info = (name, round(fmse_score,2), round(mse_score, 2))
14 | return mse_score, fmse_score, image_fmse_info
--------------------------------------------------------------------------------
/applications/iih_relighting/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3 | import os
4 |
5 |
6 | class HTML:
7 | """This HTML class allows us to save images and write texts into a single HTML file.
8 |
9 | It consists of functions such as (add a text header to the HTML file),
10 | (add a row of images to the HTML file), and (save the HTML to the disk).
11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12 | """
13 |
14 | def __init__(self, web_dir, title, refresh=0):
15 | """Initialize the HTML classes
16 |
17 | Parameters:
18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
32 | with self.doc.head:
33 | meta(http_equiv="refresh", content=str(refresh))
34 |
35 | def get_image_dir(self):
36 | """Return the directory that stores images"""
37 | return self.img_dir
38 |
39 | def add_header(self, text):
40 | """Insert a header to the HTML file
41 |
42 | Parameters:
43 | text (str) -- the header text
44 | """
45 | with self.doc:
46 | h3(text)
47 |
48 | def add_images(self, ims, txts, links, width=400):
49 | """add images to the HTML file
50 |
51 | Parameters:
52 | ims (str list) -- a list of image paths
53 | txts (str list) -- a list of image names shown on the website
54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55 | """
56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57 | self.doc.add(self.t)
58 | with self.t:
59 | with tr():
60 | for im, txt, link in zip(ims, txts, links):
61 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
62 | with p():
63 | with a(href=os.path.join('images', link)):
64 | img(style="width:%dpx" % width, src=os.path.join('images', im))
65 | br()
66 | p(txt)
67 |
68 | def save(self):
69 | """save the current content to the HMTL file"""
70 | html_file = '%s/index.html' % self.web_dir
71 | f = open(html_file, 'wt')
72 | f.write(self.doc.render())
73 | f.close()
74 |
75 |
76 | if __name__ == '__main__': # we show an example usage here.
77 | html = HTML('web/', 'test_html')
78 | html.add_header('hello world')
79 |
80 | ims, txts, links = [], [], []
81 | for n in range(4):
82 | ims.append('image_%d.png' % n)
83 | txts.append('text_%d' % n)
84 | links.append('image_%d.png' % n)
85 | html.add_images(ims, txts, links)
86 | html.save()
87 |
--------------------------------------------------------------------------------
/applications/iih_relighting/util/multiprocessing.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | """Multiprocessing helpers."""
5 |
6 | import torch
7 |
8 |
9 | def run(
10 | local_rank,
11 | num_proc,
12 | func,
13 | init_method,
14 | shard_id,
15 | num_shards,
16 | backend,
17 | cfg,
18 | output_queue=None,
19 | ):
20 | """
21 | Runs a function from a child process.
22 | Args:
23 | local_rank (int): rank of the current process on the current machine.
24 | num_proc (int): number of processes per machine.
25 | func (function): function to execute on each of the process.
26 | init_method (string): method to initialize the distributed training.
27 | TCP initialization: equiring a network address reachable from all
28 | processes followed by the port.
29 | Shared file-system initialization: makes use of a file system that
30 | is shared and visible from all machines. The URL should start with
31 | file:// and contain a path to a non-existent file on a shared file
32 | system.
33 | shard_id (int): the rank of the current machine.
34 | num_shards (int): number of overall machines for the distributed
35 | training job.
36 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are
37 | supports, each with different capabilities. Details can be found
38 | here:
39 | https://pytorch.org/docs/stable/distributed.html
40 | cfg (CfgNode): configs. Details can be found in
41 | slowfast/config/defaults.py
42 | output_queue (queue): can optionally be used to return values from the
43 | master process.
44 | """
45 | # Initialize the process group.
46 | world_size = num_proc * num_shards
47 | rank = shard_id * num_proc + local_rank
48 | try:
49 | torch.distributed.init_process_group(
50 | backend=backend,
51 | init_method=init_method,
52 | world_size=world_size,
53 | rank=rank,
54 | )
55 |
56 | except Exception as e:
57 | raise e
58 |
59 | torch.cuda.set_device(local_rank)
60 | ret = func(cfg)
61 | if output_queue is not None and local_rank == 0:
62 | output_queue.put(ret)
63 |
--------------------------------------------------------------------------------
/applications/iih_relighting/util/ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 | def create_window(window_size, channel):
12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15 | return window
16 |
17 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
20 |
21 | mu1_sq = mu1.pow(2)
22 | mu2_sq = mu2.pow(2)
23 | mu1_mu2 = mu1*mu2
24 |
25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
28 |
29 | C1 = 0.01**2
30 | C2 = 0.03**2
31 |
32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
33 |
34 | if size_average:
35 | return ssim_map.mean()
36 | else:
37 | return ssim_map.mean(1).mean(1).mean(1)
38 |
39 | def _ssim_c_s(img1, img2, window, window_size, channel, size_average = True):
40 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
41 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
42 |
43 | mu1_sq = mu1.pow(2)
44 | mu2_sq = mu2.pow(2)
45 | mu1_mu2 = mu1*mu2
46 |
47 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
48 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
49 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
50 |
51 | C1 = 0.01**2
52 | C2 = 0.03**2
53 |
54 | ssim_map = ((2*sigma12 + C2))/((sigma1_sq + sigma2_sq + C2))
55 |
56 | if size_average:
57 | return ssim_map.mean()
58 | else:
59 | return ssim_map.mean(1).mean(1).mean(1)
60 | def _ssim_l(img1, img2, window, window_size, channel, size_average = True):
61 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
62 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
63 |
64 | mu1_sq = mu1.pow(2)
65 | mu2_sq = mu2.pow(2)
66 | mu1_mu2 = mu1*mu2
67 |
68 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
69 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
70 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
71 |
72 | C1 = 0.01**2
73 | C2 = 0.03**2
74 |
75 | ssim_map = (2*mu1_mu2 + C1)/(mu1_sq + mu2_sq + C1)
76 |
77 | if size_average:
78 | return ssim_map.mean()
79 | else:
80 | return ssim_map.mean(1).mean(1).mean(1)
81 |
82 |
83 | def ssim(img1, img2, window_size = 11, size_average = True):
84 | (_, channel, _, _) = img1.size()
85 | window = create_window(window_size, channel)
86 |
87 | if img1.is_cuda:
88 | window = window.cuda(img1.get_device())
89 | window = window.type_as(img1)
90 |
91 | return _ssim(img1, img2, window, window_size, channel, size_average)
92 |
93 | class SSIM(torch.nn.Module):
94 | def __init__(self, window_size = 11, size_average = True, mode='all'):
95 | super(SSIM, self).__init__()
96 | self.window_size = window_size
97 | self.size_average = size_average
98 | self.channel = 1
99 | self.window = create_window(window_size, self.channel)
100 | self.mode = mode
101 | def forward(self, img1, img2):
102 | (_, channel, _, _) = img1.size()
103 |
104 | if channel == self.channel and self.window.data.type() == img1.data.type():
105 | window = self.window
106 | else:
107 | window = create_window(self.window_size, channel)
108 |
109 | # if img1.is_cuda:
110 | # window = window.cuda(img1.get_device())
111 | window = window.type_as(img1)
112 |
113 | self.window = window
114 | self.channel = channel
115 | if self.mode == 'all':
116 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
117 | elif self.mode == 'c_s':
118 | return _ssim_c_s(img1, img2, window, self.window_size, channel, self.size_average)
119 | else:
120 | return _ssim_l(img1, img2, window, self.window_size, channel, self.size_average)
121 |
122 | class DSSIM(torch.nn.Module):
123 | def __init__(self, window_size = 11, size_average = True, mode='all'):
124 | super(DSSIM, self).__init__()
125 | self.window_size = window_size
126 | self.size_average = size_average
127 | self.channel = 1
128 | self.window = create_window(window_size, self.channel)
129 | self.mode = mode
130 | def forward(self, img1, img2):
131 | (_, channel, _, _) = img1.size()
132 |
133 | if channel == self.channel and self.window.data.type() == img1.data.type():
134 | window = self.window
135 | else:
136 | window = create_window(self.window_size, channel)
137 |
138 | # if img1.is_cuda:
139 | # window = window.cuda(img1.get_device())
140 | window = window.type_as(img1)
141 |
142 | self.window = window
143 | self.channel = channel
144 | if self.mode == 'all':
145 | ssim_v = _ssim(img1, img2, window, self.window_size, channel, self.size_average)
146 | elif self.mode == 'c_s':
147 | ssim_v = _ssim_c_s(img1, img2, window, self.window_size, channel, self.size_average)
148 | else:
149 | ssim_v = _ssim_l(img1, img2, window, self.window_size, channel, self.size_average)
150 | return (1-ssim_v)/2
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes all the modules related to data loading and preprocessing
2 |
3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4 | You need to implement four functions:
5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6 | -- <__len__>: return the size of dataset.
7 | -- <__getitem__>: get a data point from data loader.
8 | -- : (optionally) add dataset-specific options and set default options.
9 |
10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11 | See our template dataset class 'template_dataset.py' for more details.
12 | """
13 | import importlib
14 | import torch.utils.data
15 | from data.base_dataset import BaseDataset
16 | from torch.utils.data.distributed import DistributedSampler
17 | from torch.utils.data.sampler import RandomSampler
18 |
19 |
20 | def find_dataset_using_name(dataset_name):
21 | """Import the module "data/[dataset_name]_dataset.py".
22 |
23 | In the file, the class called DatasetNameDataset() will
24 | be instantiated. It has to be a subclass of BaseDataset,
25 | and it is case-insensitive.
26 | """
27 | dataset_filename = "data." + dataset_name + "_dataset"
28 | datasetlib = importlib.import_module(dataset_filename)
29 |
30 | dataset = None
31 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
32 | for name, cls in datasetlib.__dict__.items():
33 | if name.lower() == target_dataset_name.lower() \
34 | and issubclass(cls, BaseDataset):
35 | dataset = cls
36 |
37 | if dataset is None:
38 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
39 |
40 | return dataset
41 |
42 |
43 | def get_option_setter(dataset_name):
44 | """Return the static method of the dataset class."""
45 | dataset_class = find_dataset_using_name(dataset_name)
46 | return dataset_class.modify_commandline_options
47 |
48 |
49 | def create_dataset(opt):
50 | """Create a dataset given the option.
51 |
52 | This function wraps the class CustomDatasetDataLoader.
53 | This is the main interface between this package and 'train.py'/'test.py'
54 |
55 | Example:
56 | >>> from data import create_dataset
57 | >>> dataset = create_dataset(opt)
58 | """
59 | # data_loader = CustomDatasetDataLoader(opt)
60 | # dataset = data_loader.load_data()
61 |
62 | dataset_class = find_dataset_using_name(opt.dataset_mode)
63 | dataset = dataset_class(opt)
64 | print("dataset [%s] was created" % type(dataset).__name__)
65 |
66 | # batch_size = int(opt.batch_size / max(1, opt.NUM_GPUS))
67 | if opt.isTrain==True:
68 | shuffle = True
69 | drop_last = True
70 | elif opt.isTrain==False:
71 | shuffle = False
72 | drop_last = False
73 |
74 | sampler = torch.utils.data.distributed.DistributedSampler(dataset) if opt.NUM_GPUS > 1 else None
75 |
76 | # Create a loader
77 | dataloader = torch.utils.data.DataLoader(
78 | dataset,
79 | batch_size=opt.batch_size,
80 | shuffle=(False if sampler else shuffle),
81 | sampler=sampler,
82 | num_workers=int(opt.num_threads),
83 | drop_last=drop_last,
84 | pin_memory=True,
85 | )
86 | return dataloader
87 |
88 |
89 | class CustomDatasetDataLoader():
90 | """Wrapper class of Dataset class that performs multi-threaded data loading"""
91 |
92 | def __init__(self, opt):
93 | """Initialize this class
94 |
95 | Step 1: create a dataset instance given the name [dataset_mode]
96 | Step 2: create a multi-threaded data loader.
97 | """
98 | self.opt = opt
99 | dataset_class = find_dataset_using_name(opt.dataset_mode)
100 | self.dataset = dataset_class(opt)
101 | print("dataset [%s] was created" % type(self.dataset).__name__)
102 |
103 | batch_size = int(opt.batch_size / max(1, opt.NUM_GPUS))
104 | if opt.isTrain==True:
105 | shuffle = True
106 | drop_last = True
107 | elif opt.isTrain==False:
108 | shuffle = False
109 | drop_last = False
110 |
111 | self.sampler = torch.utils.data.distributed.DistributedSampler(self.dataset) if opt.NUM_GPUS > 1 else None
112 |
113 | # Create a loader
114 | self.dataloader = torch.utils.data.DataLoader(
115 | self.dataset,
116 | batch_size=batch_size,
117 | shuffle=(False if self.sampler else shuffle),
118 | sampler=self.sampler,
119 | num_workers=int(opt.num_threads),
120 | drop_last=drop_last,
121 | )
122 |
123 | # self.dataloader = torch.utils.data.DataLoader(
124 | # self.dataset,
125 | # batch_size=opt.batch_size,
126 | # shuffle=not opt.serial_batches,
127 | # num_workers=int(opt.num_threads))
128 |
129 | def load_data(self):
130 | return self
131 |
132 | def __len__(self):
133 | """Return the number of data in the dataset"""
134 | return min(len(self.dataset), self.opt.max_dataset_size)
135 |
136 | # def __iter__(self):
137 | # """Return a batch of data"""
138 | # for i, data in enumerate(self.dataloader):
139 | # if i * self.opt.batch_size >= self.opt.max_dataset_size:
140 | # break
141 | # yield data
142 |
143 | def shuffle_dataset(loader, cur_epoch):
144 | """ "
145 | Shuffles the data.
146 | Args:
147 | loader (loader): data loader to perform shuffle.
148 | cur_epoch (int): number of the current epoch.
149 | """
150 | # sampler = (
151 | # loader.batch_sampler.sampler
152 | # if isinstance(loader.batch_sampler, ShortCycleBatchSampler)
153 | # else loader.sampler
154 | # )
155 | sampler = loader.sampler
156 | assert isinstance(
157 | sampler, (RandomSampler, DistributedSampler)
158 | ), "Sampler type '{}' not supported".format(type(sampler))
159 | # RandomSampler handles shuffling automatically
160 | if isinstance(sampler, DistributedSampler):
161 | # DistributedSampler shuffles data based on epoch
162 | sampler.set_epoch(cur_epoch)
163 |
164 |
165 |
166 |
167 |
168 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2 |
3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4 | """
5 | import random
6 | import numpy as np
7 | import torch.utils.data as data
8 | from PIL import Image
9 | import torchvision.transforms as transforms
10 | from abc import ABC, abstractmethod
11 |
12 |
13 | class BaseDataset(data.Dataset, ABC):
14 | """This class is an abstract base class (ABC) for datasets.
15 |
16 | To create a subclass, you need to implement the following four functions:
17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18 | -- <__len__>: return the size of dataset.
19 | -- <__getitem__>: get a data point.
20 | -- : (optionally) add dataset-specific options and set default options.
21 | """
22 |
23 | def __init__(self, opt):
24 | """Initialize the class; save the options in the class
25 |
26 | Parameters:
27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28 | """
29 | self.opt = opt
30 | # self.root = opt.dataroot
31 | self.root = opt.dataset_root
32 |
33 | @staticmethod
34 | def modify_commandline_options(parser, is_train):
35 | """Add new dataset-specific options, and rewrite default values for existing options.
36 |
37 | Parameters:
38 | parser -- original option parser
39 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
40 |
41 | Returns:
42 | the modified parser.
43 | """
44 | return parser
45 |
46 | @abstractmethod
47 | def __len__(self):
48 | """Return the total number of images in the dataset."""
49 | return 0
50 |
51 | @abstractmethod
52 | def __getitem__(self, index):
53 | """Return a data point and its metadata information.
54 |
55 | Parameters:
56 | index - - a random integer for data indexing
57 |
58 | Returns:
59 | a dictionary of data with their names. It ususally contains the data itself and its metadata information.
60 | """
61 | pass
62 |
63 |
64 | def get_params(opt, size):
65 | w, h = size
66 | new_h = h
67 | new_w = w
68 | if opt.preprocess == 'resize_and_crop':
69 | new_h = new_w = opt.load_size
70 | elif opt.preprocess == 'scale_width_and_crop':
71 | new_w = opt.load_size
72 | new_h = opt.load_size * h // w
73 |
74 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
75 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
76 |
77 | flip = random.random() > 0.5
78 |
79 | return {'crop_pos': (x, y), 'flip': flip}
80 |
81 |
82 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
83 | transform_list = []
84 | if grayscale:
85 | transform_list.append(transforms.Grayscale(1))
86 | if 'resize' in opt.preprocess:
87 | osize = [opt.load_size, opt.load_size]
88 | transform_list.append(transforms.Resize(osize, method))
89 | elif 'scale_width' in opt.preprocess:
90 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
91 |
92 | if 'crop' in opt.preprocess:
93 | if params is None:
94 | transform_list.append(transforms.RandomCrop(opt.crop_size))
95 | else:
96 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
97 |
98 | if opt.preprocess == 'none':
99 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
100 |
101 | if not opt.no_flip:
102 | if params is None:
103 | transform_list.append(transforms.RandomHorizontalFlip())
104 | elif params['flip']:
105 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
106 |
107 | if convert:
108 | transform_list += [transforms.ToTensor()]
109 | if grayscale:
110 | transform_list += [transforms.Normalize((0.5,), (0.5,))]
111 | else:
112 | # transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
113 | transform_list += [transforms.Normalize((0, 0, 0), (1, 1, 1))]
114 | return transforms.Compose(transform_list)
115 |
116 |
117 | def __make_power_2(img, base, method=Image.BICUBIC):
118 | ow, oh = img.size
119 | h = int(round(oh / base) * base)
120 | w = int(round(ow / base) * base)
121 | if (h == oh) and (w == ow):
122 | return img
123 |
124 | __print_size_warning(ow, oh, w, h)
125 | return img.resize((w, h), method)
126 |
127 |
128 | def __scale_width(img, target_width, method=Image.BICUBIC):
129 | ow, oh = img.size
130 | if (ow == target_width):
131 | return img
132 | w = target_width
133 | h = int(target_width * oh / ow)
134 | return img.resize((w, h), method)
135 |
136 |
137 | def __crop(img, pos, size):
138 | ow, oh = img.size
139 | x1, y1 = pos
140 | tw = th = size
141 | if (ow > tw or oh > th):
142 | return img.crop((x1, y1, x1 + tw, y1 + th))
143 | return img
144 |
145 |
146 | def __flip(img, flip):
147 | if flip:
148 | return img.transpose(Image.FLIP_LEFT_RIGHT)
149 | return img
150 |
151 |
152 | def __print_size_warning(ow, oh, w, h):
153 | """Print warning information about image size(only print once)"""
154 | if not hasattr(__print_size_warning, 'has_printed'):
155 | print("The image size needs to be a multiple of 4. "
156 | "The loaded image size was (%d, %d), so it was adjusted to "
157 | "(%d, %d). This adjustment will be done to all images "
158 | "whose sizes are not multiples of 4" % (ow, oh, w, h))
159 | __print_size_warning.has_printed = True
160 |
--------------------------------------------------------------------------------
/data/ihd_dataset.py:
--------------------------------------------------------------------------------
1 | """Dataset class template
2 |
3 | This module provides a template for users to implement custom datasets.
4 | You can specify '--dataset_mode template' to use this dataset.
5 | The class name should be consistent with both the filename and its dataset_mode option.
6 | The filename should be _dataset.py
7 | The class name should be Dataset.py
8 | You need to implement the following functions:
9 | -- : Add dataset-specific options and rewrite default values for existing options.
10 | -- <__init__>: Initialize this dataset class.
11 | -- <__getitem__>: Return a data point and its metadata information.
12 | -- <__len__>: Return the number of images.
13 | """
14 | import os.path
15 | import torch
16 | import torchvision.transforms.functional as tf
17 | import torch.nn.functional as F
18 | from data.base_dataset import BaseDataset, get_transform
19 | from data.image_folder import make_dataset
20 | from PIL import Image
21 | import numpy as np
22 | import torchvision.transforms as transforms
23 | from util import util
24 |
25 | class IhdDataset(BaseDataset):
26 | """A template dataset class for you to implement custom datasets."""
27 | @staticmethod
28 | def modify_commandline_options(parser, is_train):
29 | """Add new dataset-specific options, and rewrite default values for existing options.
30 |
31 | Parameters:
32 | parser -- original option parser
33 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
34 |
35 | Returns:
36 | the modified parser.
37 | """
38 | parser.add_argument('--is_train', type=bool, default=True, help='whether in the training phase')
39 | parser.set_defaults(max_dataset_size=float("inf"), new_dataset_option=2.0) # specify dataset-specific default values
40 | return parser
41 |
42 | def __init__(self, opt):
43 | """Initialize this dataset class.
44 |
45 | Parameters:
46 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
47 |
48 | A few things can be done here.
49 | - save the options (have been done in BaseDataset)
50 | - get image paths and meta information of the dataset.
51 | - define the image transformation.
52 | """
53 | # save the option and dataset root
54 | BaseDataset.__init__(self, opt)
55 | self.image_paths = []
56 | self.isTrain = opt.isTrain
57 | self.image_size = opt.crop_size
58 |
59 | if opt.isTrain==True:
60 | print('loading training file')
61 | self.trainfile = opt.dataset_root+opt.dataset_name+'_train.txt'
62 | with open(self.trainfile,'r') as f:
63 | for line in f.readlines():
64 | self.image_paths.append(os.path.join(opt.dataset_root,'composite_images',line.rstrip()))
65 | elif opt.isTrain==False:
66 | #self.real_ext='.jpg'
67 | print('loading test file')
68 | self.trainfile = opt.dataset_root+opt.dataset_name+'_test.txt'
69 | with open(self.trainfile,'r') as f:
70 | for line in f.readlines():
71 | self.image_paths.append(os.path.join(opt.dataset_root,'composite_images',line.rstrip()))
72 | # get the image paths of your dataset;
73 | # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
74 | # define the default transform function. You can use ; You can also define your custom transform function
75 | transform_list = [
76 | transforms.ToTensor(),
77 | transforms.Normalize((0, 0, 0), (1, 1, 1))
78 | ]
79 | self.transforms = transforms.Compose(transform_list)
80 |
81 | def __getitem__(self, index):
82 | """Return a data point and its metadata information.
83 |
84 | Parameters:
85 | index -- a random integer for data indexing
86 |
87 | Returns:
88 | a dictionary of data with their names. It usually contains the data itself and its metadata information.
89 |
90 | Step 1: get a random image path: e.g., path = self.image_paths[index]
91 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
92 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
93 | Step 4: return a data point as a dictionary.
94 | """
95 | path = self.image_paths[index]
96 | name_parts=path.split('_')
97 | mask_path = self.image_paths[index].replace('composite_images','masks')
98 | mask_path = mask_path.replace(('_'+name_parts[-1]),'.png')
99 | target_path = self.image_paths[index].replace('composite_images','real_images')
100 | target_path = target_path.replace(('_'+name_parts[-2]+'_'+name_parts[-1]),'.jpg')
101 |
102 | comp = Image.open(path).convert('RGB')
103 | real = Image.open(target_path).convert('RGB')
104 | mask = Image.open(mask_path).convert('1')
105 |
106 | if np.random.rand() > 0.5 and self.isTrain:
107 | comp, mask, real = tf.hflip(comp), tf.hflip(mask), tf.hflip(real)
108 |
109 | if comp.size[0] != self.image_size:
110 | comp = tf.resize(comp, [self.image_size, self.image_size])
111 | mask = tf.resize(mask, [self.image_size, self.image_size])
112 | real = tf.resize(real, [self.image_size,self.image_size])
113 |
114 | comp = self.transforms(comp)
115 | mask = tf.to_tensor(mask)
116 | real = self.transforms(real)
117 |
118 | inputs=torch.cat([comp,mask],0)
119 |
120 | return {'inputs': inputs, 'comp': comp, 'real': real,'img_path':path,'mask':mask}
121 |
122 | def __len__(self):
123 | """Return the total number of images."""
124 | return len(self.image_paths)
125 |
126 |
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | """A modified image folder class
2 |
3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4 | so that this class can load images from both current directory and its subdirectories.
5 | """
6 |
7 | import torch.utils.data as data
8 |
9 | from PIL import Image
10 | import os
11 | import os.path
12 |
13 | IMG_EXTENSIONS = [
14 | '.jpg', '.JPG', '.jpeg', '.JPEG',
15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
16 | ]
17 |
18 |
19 | def is_image_file(filename):
20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21 |
22 |
23 | def make_dataset(dir, max_dataset_size=float("inf")):
24 | images = []
25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
26 |
27 | for root, _, fnames in sorted(os.walk(dir)):
28 | for fname in fnames:
29 | if is_image_file(fname):
30 | path = os.path.join(root, fname)
31 | images.append(path)
32 | return images[:min(max_dataset_size, len(images))]
33 |
34 |
35 | def default_loader(path):
36 | return Image.open(path).convert('RGB')
37 |
38 |
39 | class ImageFolder(data.Dataset):
40 |
41 | def __init__(self, root, transform=None, return_paths=False,
42 | loader=default_loader):
43 | imgs = make_dataset(root)
44 | if len(imgs) == 0:
45 | raise(RuntimeError("Found 0 images in: " + root + "\n"
46 | "Supported image extensions are: " +
47 | ",".join(IMG_EXTENSIONS)))
48 |
49 | self.root = root
50 | self.imgs = imgs
51 | self.transform = transform
52 | self.return_paths = return_paths
53 | self.loader = loader
54 |
55 | def __getitem__(self, index):
56 | path = self.imgs[index]
57 | img = self.loader(path)
58 | if self.transform is not None:
59 | img = self.transform(img)
60 | if self.return_paths:
61 | return img, path
62 | else:
63 | return img
64 |
65 | def __len__(self):
66 | return len(self.imgs)
67 |
--------------------------------------------------------------------------------
/data/loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | """Data loader."""
5 |
6 | import itertools
7 | import numpy as np
8 | import torch
9 | from torch.utils.data._utils.collate import default_collate
10 | from torch.utils.data.distributed import DistributedSampler
11 | from torch.utils.data.sampler import RandomSampler
12 |
13 | from slowfast.datasets.multigrid_helper import ShortCycleBatchSampler
14 |
15 | from . import utils as utils
16 |
17 | def build_dataset(cfg):
18 | image_paths = []
19 | if cfg.phase == 'train':
20 | print('loading training file')
21 | file = cfg.dataset_root+cfg.dataset_name+'_train.txt'
22 | with open(file,'r') as f:
23 | for line in f.readlines():
24 | image_paths.append(os.path.join(cfg.dataset_root,'composite_images',line.rstrip()))
25 |
26 |
27 | def construct_loader(cfg, split, is_precise_bn=False):
28 | """
29 | Constructs the data loader for the given dataset.
30 | Args:
31 | cfg (CfgNode): configs. Details can be found in
32 | slowfast/config/defaults.py
33 | split (str): the split of the data loader. Options include `train`,
34 | `val`, and `test`.
35 | """
36 | assert split in ["train", "val", "test"]
37 | if split in ["train"]:
38 | dataset_name = cfg.TRAIN.DATASET
39 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
40 | shuffle = True
41 | drop_last = True
42 | elif split in ["val"]:
43 | dataset_name = cfg.TRAIN.DATASET
44 | batch_size = int(cfg.TRAIN.BATCH_SIZE / max(1, cfg.NUM_GPUS))
45 | shuffle = False
46 | drop_last = False
47 | elif split in ["test"]:
48 | dataset_name = cfg.TEST.DATASET
49 | batch_size = int(cfg.TEST.BATCH_SIZE / max(1, cfg.NUM_GPUS))
50 | shuffle = False
51 | drop_last = False
52 |
53 | # Construct the dataset
54 | dataset = build_dataset(dataset_name, cfg, split)
55 |
56 | if cfg.MULTIGRID.SHORT_CYCLE and split in ["train"] and not is_precise_bn:
57 | # Create a sampler for multi-process training
58 | sampler = utils.create_sampler(dataset, shuffle, cfg)
59 | batch_sampler = ShortCycleBatchSampler(
60 | sampler, batch_size=batch_size, drop_last=drop_last, cfg=cfg
61 | )
62 | # Create a loader
63 | loader = torch.utils.data.DataLoader(
64 | dataset,
65 | batch_sampler=batch_sampler,
66 | num_workers=cfg.DATA_LOADER.NUM_WORKERS,
67 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
68 | worker_init_fn=utils.loader_worker_init_fn(dataset),
69 | )
70 | else:
71 | # Create a sampler for multi-process training
72 | sampler = utils.create_sampler(dataset, shuffle, cfg)
73 | # Create a loader
74 | loader = torch.utils.data.DataLoader(
75 | dataset,
76 | batch_size=batch_size,
77 | shuffle=(False if sampler else shuffle),
78 | sampler=sampler,
79 | num_workers=cfg.DATA_LOADER.NUM_WORKERS,
80 | pin_memory=cfg.DATA_LOADER.PIN_MEMORY,
81 | drop_last=drop_last,
82 | collate_fn=detection_collate if cfg.DETECTION.ENABLE else None,
83 | worker_init_fn=utils.loader_worker_init_fn(dataset),
84 | )
85 | return loader
86 |
87 |
88 | def shuffle_dataset(loader, cur_epoch):
89 | """ "
90 | Shuffles the data.
91 | Args:
92 | loader (loader): data loader to perform shuffle.
93 | cur_epoch (int): number of the current epoch.
94 | """
95 | sampler = (
96 | loader.batch_sampler.sampler
97 | if isinstance(loader.batch_sampler, ShortCycleBatchSampler)
98 | else loader.sampler
99 | )
100 | assert isinstance(
101 | sampler, (RandomSampler, DistributedSampler)
102 | ), "Sampler type '{}' not supported".format(type(sampler))
103 | # RandomSampler handles shuffling automatically
104 | if isinstance(sampler, DistributedSampler):
105 | # DistributedSampler shuffles data based on epoch
106 | sampler.set_epoch(cur_epoch)
107 |
--------------------------------------------------------------------------------
/data/real_dataset.py:
--------------------------------------------------------------------------------
1 | """Dataset class template
2 |
3 | This module provides a template for users to implement custom datasets.
4 | You can specify '--dataset_mode template' to use this dataset.
5 | The class name should be consistent with both the filename and its dataset_mode option.
6 | The filename should be _dataset.py
7 | The class name should be Dataset.py
8 | You need to implement the following functions:
9 | -- : Add dataset-specific options and rewrite default values for existing options.
10 | -- <__init__>: Initialize this dataset class.
11 | -- <__getitem__>: Return a data point and its metadata information.
12 | -- <__len__>: Return the number of images.
13 | """
14 | import os.path
15 | import torch
16 | import torchvision.transforms.functional as tf
17 | from data.base_dataset import BaseDataset, get_transform
18 | from data.image_folder import make_dataset
19 | from PIL import Image
20 | import numpy as np
21 | import torchvision.transforms as transforms
22 | from scipy import sparse
23 | from util import util
24 |
25 | class RealDataset(BaseDataset):
26 | """A template dataset class for you to implement custom datasets."""
27 | @staticmethod
28 | def modify_commandline_options(parser, is_train):
29 | """Add new dataset-specific options, and rewrite default values for existing options.
30 |
31 | Parameters:
32 | parser -- original option parser
33 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
34 |
35 | Returns:
36 | the modified parser.
37 | """
38 | parser.add_argument('--is_train', type=bool, default=True, help='whether in the training phase')
39 | parser.set_defaults(max_dataset_size=float("inf"), new_dataset_option=2.0) # specify dataset-specific default values
40 | return parser
41 |
42 | def __init__(self, opt):
43 | """Initialize this dataset class.
44 |
45 | Parameters:
46 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
47 |
48 | A few things can be done here.
49 | - save the options (have been done in BaseDataset)
50 | - get image paths and meta information of the dataset.
51 | - define the image transformation.
52 | """
53 | # save the option and dataset root
54 | BaseDataset.__init__(self, opt)
55 | self.image_paths = []
56 | self.isTrain = opt.isTrain
57 | self.image_size = opt.crop_size
58 | if opt.isTrain==True:
59 | print('loading training file')
60 | self.trainfile = opt.dataset_root+opt.dataset_name+'_train.txt'
61 | with open(self.trainfile,'r') as f:
62 | for line in f.readlines():
63 | self.image_paths.append(os.path.join(opt.dataset_root,'composite_images',line.rstrip()))
64 | elif opt.isTrain==False:
65 | print('loading test file')
66 | self.trainfile = opt.dataset_root+opt.dataset_name+'_test.txt'
67 | with open(self.trainfile,'r') as f:
68 | for line in f.readlines():
69 | self.image_paths.append(os.path.join(opt.dataset_root,'composite_images',line.rstrip()))
70 | # get the image paths of your dataset;
71 | # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
72 | # define the default transform function. You can use ; You can also define your custom transform function
73 | self.transform = get_transform(opt)
74 |
75 | def __getitem__(self, index):
76 | path = self.image_paths[index]
77 | mask_path = self.image_paths[index].replace('composite_images','masks')
78 |
79 | comp = Image.open(path).convert('RGB')
80 | mask = Image.open(mask_path).convert('1')
81 |
82 | if np.random.rand() > 0.5 and self.isTrain:
83 | comp, mask = tf.hflip(comp), tf.hflip(mask)
84 |
85 | comp = tf.resize(comp, [self.image_size, self.image_size])
86 | mask = tf.resize(mask, [self.image_size, self.image_size])
87 | comp = self.transform(comp)
88 | mask = tf.to_tensor(mask)
89 | inputs=torch.cat([comp,mask],0)
90 |
91 | return {'inputs': inputs, 'comp': comp, 'real': comp,'img_path':path,'mask':mask}
92 |
93 | def __len__(self):
94 | """Return the total number of images."""
95 | return len(self.image_paths)
96 |
--------------------------------------------------------------------------------
/evaluation/pytorch_ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 | import os
7 |
8 |
9 | def gaussian(window_size, sigma):
10 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
11 | return gauss/gauss.sum()
12 |
13 | def create_window(window_size, channel):
14 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
15 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
16 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
17 | return window
18 |
19 | def _ssim(img1, img2, window, window_size, channel, size_average = True, mask=None):
20 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
21 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
22 |
23 | mu1_sq = mu1.pow(2)
24 | mu2_sq = mu2.pow(2)
25 | mu1_mu2 = mu1*mu2
26 |
27 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
28 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
29 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
30 |
31 | C1 = 0.01**2
32 | C2 = 0.03**2
33 |
34 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
35 |
36 | if mask is not None:
37 | mask_sum = mask.sum()
38 | fg_ssim_map = ssim_map*mask
39 | fg_ssim_map_sum = fg_ssim_map.sum(3).sum(2)
40 | fg_ssim = fg_ssim_map_sum/mask_sum
41 | fg_ssim_mu = fg_ssim.mean()
42 | ssim_mu = ssim_map.mean()
43 | return ssim_mu.item(), fg_ssim_mu.item()
44 |
45 | # if size_average:
46 | # return ssim_map.mean()
47 | # else:
48 | # return ssim_map.mean(1).mean(1).mean(1)
49 |
50 | class SSIM(torch.nn.Module):
51 | def __init__(self, window_size = 11, size_average = True):
52 | super(SSIM, self).__init__()
53 | self.window_size = window_size
54 | self.size_average = size_average
55 | self.channel = 1
56 | self.window = create_window(window_size, self.channel)
57 |
58 | def forward(self, img1, img2):
59 | (_, channel, _, _) = img1.size()
60 |
61 | if channel == self.channel and self.window.data.type() == img1.data.type():
62 | window = self.window
63 | else:
64 | window = create_window(self.window_size, channel)
65 |
66 | if img1.is_cuda:
67 | window = window.cuda(img1.get_device())
68 | window = window.type_as(img1)
69 |
70 | self.window = window
71 | self.channel = channel
72 |
73 |
74 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
75 |
76 | def ssim(img1, img2, window_size = 11, size_average = True, mask=None):
77 | (_, channel, _, _) = img1.size()
78 | window = create_window(window_size, channel)
79 |
80 | if img1.is_cuda:
81 | window = window.cuda(img1.get_device())
82 | window = window.type_as(img1)
83 |
84 | return _ssim(img1, img2, window, window_size, channel, size_average, mask)
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to objective functions, optimizations, and network architectures.
2 |
3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4 | You need to implement the following five functions:
5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6 | -- : unpack data from dataset and apply preprocessing.
7 | -- : produce intermediate results.
8 | -- : calculate loss, gradients, and update network weights.
9 | -- : (optionally) add model-specific options and set default options.
10 |
11 | In the function <__init__>, you need to define four lists:
12 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
13 | -- self.model_names (str list): define networks used in our training.
14 | -- self.visual_names (str list): specify the images that you want to display and save.
15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16 |
17 | Now you can use the model class by specifying flag '--model dummy'.
18 | See our template model class 'template_model.py' for more details.
19 | """
20 |
21 | import importlib
22 | from models.base_model import BaseModel
23 |
24 |
25 | def find_model_using_name(model_name):
26 | """Import the module "models/[model_name]_model.py".
27 |
28 | In the file, the class called DatasetNameModel() will
29 | be instantiated. It has to be a subclass of BaseModel,
30 | and it is case-insensitive.
31 | """
32 | model_filename = "models." + model_name + "_model"
33 | modellib = importlib.import_module(model_filename)
34 | model = None
35 | target_model_name = model_name.replace('_', '') + 'model'
36 | for name, cls in modellib.__dict__.items():
37 | if name.lower() == target_model_name.lower() \
38 | and issubclass(cls, BaseModel):
39 | model = cls
40 |
41 | if model is None:
42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43 | exit(0)
44 |
45 | return model
46 |
47 |
48 | def get_option_setter(model_name):
49 | """Return the static method of the model class."""
50 | model_class = find_model_using_name(model_name)
51 | return model_class.modify_commandline_options
52 |
53 |
54 | def create_model(opt):
55 | """Create a model given the option.
56 |
57 | This function warps the class CustomDatasetDataLoader.
58 | This is the main interface between this package and 'train.py'/'test.py'
59 |
60 | Example:
61 | >>> from models import create_model
62 | >>> model = create_model(opt)
63 | """
64 | model = find_model_using_name(opt.model)
65 | instance = model(opt)
66 | print("model [%s] was created" % type(instance).__name__)
67 | return instance
68 |
--------------------------------------------------------------------------------
/models/iih_base_gd_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import itertools
4 | import torch.nn.functional as F
5 | from util import distributed as du
6 | from .base_model import BaseModel
7 | from util import util
8 | from . import harmony_networks as networks
9 | import util.ssim as ssim
10 |
11 |
12 | class IIHBaseGDModel(BaseModel):
13 | @staticmethod
14 | def modify_commandline_options(parser, is_train=True):
15 | parser.set_defaults(norm='instance', netG='base_gd', dataset_mode='ihd')
16 | if is_train:
17 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
18 | parser.add_argument('--lambda_R_gradient', type=float, default=20., help='weight for reflectance gradient loss')
19 | parser.add_argument('--lambda_I_L2', type=float, default=10., help='weight for illumination L2 loss')
20 | parser.add_argument('--lambda_I_smooth', type=float, default=1, help='weight for Illumination smooth loss')
21 | parser.add_argument('--lambda_ifm', type=float, default=100, help='weight for pm loss')
22 | return parser
23 |
24 | def __init__(self, opt):
25 | BaseModel.__init__(self, opt)
26 | self.opt = opt
27 | # specify the training losses you want to print out. The training/test scripts will call
28 | self.loss_names = ['G','G_L1',"IF"]
29 | if opt.loss_RH:
30 | self.loss_names.append("G_R_grident")
31 | if opt.loss_IH:
32 | self.loss_names.append("G_I_L2")
33 | if opt.loss_IS:
34 | self.loss_names.append("G_I_smooth")
35 |
36 | # specify the images you want to save/display. The training/test scripts will call
37 | self.visual_names = ['mask', 'harmonized','comp','real','reflectance','illumination','ifm_mean']
38 | # specify the models you want to save to the disk. The training/test scripts will call and
39 | self.model_names = ['G']
40 | self.opt.device = self.device
41 | self.netG = networks.define_G(opt.netG, opt.init_type, opt.init_gain, self.opt)
42 | self.cur_device = torch.cuda.current_device()
43 | self.ismaster = du.is_master_proc(opt.NUM_GPUS)
44 | if self.ismaster:
45 | print(self.netG)
46 |
47 | if self.isTrain:
48 | util.saveprint(self.opt, 'netG', str(self.netG))
49 | # define loss functions
50 | self.criterionL1 = torch.nn.L1Loss().cuda(self.cur_device)
51 | self.criterionL2 = torch.nn.MSELoss().cuda(self.cur_device)
52 | self.criterionDSSIM_CS = ssim.DSSIM(mode='c_s').to(self.device)
53 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
54 | self.optimizers.append(self.optimizer_G)
55 |
56 | def set_input(self, input):
57 | self.comp = input['comp'].to(self.device)
58 | self.real = input['real'].to(self.device)
59 | self.inputs = input['inputs'].to(self.device)
60 | self.mask = input['mask'].to(self.device)
61 | self.image_paths = input['img_path']
62 | self.mask_r = F.interpolate(self.mask, size=[64,64])
63 | self.mask_r_32 = F.interpolate(self.mask, size=[32,32])
64 | self.real_r = F.interpolate(self.real, size=[32,32])
65 | self.real_gray = util.rgbtogray(self.real_r)
66 |
67 | def forward(self):
68 | """Run forward pass; called by both functions and ."""
69 | self.harmonized, self.reflectance, self.illumination, self.ifm_mean = self.netG(self.inputs, self.mask_r, self.mask_r_32)
70 | if not self.isTrain:
71 | self.harmonized = self.comp*(1-self.mask) + self.harmonized*self.mask
72 | def backward_G(self):
73 | """Calculate GAN and L1 loss for the generator"""
74 | self.loss_IF = self.criterionDSSIM_CS(self.ifm_mean, self.real_gray)*self.opt.lambda_ifm
75 |
76 | self.loss_G_L1 = self.criterionL1(self.harmonized, self.real)*self.opt.lambda_L1
77 | self.loss_G = self.loss_G_L1+self.loss_IF
78 | if self.opt.loss_RH:
79 | self.loss_G_R_grident = self.gradient_loss(self.reflectance, self.real)*self.opt.lambda_R_gradient
80 | self.loss_G = self.loss_G + self.loss_G_R_grident
81 | if self.opt.loss_IH:
82 | self.loss_G_I_L2 = self.criterionL2(self.illumination, self.real)*self.opt.lambda_I_L2
83 | self.loss_G = self.loss_G + self.loss_G_I_L2
84 | if self.opt.loss_IS:
85 | self.loss_G_I_smooth = util.compute_smooth_loss(self.illumination)*self.opt.lambda_I_smooth
86 | self.loss_G = self.loss_G + self.loss_G_I_smooth
87 | self.loss_G.backward()
88 |
89 | def optimize_parameters(self):
90 | self.forward() # compute fake images: G(A)
91 | # update G
92 | self.optimizer_G.zero_grad() # set G's gradients to zero
93 | self.backward_G() # calculate graidents for G
94 | self.optimizer_G.step() # udpate G's weights
95 |
96 | def gradient_loss(self, input_1, input_2):
97 | g_x = self.criterionL1(util.gradient(input_1, 'x'), util.gradient(input_2, 'x'))
98 | g_y = self.criterionL1(util.gradient(input_1, 'y'), util.gradient(input_2, 'y'))
99 | return g_x+g_y
100 |
101 |
102 |
--------------------------------------------------------------------------------
/models/iih_base_lt_gd_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import itertools
4 | import torch.nn.functional as F
5 | from util import distributed as du
6 | from .base_model import BaseModel
7 | from util import util
8 | from . import harmony_networks as networks
9 | import util.ssim as ssim
10 |
11 | class IIHBaseLTGDModel(BaseModel):
12 | @staticmethod
13 | def modify_commandline_options(parser, is_train=True):
14 | parser.set_defaults(norm='instance', netG='base_lt_gd', dataset_mode='ihd')
15 | if is_train:
16 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
17 | parser.add_argument('--lambda_R_gradient', type=float, default=20., help='weight for reflectance gradient loss')
18 | parser.add_argument('--lambda_I_L2', type=float, default=10, help='weight for illumination L2 loss')
19 | parser.add_argument('--lambda_I_smooth', type=float, default=1, help='weight for Illumination smooth loss')
20 | parser.add_argument('--lambda_ifm', type=float, default=100, help='weight for pm loss')
21 |
22 | return parser
23 |
24 | def __init__(self, opt):
25 | """Initialize the pix2pix class.
26 |
27 | Parameters:
28 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
29 | """
30 | BaseModel.__init__(self, opt)
31 | self.opt = opt
32 | self.loss_names = ['G','G_L1','G_R_grident','G_I_L2','G_I_smooth',"IF"]
33 |
34 | # specify the images you want to save/display. The training/test scripts will call
35 | self.visual_names = ['mask', 'harmonized','comp','real','reflectance','illumination','ifm_mean']
36 | # specify the models you want to save to the disk. The training/test scripts will call and
37 | self.model_names = ['G']
38 | self.opt.device = self.device
39 | self.netG = networks.define_G(opt.netG, opt.init_type, opt.init_gain, self.opt)
40 | self.cur_device = torch.cuda.current_device()
41 | self.ismaster = du.is_master_proc(opt.NUM_GPUS)
42 | if self.ismaster:
43 | print(self.netG)
44 |
45 | if self.isTrain:
46 | if self.ismaster == 0:
47 | util.saveprint(self.opt, 'netG', str(self.netG))
48 | # define loss functions
49 | self.criterionL1 = torch.nn.L1Loss().cuda(self.cur_device)
50 | self.criterionL2 = torch.nn.MSELoss().cuda(self.cur_device)
51 | self.criterionDSSIM_CS = ssim.DSSIM(mode='c_s').to(self.device)
52 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
53 | self.optimizers.append(self.optimizer_G)
54 |
55 | def set_input(self, input):
56 | self.comp = input['comp'].to(self.device)
57 | self.real = input['real'].to(self.device)
58 | self.inputs = input['inputs'].to(self.device)
59 | self.mask = input['mask'].to(self.device)
60 | self.image_paths = input['img_path']
61 |
62 | self.mask_r = F.interpolate(self.mask, size=[64,64])
63 | self.mask_r_32 = F.interpolate(self.mask, size=[32,32])
64 | self.real_r = F.interpolate(self.real, size=[32,32])
65 | self.real_gray = util.rgbtogray(self.real_r)
66 |
67 | def forward(self):
68 | """Run forward pass; called by both functions and ."""
69 | self.harmonized, self.reflectance, self.illumination, self.ifm_mean = self.netG(self.inputs, self.mask_r, self.mask_r_32)
70 | if not self.isTrain:
71 | self.harmonized = self.comp*(1-self.mask) + self.harmonized*self.mask
72 | def backward_G(self):
73 | """Calculate GAN and L1 loss for the generator"""
74 | self.loss_IF = self.criterionDSSIM_CS(self.ifm_mean, self.real_gray)*self.opt.lambda_ifm
75 |
76 | self.loss_G_L1 = self.criterionL1(self.harmonized, self.real)*self.opt.lambda_L1
77 | self.loss_G_R_grident = self.gradient_loss(self.reflectance, self.real)*self.opt.lambda_R_gradient
78 | self.loss_G_I_L2 = self.criterionL2(self.illumination, self.real)*self.opt.lambda_I_L2
79 | self.loss_G_I_smooth = util.compute_smooth_loss(self.illumination)*self.opt.lambda_I_smooth
80 | # assert 0
81 | self.loss_G = self.loss_G_L1 + self.loss_G_R_grident + self.loss_G_I_L2 + self.loss_G_I_smooth + self.loss_IF
82 | self.loss_G.backward()
83 |
84 | def optimize_parameters(self):
85 | self.forward() # compute fake images: G(A)
86 | # update G
87 | self.optimizer_G.zero_grad() # set G's gradients to zero
88 | self.backward_G() # calculate graidents for G
89 | self.optimizer_G.step() # udpate G's weights
90 |
91 | def gradient_loss(self, input_1, input_2):
92 | g_x = self.criterionL1(util.gradient(input_1, 'x'), util.gradient(input_2, 'x'))
93 | g_y = self.criterionL1(util.gradient(input_1, 'y'), util.gradient(input_2, 'y'))
94 | return g_x+g_y
95 |
96 |
--------------------------------------------------------------------------------
/models/iih_base_lt_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import itertools
3 | import torch.nn.functional as F
4 | from util import distributed as du
5 | from .base_model import BaseModel
6 | from util import util
7 | from . import harmony_networks as networks
8 | from . import networks as network_init
9 |
10 |
11 | class IIHBaseLTModel(BaseModel):
12 | @staticmethod
13 | def modify_commandline_options(parser, is_train=True):
14 | parser.set_defaults(norm='instance', netG='base_lt', dataset_mode='ihd')
15 | if is_train:
16 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
17 | parser.add_argument('--lambda_R_gradient', type=float, default=20., help='weight for reflectance gradient loss')
18 | parser.add_argument('--lambda_I_L2', type=float, default=10., help='weight for illumination L2 loss')
19 | parser.add_argument('--lambda_I_smooth', type=float, default=1, help='weight for Illumination smooth loss')
20 |
21 | return parser
22 |
23 | def __init__(self, opt):
24 | BaseModel.__init__(self, opt)
25 | self.opt = opt
26 | self.loss_names = ['G','G_L1']
27 | if opt.loss_RH:
28 | self.loss_names.append("G_R_grident")
29 | if opt.loss_IH:
30 | self.loss_names.append("G_I_L2")
31 | if opt.loss_IS:
32 | self.loss_names.append("G_I_smooth")
33 |
34 | self.visual_names = ['mask', 'harmonized','comp','real','reflectance','illumination']
35 | self.model_names = ['G']
36 | self.opt.device = self.device
37 | self.netG = networks.define_G(opt.netG, opt.init_type, opt.init_gain, self.opt)
38 | self.cur_device = torch.cuda.current_device()
39 | self.ismaster = du.is_master_proc(opt.NUM_GPUS)
40 | if self.ismaster:
41 | print(self.netG)
42 |
43 | if self.isTrain:
44 | if self.ismaster == 0:
45 | util.saveprint(self.opt, 'netG', str(self.netG))
46 | self.criterionL1 = torch.nn.L1Loss().cuda(self.cur_device)
47 | self.criterionL2 = torch.nn.MSELoss().cuda(self.cur_device)
48 | # initialize optimizers; schedulers will be automatically created by function .
49 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
50 | self.optimizers.append(self.optimizer_G)
51 |
52 | def set_input(self, input):
53 | self.comp = input['comp'].to(self.device)
54 | self.real = input['real'].to(self.device)
55 | self.inputs = input['inputs'].to(self.device)
56 | self.mask = input['mask'].to(self.device)
57 | self.image_paths = input['img_path']
58 | self.mask_r = F.interpolate(self.mask, size=[64,64])
59 |
60 | def forward(self):
61 | """Run forward pass; called by both functions and ."""
62 | self.harmonized, self.reflectance, self.illumination = self.netG(self.inputs, self.mask, self.mask_r)
63 | if not self.isTrain:
64 | self.harmonized = self.comp*(1-self.mask) + self.harmonized*self.mask
65 |
66 | def backward_G(self):
67 | """Calculate GAN and L1 loss for the generator"""
68 | self.loss_G_L1 = self.criterionL1(self.harmonized, self.real)*self.opt.lambda_L1
69 | self.loss_G = self.loss_G_L1
70 | if self.opt.loss_RH:
71 | self.loss_G_R_grident = self.gradient_loss(self.reflectance, self.real)*self.opt.lambda_R_gradient
72 | self.loss_G = self.loss_G + self.loss_G_R_grident
73 | if self.opt.loss_IH:
74 | self.loss_G_I_L2 = self.criterionL2(self.illumination, self.real)*self.opt.lambda_I_L2
75 | self.loss_G = self.loss_G + self.loss_G_I_L2
76 | if self.opt.loss_IS:
77 | self.loss_G_I_smooth = util.compute_smooth_loss(self.illumination)*self.opt.lambda_I_smooth
78 | self.loss_G = self.loss_G + self.loss_G_I_smooth
79 | # self.loss_G_R_grident = self.gradient_loss(self.reflectance, self.real)*self.opt.lambda_R_gradient
80 | # self.loss_G_I_L2 = self.criterionL2(self.illumination, self.real)*self.opt.lambda_I_L2
81 | # self.loss_G_I_smooth = util.compute_smooth_loss(self.illumination)*self.opt.lambda_I_smooth
82 | # # assert 0
83 | # self.loss_G = self.loss_G_L1 + self.loss_G_R_grident + self.loss_G_I_L2 + self.loss_G_I_smooth
84 | self.loss_G.backward()
85 |
86 | def optimize_parameters(self):
87 | self.forward() # compute fake images: G(A)
88 | # update G
89 | self.optimizer_G.zero_grad() # set G's gradients to zero
90 | self.backward_G() # calculate graidents for G
91 | self.optimizer_G.step() # udpate G's weights
92 |
93 | def gradient_loss(self, input_1, input_2):
94 | g_x = self.criterionL1(util.gradient(input_1, 'x'), util.gradient(input_2, 'x'))
95 | g_y = self.criterionL1(util.gradient(input_1, 'y'), util.gradient(input_2, 'y'))
96 | return g_x+g_y
97 |
98 |
--------------------------------------------------------------------------------
/models/iih_base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import itertools
4 | import torch.nn.functional as F
5 | from util import distributed as du
6 | from .base_model import BaseModel
7 | from util import util
8 | from . import harmony_networks as networks
9 |
10 |
11 | class IIHBaseModel(BaseModel):
12 | @staticmethod
13 | def modify_commandline_options(parser, is_train=True):
14 | parser.set_defaults(norm='instance', netG='base', dataset_mode='ihd')
15 | if is_train:
16 | parser.add_argument('--lambda_L1', type=float, default=100.0, help='weight for L1 loss')
17 | parser.add_argument('--lambda_R_gradient', type=float, default=50., help='weight for reflectance gradient loss')
18 | parser.add_argument('--lambda_I_L2', type=float, default=10., help='weight for illumination L2 loss')
19 | parser.add_argument('--lambda_I_smooth', type=float, default=1, help='weight for Illumination smooth loss')
20 | return parser
21 |
22 | def __init__(self, opt):
23 | BaseModel.__init__(self, opt)
24 | self.opt = opt
25 | self.loss_names = ['G','G_L1']
26 | if opt.loss_RH:
27 | self.loss_names.append("G_R_grident")
28 | if opt.loss_IH:
29 | self.loss_names.append("G_I_L2")
30 | if opt.loss_IS:
31 | self.loss_names.append("G_I_smooth")
32 |
33 | self.visual_names = ['mask', 'harmonized','comp','real','reflectance','illumination']
34 | self.model_names = ['G']
35 | self.opt.device = self.device
36 | self.netG = networks.define_G(opt.netG, opt.init_type, opt.init_gain, self.opt)
37 | self.cur_device = torch.cuda.current_device()
38 | self.ismaster = du.is_master_proc(opt.NUM_GPUS)
39 | if self.ismaster:
40 | print(self.netG)
41 | if self.isTrain:
42 | util.saveprint(self.opt, 'netG', str(self.netG))
43 | self.criterionL1 = torch.nn.L1Loss().cuda(self.cur_device)
44 | self.criterionL2 = torch.nn.MSELoss().cuda(self.cur_device)
45 | # initialize optimizers; schedulers will be automatically created by function .
46 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
47 | self.optimizers.append(self.optimizer_G)
48 |
49 | def set_input(self, input):
50 | self.comp = input['comp'].to(self.device)
51 | self.real = input['real'].to(self.device)
52 | self.inputs = input['inputs'].to(self.device)
53 | self.mask = input['mask'].to(self.device)
54 | self.image_paths = input['img_path']
55 |
56 | def forward(self):
57 | self.harmonized, self.reflectance, self.illumination = self.netG(self.inputs, self.mask)
58 | if not self.isTrain:
59 | self.harmonized = self.comp*(1-self.mask) + self.harmonized*self.mask
60 |
61 | def backward_G(self):
62 | self.loss_G_L1 = self.criterionL1(self.harmonized, self.real)*self.opt.lambda_L1
63 | self.loss_G = self.loss_G_L1
64 | if self.opt.loss_RH:
65 | self.loss_G_R_grident = self.gradient_loss(self.reflectance, self.real)*self.opt.lambda_R_gradient
66 | self.loss_G = self.loss_G + self.loss_G_R_grident
67 | if self.opt.loss_IH:
68 | self.loss_G_I_L2 = self.criterionL2(self.illumination, self.real)*self.opt.lambda_I_L2
69 | self.loss_G = self.loss_G + self.loss_G_I_L2
70 | if self.opt.loss_IS:
71 | self.loss_G_I_smooth = util.compute_smooth_loss(self.illumination)*self.opt.lambda_I_smooth
72 | self.loss_G = self.loss_G + self.loss_G_I_smooth
73 | self.loss_G.backward()
74 |
75 | def optimize_parameters(self):
76 | self.forward() # compute fake images: G(A)
77 | # update G
78 | self.optimizer_G.zero_grad() # set G's gradients to zero
79 | self.backward_G() # calculate graidents for G
80 | self.optimizer_G.step() # udpate G's weights
81 |
82 |
83 | def gradient_loss(self, input_1, input_2):
84 | g_x = self.criterionL1(util.gradient(input_1, 'x'), util.gradient(input_2, 'x'))
85 | g_y = self.criterionL1(util.gradient(input_1, 'y'), util.gradient(input_2, 'y'))
86 | return g_x+g_y
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
2 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | """This class includes test options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser) # define shared options
12 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
13 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
14 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
15 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
16 | # Dropout and Batchnorm has different behavioir during training and test.
17 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.')
18 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run')
19 | parser.add_argument('--test_epoch', type=str, default="0", help='how many test images to run')
20 | # rewrite devalue values
21 | parser.set_defaults(model='test')
22 | # To avoid cropping, the load_size should be the same as crop_size
23 | parser.set_defaults(load_size=parser.get_default('crop_size'))
24 | self.isTrain = False
25 | return parser
26 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | """This class includes training options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser)
12 | # visdom and HTML visualization parameters
13 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
14 | parser.add_argument('--display_ncols', type=int, default=6, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
15 | parser.add_argument('--display_id', type=int, default=0, help='window id of the web display')
16 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
17 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
18 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
19 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
20 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
21 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
22 | # network saving and loading parameters
23 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
24 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
25 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
26 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
27 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
28 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
29 | # training parameters
30 | parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate')
31 | parser.add_argument('--niter_decay', type=int, default=50, help='# of iter to linearly decay learning rate to zero')
32 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
33 | parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
34 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
35 | parser.add_argument('--g_lr_ratio', type=float, default=1.0, help='a ratio for changing learning rate of generator')
36 | parser.add_argument('--d_lr_ratio', type=float, default=1.0, help='a ratio for changing learning rate of discriminator')
37 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
38 | parser.add_argument('--pool_size', type=int, default=40, help='the size of image buffer that stores previously generated images')
39 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]')
40 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
41 | parser.add_argument('--save_iter_model', action='store_true', help='whether saves model by iteration')
42 |
43 | self.isTrain = True
44 | return parser
45 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.4.0
2 | torchvision>=0.5.0
3 | dominate>=2.4.0
4 | visdom>=0.1.8.8
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 | from util.misc import launch_job
5 | from train_net import test
6 |
7 | from options.test_options import TestOptions
8 |
9 | def main():
10 | cfg = TestOptions().parse() # get training options
11 | cfg.NUM_GPUS = torch.cuda.device_count()
12 | cfg.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
13 | cfg.no_flip = True # no flip; comment this line if results on flipped images are needed.
14 | cfg.display_id = -1 # no visdom display; the test code saves the results to a HTML file.
15 |
16 | cfg.phase = 'test'
17 | cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS))
18 | launch_job(cfg=cfg, init_method=cfg.init_method, func=test)
19 |
20 |
21 | if __name__=="__main__":
22 | main()
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.distributed as dist
4 | from util.misc import launch_job
5 | from train_net import train
6 |
7 | from options.train_options import TrainOptions
8 |
9 | def main():
10 | cfg = TrainOptions().parse() # get training options
11 | cfg.NUM_GPUS = torch.cuda.device_count()
12 | cfg.batch_size = int(cfg.batch_size / max(1, cfg.NUM_GPUS))
13 | cfg.phase = 'train'
14 | launch_job(cfg=cfg, init_method=cfg.init_method, func=train)
15 |
16 |
17 | if __name__=="__main__":
18 | main()
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br
3 | import os
4 |
5 |
6 | class HTML:
7 | """This HTML class allows us to save images and write texts into a single HTML file.
8 |
9 | It consists of functions such as (add a text header to the HTML file),
10 | (add a row of images to the HTML file), and (save the HTML to the disk).
11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12 | """
13 |
14 | def __init__(self, web_dir, title, refresh=0):
15 | """Initialize the HTML classes
16 |
17 | Parameters:
18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
32 | with self.doc.head:
33 | meta(http_equiv="refresh", content=str(refresh))
34 |
35 | def get_image_dir(self):
36 | """Return the directory that stores images"""
37 | return self.img_dir
38 |
39 | def add_header(self, text):
40 | """Insert a header to the HTML file
41 |
42 | Parameters:
43 | text (str) -- the header text
44 | """
45 | with self.doc:
46 | h3(text)
47 |
48 | def add_images(self, ims, txts, links, width=400):
49 | """add images to the HTML file
50 |
51 | Parameters:
52 | ims (str list) -- a list of image paths
53 | txts (str list) -- a list of image names shown on the website
54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
55 | """
56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table
57 | self.doc.add(self.t)
58 | with self.t:
59 | with tr():
60 | for im, txt, link in zip(ims, txts, links):
61 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
62 | with p():
63 | with a(href=os.path.join('images', link)):
64 | img(style="width:%dpx" % width, src=os.path.join('images', im))
65 | br()
66 | p(txt)
67 |
68 | def save(self):
69 | """save the current content to the HMTL file"""
70 | html_file = '%s/index.html' % self.web_dir
71 | f = open(html_file, 'wt')
72 | f.write(self.doc.render())
73 | f.close()
74 |
75 |
76 | if __name__ == '__main__': # we show an example usage here.
77 | html = HTML('web/', 'test_html')
78 | html.add_header('hello world')
79 |
80 | ims, txts, links = [], [], []
81 | for n in range(4):
82 | ims.append('image_%d.png' % n)
83 | txts.append('text_%d' % n)
84 | links.append('image_%d.png' % n)
85 | html.add_images(ims, txts, links)
86 | html.save()
87 |
--------------------------------------------------------------------------------
/util/multiprocessing.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
3 |
4 | """Multiprocessing helpers."""
5 |
6 | import torch
7 |
8 |
9 | def run(
10 | local_rank,
11 | num_proc,
12 | func,
13 | init_method,
14 | shard_id,
15 | num_shards,
16 | backend,
17 | cfg,
18 | output_queue=None,
19 | ):
20 | """
21 | Runs a function from a child process.
22 | Args:
23 | local_rank (int): rank of the current process on the current machine.
24 | num_proc (int): number of processes per machine.
25 | func (function): function to execute on each of the process.
26 | init_method (string): method to initialize the distributed training.
27 | TCP initialization: equiring a network address reachable from all
28 | processes followed by the port.
29 | Shared file-system initialization: makes use of a file system that
30 | is shared and visible from all machines. The URL should start with
31 | file:// and contain a path to a non-existent file on a shared file
32 | system.
33 | shard_id (int): the rank of the current machine.
34 | num_shards (int): number of overall machines for the distributed
35 | training job.
36 | backend (string): three distributed backends ('nccl', 'gloo', 'mpi') are
37 | supports, each with different capabilities. Details can be found
38 | here:
39 | https://pytorch.org/docs/stable/distributed.html
40 | cfg (CfgNode): configs. Details can be found in
41 | slowfast/config/defaults.py
42 | output_queue (queue): can optionally be used to return values from the
43 | master process.
44 | """
45 | # Initialize the process group.
46 | world_size = num_proc * num_shards
47 | rank = shard_id * num_proc + local_rank
48 | try:
49 | torch.distributed.init_process_group(
50 | backend=backend,
51 | init_method=init_method,
52 | world_size=world_size,
53 | rank=rank,
54 | )
55 |
56 | except Exception as e:
57 | raise e
58 |
59 | torch.cuda.set_device(local_rank)
60 | ret = func(cfg)
61 | if output_queue is not None and local_rank == 0:
62 | output_queue.put(ret)
63 |
--------------------------------------------------------------------------------
/util/ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 | def create_window(window_size, channel):
12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15 | return window
16 |
17 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
20 |
21 | mu1_sq = mu1.pow(2)
22 | mu2_sq = mu2.pow(2)
23 | mu1_mu2 = mu1*mu2
24 |
25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
28 |
29 | C1 = 0.01**2
30 | C2 = 0.03**2
31 |
32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
33 |
34 | if size_average:
35 | return ssim_map.mean()
36 | else:
37 | return ssim_map.mean(1).mean(1).mean(1)
38 |
39 | def _ssim_c_s(img1, img2, window, window_size, channel, size_average = True):
40 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
41 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
42 |
43 | mu1_sq = mu1.pow(2)
44 | mu2_sq = mu2.pow(2)
45 | mu1_mu2 = mu1*mu2
46 |
47 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
48 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
49 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
50 |
51 | C1 = 0.01**2
52 | C2 = 0.03**2
53 |
54 | ssim_map = ((2*sigma12 + C2))/((sigma1_sq + sigma2_sq + C2))
55 |
56 | if size_average:
57 | return ssim_map.mean()
58 | else:
59 | return ssim_map.mean(1).mean(1).mean(1)
60 | def _ssim_l(img1, img2, window, window_size, channel, size_average = True):
61 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
62 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
63 |
64 | mu1_sq = mu1.pow(2)
65 | mu2_sq = mu2.pow(2)
66 | mu1_mu2 = mu1*mu2
67 |
68 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
69 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
70 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
71 |
72 | C1 = 0.01**2
73 | C2 = 0.03**2
74 |
75 | ssim_map = (2*mu1_mu2 + C1)/(mu1_sq + mu2_sq + C1)
76 |
77 | if size_average:
78 | return ssim_map.mean()
79 | else:
80 | return ssim_map.mean(1).mean(1).mean(1)
81 |
82 |
83 | def ssim(img1, img2, window_size = 11, size_average = True):
84 | (_, channel, _, _) = img1.size()
85 | window = create_window(window_size, channel)
86 |
87 | if img1.is_cuda:
88 | window = window.cuda(img1.get_device())
89 | window = window.type_as(img1)
90 |
91 | return _ssim(img1, img2, window, window_size, channel, size_average)
92 |
93 | class SSIM(torch.nn.Module):
94 | def __init__(self, window_size = 11, size_average = True, mode='all'):
95 | super(SSIM, self).__init__()
96 | self.window_size = window_size
97 | self.size_average = size_average
98 | self.channel = 1
99 | self.window = create_window(window_size, self.channel)
100 | self.mode = mode
101 | def forward(self, img1, img2):
102 | (_, channel, _, _) = img1.size()
103 |
104 | if channel == self.channel and self.window.data.type() == img1.data.type():
105 | window = self.window
106 | else:
107 | window = create_window(self.window_size, channel)
108 |
109 | # if img1.is_cuda:
110 | # window = window.cuda(img1.get_device())
111 | window = window.type_as(img1)
112 |
113 | self.window = window
114 | self.channel = channel
115 | if self.mode == 'all':
116 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
117 | elif self.mode == 'c_s':
118 | return _ssim_c_s(img1, img2, window, self.window_size, channel, self.size_average)
119 | else:
120 | return _ssim_l(img1, img2, window, self.window_size, channel, self.size_average)
121 |
122 | class DSSIM(torch.nn.Module):
123 | def __init__(self, window_size = 11, size_average = True, mode='all'):
124 | super(DSSIM, self).__init__()
125 | self.window_size = window_size
126 | self.size_average = size_average
127 | self.channel = 1
128 | self.window = create_window(window_size, self.channel)
129 | self.mode = mode
130 | def forward(self, img1, img2):
131 | (_, channel, _, _) = img1.size()
132 |
133 | if channel == self.channel and self.window.data.type() == img1.data.type():
134 | window = self.window
135 | else:
136 | window = create_window(self.window_size, channel)
137 |
138 | # if img1.is_cuda:
139 | # window = window.cuda(img1.get_device())
140 | window = window.type_as(img1)
141 |
142 | self.window = window
143 | self.channel = channel
144 | if self.mode == 'all':
145 | ssim_v = _ssim(img1, img2, window, self.window_size, channel, self.size_average)
146 | elif self.mode == 'c_s':
147 | ssim_v = _ssim_c_s(img1, img2, window, self.window_size, channel, self.size_average)
148 | else:
149 | ssim_v = _ssim_l(img1, img2, window, self.window_size, channel, self.size_average)
150 | return (1-ssim_v)/2
--------------------------------------------------------------------------------