├── img
├── fig1.jpg
├── fig2.jpg
└── fig3.jpg
├── evaluation
└── evaluate_MAE_PSNR_SSIM.m
├── show_eval_result.py
├── script
├── flist.py
├── flist_train_val_test.py
├── process_syn_data.py
├── generate_flist_srd.py
└── generate_flist_istd.py
├── debug_ready.py
├── src
├── config.py
├── main.py
├── image_pool.py
├── models.py
├── metrics.py
├── dataset.py
├── loss.py
├── network
│ ├── networks.py
│ └── network_DMTN.py
├── utils.py
└── model_top.py
├── deploy_ready.py
├── config
├── config_SRD_da.yml
├── config_ISTD_da.yml
├── config_SRD.yml
├── config_ISTD.yml
├── run_ISTD.py
├── run_SRD.py
├── run_SRD_da.py
└── run_ISTD_da.py
├── install.yaml
└── README.md
/img/fig1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nachifur/DMTN/HEAD/img/fig1.jpg
--------------------------------------------------------------------------------
/img/fig2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nachifur/DMTN/HEAD/img/fig2.jpg
--------------------------------------------------------------------------------
/img/fig3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nachifur/DMTN/HEAD/img/fig3.jpg
--------------------------------------------------------------------------------
/evaluation/evaluate_MAE_PSNR_SSIM.m:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nachifur/DMTN/HEAD/evaluation/evaluate_MAE_PSNR_SSIM.m
--------------------------------------------------------------------------------
/show_eval_result.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | rmse = np.genfromtxt(
3 | 'checkpoints/ISTD/show_result.txt', dtype=np.str, encoding='utf-8').astype(np.float)
4 |
5 | print('running rmse-shadow: %.4f, rmse-non-shadow: %.4f, rmse-all: %.4f'
6 | % (rmse[0], rmse[1], rmse[2]))
7 |
--------------------------------------------------------------------------------
/script/flist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from pathlib import Path
4 |
5 | def gen_flist(data_path):
6 | ext = {'.JPG', '.JPEG', '.PNG', '.TIF', 'TIFF','json'}
7 | images = []
8 | for root, dirs, files in os.walk(data_path):
9 | print('loading ' + root)
10 | for file in files:
11 | if os.path.splitext(file)[1].upper() in ext:
12 | images.append(os.path.join(root, file))
13 |
14 | images = sorted(images)
15 | return images
16 |
--------------------------------------------------------------------------------
/debug_ready.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import os
3 |
4 | # # clear
5 | # flag = os.system("rm -rf checkpoints")
6 | # if flag == 0:
7 | # print("clear checkpoints success")
8 | # flag = os.system("rm -rf src/__pycache__")
9 | # if flag == 0:
10 | # print("clear src/__pycache__ success")
11 |
12 | fr = open("config.yml", 'r')
13 | config = yaml.load(fr, Loader=yaml.FullLoader)
14 | config["DEBUG"] = 1
15 | config["GPU"] = [0]
16 | with open("config.yml", 'w') as f_obj:
17 | yaml.dump(config, f_obj)
18 | print("in debug mode")
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import yaml
4 |
5 |
6 | class Config(dict):
7 | def __init__(self, config_path):
8 | with open(config_path, 'r') as f:
9 | self._yaml = f.read()
10 | self._dict = yaml.load(self._yaml, Loader=yaml.FullLoader)
11 | self._dict['PATH'] = os.path.dirname(config_path)
12 |
13 | def __getattr__(self, name):
14 | if self._dict.get(name) is not None:
15 | return self._dict[name]
16 |
17 | return None
18 |
19 | def print(self):
20 | print('Model configurations:')
21 | print('---------------------------------')
22 | print(self._yaml)
23 | print('')
24 | print('---------------------------------')
25 | print('')
26 |
--------------------------------------------------------------------------------
/deploy_ready.py:
--------------------------------------------------------------------------------
1 | import os
2 | import yaml
3 | import argparse
4 | from pathlib import Path
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument('--gpu', help="gpu")
8 | ARGS = parser.parse_args()
9 | gpu = ARGS.gpu
10 |
11 |
12 |
13 | fr = open("config.yml", 'r')
14 | config = yaml.load(fr, Loader=yaml.FullLoader)
15 | config["DEBUG"] = 0
16 | config["GPU"] = [gpu]
17 | with open("config.yml", 'w') as f_obj:
18 | yaml.dump(config, f_obj)
19 | print("in deploy mode")
20 |
21 | # clear
22 | # model_name = config["MODEL_NAME"] + '.pth'
23 | # checkpoints_path = str(Path('./checkpoints') / \
24 | # config["SUBJECT_WORD"]/model_name)
25 | # flag = os.system("rm "+checkpoints_path)
26 | # if flag == 0:
27 | # print("clear "+checkpoints_path+" success")
28 | flag = os.system("rm -rf src/__pycache__")
29 | if flag == 0:
30 | print("clear src/__pycache__ success")
31 |
32 |
--------------------------------------------------------------------------------
/config/config_SRD_da.yml:
--------------------------------------------------------------------------------
1 | ADV: 1
2 | BACKUP: 0
3 | BATCH_SIZE: 1
4 | BETA1: 0.9
5 | BETA2: 0.999
6 | DATA_ROOT: /home/xxx/data-set/shadow_removal_dataset/SRD_Dataset_arg
7 | DEBUG: 0
8 | DIS_NORM: spectral
9 | FORCE_EXIT: 1
10 | GAN_LOSS: nsgan
11 | GAN_NORM: batch_
12 | GPU:
13 | - '0'
14 | INIT_TYPE: identity
15 | INPUT_SIZE_H: 640
16 | INPUT_SIZE_W: 840
17 | LOG_INTERVAL: 10
18 | LOSS: L1Loss
19 | LR: 0.0001
20 | LR_D: 5.0e-05
21 | MIDDLE_RES_NUM: 4
22 | MODEL_NAME: ShadowRemoval
23 | NETWORK: DMTN
24 | POOL_SIZE: 50
25 | PRE_TRAIN_EVAL_LEN: 10
26 | RESULTS: ./checkpoints/SRD/results
27 | RESULTS_SAMPLE: 0
28 | SAMPLE_SIZE: 1
29 | SEED: 10
30 | SPLIT: 1
31 | SUBJECT_WORD: SRD
32 | TRAIN_EVAL_LEN: 10
33 | VERBOSE: 1
34 |
35 | TEST_FLIST: /data_val/SRD_shadow_test.flist
36 | TEST_GT_FLIST: /data_val/SRD_shadow_free_test.flist
37 | TEST_MASK_FLIST: /data_val/SRD_mask_test.flist
38 |
39 | TRAIN_FLIST: /data_val/SRD_shadow_train_train.flist
40 | TRAIN_GT_FLIST: /data_val/SRD_shadow_free_train_train.flist
41 | TRAIN_MASK_FLIST: /data_val/SRD_mask_train_train.flist
42 |
43 | TRAIN_FLIST_PRE: /data_val/SRD_shadow_pre_train.flist
44 | TRAIN_GT_FLIST_PRE: /data_val/SRD_shadow_free_pre_train.flist
45 | TRAIN_MASK_FLIST_PRE: /data_val/SRD_mask_pre_train.flist
46 |
47 | VAL_FLIST: /data_val/SRD_shadow_train_val.flist
48 | VAL_GT_FLIST: /data_val/SRD_shadow_free_train_val.flist
49 | VAL_MASK_FLIST: /data_val/SRD_mask_train_val.flist
50 |
51 | VAL_FLIST_PRE: /data_val/SRD_shadow_pre_val.flist
52 | VAL_GT_FLIST_PRE: /data_val/SRD_shadow_free_pre_val.flist
53 | VAL_MASK_FLIST_PRE: /data_val/SRD_mask_pre_val.flist
--------------------------------------------------------------------------------
/config/config_ISTD_da.yml:
--------------------------------------------------------------------------------
1 | ADV: 1
2 | BACKUP: 0
3 | BATCH_SIZE: 1
4 | BETA1: 0.9
5 | BETA2: 0.999
6 | DATA_ROOT: /home/xxx/data-set/shadow_removal_dataset/ISTD_Dataset_arg
7 | DEBUG: 0
8 | DIS_NORM: spectral
9 | FORCE_EXIT: 1
10 | GAN_LOSS: nsgan
11 | GAN_NORM: batch_
12 | GPU:
13 | - '0'
14 | INIT_TYPE: identity
15 | INPUT_SIZE_H: 480
16 | INPUT_SIZE_W: 640
17 | LOG_INTERVAL: 10
18 | LOSS: L1Loss
19 | LR: 0.0001
20 | LR_D: 5.0e-05
21 | MIDDLE_RES_NUM: 4
22 | MODEL_NAME: ShadowRemoval
23 | NETWORK: DMTN
24 | POOL_SIZE: 50
25 | PRE_TRAIN_EVAL_LEN: 10
26 | RESULTS: ./checkpoints/ISTD/results
27 | RESULTS_SAMPLE: 0
28 | SAMPLE_SIZE: 1
29 | SEED: 10
30 | SPLIT: 1
31 | SUBJECT_WORD: ISTD
32 | TRAIN_EVAL_LEN: 10
33 | VERBOSE: 1
34 |
35 | TEST_FLIST: /data_val/ISTD_shadow_test.flist
36 | TEST_GT_FLIST: /data_val/ISTD_shadow_free_test.flist
37 | TEST_MASK_FLIST: /data_val/ISTD_mask_test.flist
38 |
39 | TRAIN_FLIST: /data_val/ISTD_shadow_train_train.flist
40 | TRAIN_GT_FLIST: /data_val/ISTD_shadow_free_train_train.flist
41 | TRAIN_MASK_FLIST: /data_val/ISTD_mask_train_train.flist
42 |
43 | TRAIN_FLIST_PRE: /data_val/ISTD_shadow_pre_train.flist
44 | TRAIN_GT_FLIST_PRE: /data_val/ISTD_shadow_free_pre_train.flist
45 | TRAIN_MASK_FLIST_PRE: /data_val/ISTD_mask_pre_train.flist
46 |
47 | VAL_FLIST: /data_val/ISTD_shadow_train_val.flist
48 | VAL_GT_FLIST: /data_val/ISTD_shadow_free_train_val.flist
49 | VAL_MASK_FLIST: /data_val/ISTD_mask_train_val.flist
50 |
51 | VAL_FLIST_PRE: /data_val/ISTD_shadow_pre_val.flist
52 | VAL_GT_FLIST_PRE: /data_val/ISTD_shadow_free_pre_val.flist
53 | VAL_MASK_FLIST_PRE: /data_val/ISTD_mask_pre_val.flist
--------------------------------------------------------------------------------
/config/config_SRD.yml:
--------------------------------------------------------------------------------
1 | ADV: 1
2 | BACKUP: 0
3 | BATCH_SIZE: 1
4 | BETA1: 0.9
5 | BETA2: 0.999
6 | DATA_ROOT: /home/xxx/data-set/shadow_removal_dataset/SRD_Dataset_arg
7 | DEBUG: 0
8 | DIS_NORM: spectral
9 | FORCE_EXIT: 1
10 | GAN_LOSS: nsgan
11 | GAN_NORM: batch_
12 | GPU:
13 | - '0'
14 | INIT_TYPE: identity
15 | INPUT_SIZE_H: 640
16 | INPUT_SIZE_W: 840
17 | LOG_INTERVAL: 10
18 | LOSS: L1Loss
19 | LR: 0.0001
20 | LR_D: 5.0e-05
21 | MIDDLE_RES_NUM: 4
22 | MODEL_NAME: ShadowRemoval
23 | NETWORK: DMTN
24 | POOL_SIZE: 50
25 | PRE_TRAIN_EVAL_LEN: 10
26 | RESULTS: ./checkpoints/SRD/results
27 | RESULTS_SAMPLE: 0
28 | SAMPLE_SIZE: 1
29 | SEED: 10
30 | SPLIT: 1
31 | SUBJECT_WORD: SRD
32 | TRAIN_EVAL_LEN: 10
33 | VERBOSE: 1
34 |
35 | TEST_FLIST: /data_val/SRD_shadow_test.flist
36 | TEST_GT_FLIST: /data_val/SRD_shadow_free_test.flist
37 | TEST_MASK_FLIST: /data_val/SRD_mask_test.flist
38 |
39 | TRAIN_FLIST: /data_val/SRD_shadow_train_train.flist
40 | TRAIN_GT_FLIST: /data_val/SRD_shadow_free_train_train.flist
41 | TRAIN_MASK_FLIST: /data_val/SRD_mask_train_train.flist
42 |
43 | TRAIN_FLIST_PRE: /data_val/SRD_shadow_train_train.flist
44 | TRAIN_GT_FLIST_PRE: /data_val/SRD_shadow_free_train_train.flist
45 | TRAIN_MASK_FLIST_PRE: /data_val/SRD_mask_train_train.flist
46 |
47 | VAL_FLIST: /data_val/SRD_shadow_train_val.flist
48 | VAL_GT_FLIST: /data_val/SRD_shadow_free_train_val.flist
49 | VAL_MASK_FLIST: /data_val/SRD_mask_train_val.flist
50 |
51 | VAL_FLIST_PRE: /data_val/SRD_shadow_train_val.flist
52 | VAL_GT_FLIST_PRE: /data_val/SRD_shadow_free_train_val.flist
53 | VAL_MASK_FLIST_PRE: /data_val/SRD_mask_train_val.flist
54 |
--------------------------------------------------------------------------------
/config/config_ISTD.yml:
--------------------------------------------------------------------------------
1 | ADV: 1
2 | BACKUP: 0
3 | BATCH_SIZE: 1
4 | BETA1: 0.9
5 | BETA2: 0.999
6 | DATA_ROOT: /home/xxx/data-set/shadow_removal_dataset/ISTD_Dataset_arg
7 | DEBUG: 0
8 | DIS_NORM: spectral
9 | FORCE_EXIT: 1
10 | GAN_LOSS: nsgan
11 | GAN_NORM: batch_
12 | GPU:
13 | - '0'
14 | INIT_TYPE: identity
15 | INPUT_SIZE_H: 480
16 | INPUT_SIZE_W: 640
17 | LOG_INTERVAL: 10
18 | LOSS: L1Loss
19 | LR: 0.0001
20 | LR_D: 5.0e-05
21 | MIDDLE_RES_NUM: 4
22 | MODEL_NAME: ShadowRemoval
23 | NETWORK: DMTN
24 | POOL_SIZE: 50
25 | PRE_TRAIN_EVAL_LEN: 10
26 | RESULTS: ./checkpoints/ISTD/results
27 | RESULTS_SAMPLE: 0
28 | SAMPLE_SIZE: 1
29 | SEED: 10
30 | SPLIT: 1
31 | SUBJECT_WORD: ISTD
32 | TRAIN_EVAL_LEN: 10
33 | VERBOSE: 1
34 |
35 | TEST_FLIST: /data_val/ISTD_shadow_test.flist
36 | TEST_GT_FLIST: /data_val/ISTD_shadow_free_test.flist
37 | TEST_MASK_FLIST: /data_val/ISTD_mask_test.flist
38 |
39 | TRAIN_FLIST: /data_val/ISTD_shadow_train_train.flist
40 | TRAIN_GT_FLIST: /data_val/ISTD_shadow_free_train_train.flist
41 | TRAIN_MASK_FLIST: /data_val/ISTD_mask_train_train.flist
42 |
43 | TRAIN_FLIST_PRE: /data_val/ISTD_shadow_train_train.flist
44 | TRAIN_GT_FLIST_PRE: /data_val/ISTD_shadow_free_train_train.flist
45 | TRAIN_MASK_FLIST_PRE: /data_val/ISTD_mask_train_train.flist
46 |
47 | VAL_FLIST: /data_val/ISTD_shadow_train_val.flist
48 | VAL_GT_FLIST: /data_val/ISTD_shadow_free_train_val.flist
49 | VAL_MASK_FLIST: /data_val/ISTD_mask_train_val.flist
50 |
51 | VAL_FLIST_PRE: /data_val/ISTD_shadow_train_val.flist
52 | VAL_GT_FLIST_PRE: /data_val/ISTD_shadow_free_train_val.flist
53 | VAL_MASK_FLIST_PRE: /data_val/ISTD_mask_train_val.flist
54 |
--------------------------------------------------------------------------------
/script/flist_train_val_test.py:
--------------------------------------------------------------------------------
1 | """divide the data of all folders in a directory into: training set, verification set, test set"""
2 | import os
3 | import numpy as np
4 | import random
5 |
6 |
7 | def gen_flist_train_val_test(flist, train_val_test_path, train_val_test_ratio, SEED, id_list):
8 | random.seed(SEED)
9 | # get flist
10 | ext = {'.JPG', '.JPEG', '.PNG', '.TIF', 'TIFF', 'json'}
11 | images = np.genfromtxt(flist, dtype=np.str, encoding='utf-8')
12 | # shuffle
13 | files_num = len(images)
14 | if len(id_list) == 0:
15 | id_list = list(range(files_num))
16 | shuffle = True
17 | else:
18 | shuffle = False
19 | if shuffle:
20 | random.shuffle(id_list)
21 | images = np.array(images)[id_list]
22 | # save
23 | images_train_val_test = [[], [], []]
24 | i_list = [0]
25 | sum_ = 0
26 | if sum(train_val_test_ratio) == 10:
27 | for i in range(3):
28 | if train_val_test_ratio[i] > 0:
29 | sum_ = sum_ + \
30 | np.int(np.floor(train_val_test_ratio[i]*files_num/10))
31 | if i == 2:
32 | i_list.append(files_num)
33 | else:
34 | i_list.append(sum_)
35 | images_train_val_test[i] = images[i_list[i]:i_list[i+1]]
36 | # save
37 | np.savetxt(train_val_test_path[i],
38 | images_train_val_test[i], fmt='%s')
39 | else:
40 | sum_ = sum_ + \
41 | np.int(np.floor(train_val_test_ratio[i]*files_num/10))
42 | i_list.append(sum_)
43 | else:
44 | print('input train_val_test_ratio error!')
45 | return id_list
46 |
--------------------------------------------------------------------------------
/script/process_syn_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 |
4 | def parpare_image_syn(val_path,stage='train_shadow_free'):
5 |
6 | iminput = val_path
7 | val_mask_name = val_path.split('/')[-1].split('_')[-1]
8 | gtmask = val_path.replace(stage,'train_B_ISTD').replace(val_path.split('/')[-1],val_mask_name)
9 |
10 | val_im_name = '_'.join(val_path.split('/')[-1].split('_')[0:-1])+'.jpg'
11 | imtarget = val_path.replace(stage,'shadow_free').replace(val_path.split('/')[-1],val_im_name)
12 |
13 | return iminput,imtarget,gtmask
14 |
15 | def prepare_data(train_path, stage=['train_A']):
16 | input_names=[]
17 | for dirname in train_path:
18 | for subfolder in stage:
19 | train_b = dirname + "/"+ subfolder+"/"
20 | for root, _, fnames in sorted(os.walk(train_b)):
21 | for fname in fnames:
22 | if is_image_file(fname):
23 | input_names.append(os.path.join(train_b, fname))
24 | return input_names
25 |
26 | IMG_EXTENSIONS = [
27 | '.jpg', '.JPG', '.jpeg', '.JPEG',
28 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
29 | ]
30 |
31 | def is_image_file(filename):
32 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
33 |
34 | def gen_syn_images_flist(train_real_root):
35 | train_real_root = [train_real_root]
36 | syn_images=prepare_data(train_real_root,stage=['synC'])
37 |
38 | shadow_syn_list= []
39 | shadow_free_syn_list= []
40 | mask_syn_list= []
41 | for i in range(len(syn_images)):
42 | shadow,shadow_free,mask = parpare_image_syn(syn_images[i],stage='synC')
43 | shadow_syn_list.append(shadow)
44 | shadow_free_syn_list.append(shadow_free)
45 | mask_syn_list.append(mask)
46 | return shadow_syn_list,shadow_free_syn_list,mask_syn_list
47 |
48 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | from shutil import copyfile
5 |
6 | import cv2
7 | import numpy as np
8 | import torch
9 |
10 | from src.config import Config
11 | from src.model_top import ModelTop
12 | from src.utils import create_dir
13 |
14 |
15 | def main(mode, config_path):
16 | r"""
17 | Args:
18 | mode (int): 1: train, 2: test, 3: eval, reads from config file if not specified
19 | """
20 |
21 | config = load_config(mode, config_path)
22 | config.CONFIG_PATH = config_path
23 |
24 | # init device
25 | if torch.cuda.is_available():
26 | config.DEVICE = torch.device("cuda")
27 | torch.backends.cudnn.benchmark = True # cudnn auto-tuner
28 | else:
29 | config.DEVICE = torch.device("cpu")
30 |
31 | # set cv2 running threads to 1 (prevents deadlocks with pytorch dataloader)
32 | cv2.setNumThreads(0)
33 |
34 | # initialize random seed
35 | torch.manual_seed(config.SEED)
36 | torch.cuda.manual_seed_all(config.SEED)
37 | np.random.seed(config.SEED)
38 | random.seed(config.SEED)
39 |
40 | # build the model and initialize
41 | model = ModelTop(config)
42 | model.load()
43 |
44 | # model pre training (data augmentation)
45 | if config.MODE == 0:
46 | config.print()
47 | model.train()
48 |
49 | # model pre training (no data augmentation)
50 | if config.MODE == 1:
51 | config.print()
52 | model.train()
53 |
54 | # model training
55 | if config.MODE == 2:
56 | config.print()
57 | model.train()
58 |
59 | # model test
60 | elif config.MODE == 3:
61 | model.test()
62 |
63 | # model eval on val set
64 | elif config.MODE ==4:
65 | model.eval()
66 |
67 | # model eval on test set
68 | elif config.MODE == 5:
69 | model.eval()
70 |
71 |
72 | def load_config(mode, config_path):
73 | r"""loads model config
74 |
75 | Args:
76 | mode (int): 1: train, 2: test 3:eval reads from config file if not specified
77 | """
78 |
79 | # load config file
80 | config = Config(config_path)
81 | config.MODE = mode
82 |
83 | return config
84 |
85 |
86 | if __name__ == "__main__":
87 | main()
88 |
--------------------------------------------------------------------------------
/config/run_ISTD.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import os
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import yaml
7 |
8 |
9 | from src.main import main
10 | from src.utils import create_dir, init_config, copypth
11 | import torch
12 |
13 | if __name__ == '__main__':
14 | # inital
15 | multiprocessing.set_start_method('spawn')
16 | config = yaml.load(open("config.yml", 'r'), Loader=yaml.FullLoader)
17 | dest_path = Path('checkpoints/') / \
18 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'.pth')
19 | # cuda visble devices
20 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
21 | str(e) for e in config["GPU"])
22 | torch.autograd.set_detect_anomaly(True)
23 | checkpoints_path = Path('./checkpoints') / \
24 | config["SUBJECT_WORD"] # model checkpoints path
25 | create_dir(checkpoints_path)
26 | create_dir('./pre_train_model')
27 | config_path = os.path.join(checkpoints_path, 'config.yml')
28 |
29 | # pre_train (no data augmentation)
30 | MODE = 0
31 | print('\nmode-'+str(MODE)+': start pre_training(data augmentation)...\n')
32 | for i in range(1):
33 | skip_train = init_config(checkpoints_path, MODE=MODE,
34 | EVAL_INTERVAL_EPOCH=1, EPOCH=[90,i])
35 | if not skip_train:
36 | main(MODE, config_path)
37 | src_path = Path('./pre_train_model') / \
38 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_da.pth')
39 | copypth(dest_path, src_path)
40 |
41 | # train
42 | MODE = 2
43 | print('\nmode-'+str(MODE)+': start training...\n')
44 | for i in range(1):
45 | skip_train = init_config(checkpoints_path, MODE=MODE,
46 | EVAL_INTERVAL_EPOCH=0.1, EPOCH=[60,i])
47 | if not skip_train:
48 | main(MODE, config_path)
49 | src_path = Path('./pre_train_model') / \
50 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_final.pth')
51 | copypth(dest_path, src_path)
52 |
53 | # test
54 | MODE = 3
55 | print('\nmode-'+str(MODE)+': start testing...\n')
56 | main(MODE, config_path)
57 |
58 | # eval on val set
59 | # MODE = 4
60 | # print('\nmode-'+str(MODE)+': start eval...\n')
61 | # main(MODE,config_path)
62 |
63 | # eval on test set
64 | MODE = 5
65 | print('\nmode-'+str(MODE)+': start eval...\n')
66 | main(MODE, config_path)
67 |
--------------------------------------------------------------------------------
/config/run_SRD.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import os
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import yaml
7 |
8 |
9 | from src.main import main
10 | from src.utils import create_dir, init_config, copypth
11 | import torch
12 |
13 | if __name__ == '__main__':
14 | # inital
15 | multiprocessing.set_start_method('spawn')
16 | config = yaml.load(open("config.yml", 'r'), Loader=yaml.FullLoader)
17 | dest_path = Path('checkpoints/') / \
18 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'.pth')
19 | # cuda visble devices
20 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
21 | str(e) for e in config["GPU"])
22 | torch.autograd.set_detect_anomaly(True)
23 | checkpoints_path = Path('./checkpoints') / \
24 | config["SUBJECT_WORD"] # model checkpoints path
25 | create_dir(checkpoints_path)
26 | create_dir('./pre_train_model')
27 | config_path = os.path.join(checkpoints_path, 'config.yml')
28 |
29 | # pre_train (no data augmentation)
30 | MODE = 0
31 | print('\nmode-'+str(MODE)+': start pre_training(data augmentation)...\n')
32 | for i in range(1):
33 | skip_train = init_config(checkpoints_path, MODE=MODE,
34 | EVAL_INTERVAL_EPOCH=1, EPOCH=[90,i])
35 | if not skip_train:
36 | main(MODE, config_path)
37 | src_path = Path('./pre_train_model') / \
38 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_da.pth')
39 | copypth(dest_path, src_path)
40 |
41 | # train
42 | MODE = 2
43 | print('\nmode-'+str(MODE)+': start training...\n')
44 | for i in range(6):
45 | skip_train = init_config(checkpoints_path, MODE=MODE,
46 | EVAL_INTERVAL_EPOCH=1, EPOCH=[10,20,30,40,50,60,i])
47 | if not skip_train:
48 | main(MODE, config_path)
49 | src_path = Path('./pre_train_model') / \
50 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_final.pth')
51 | copypth(dest_path, src_path)
52 |
53 | # test
54 | MODE = 3
55 | print('\nmode-'+str(MODE)+': start testing...\n')
56 | main(MODE, config_path)
57 |
58 | # eval on val set
59 | # MODE = 4
60 | # print('\nmode-'+str(MODE)+': start eval...\n')
61 | # main(MODE,config_path)
62 |
63 | # eval on test set
64 | MODE = 5
65 | print('\nmode-'+str(MODE)+': start eval...\n')
66 | main(MODE, config_path)
67 |
--------------------------------------------------------------------------------
/src/image_pool.py:
--------------------------------------------------------------------------------
1 | """
2 | This part of the code is built based on the project:
3 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
4 | """
5 | import random
6 | import torch
7 |
8 |
9 | class ImagePool():
10 | """This class implements an image buffer that stores previously generated images.
11 |
12 | This buffer enables us to update discriminators using a history of generated images
13 | rather than the ones produced by the latest generators.
14 | """
15 |
16 | def __init__(self, pool_size):
17 | """Initialize the ImagePool class
18 |
19 | Parameters:
20 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
21 | """
22 | self.pool_size = pool_size
23 | if self.pool_size > 0: # create an empty pool
24 | self.num_imgs = 0
25 | self.images = []
26 |
27 | def query(self, images):
28 | """Return an image from the pool.
29 |
30 | Parameters:
31 | images: the latest generated images from the generator
32 |
33 | Returns images from the buffer.
34 |
35 | By 50/100, the buffer will return input images.
36 | By 50/100, the buffer will return images previously stored in the buffer,
37 | and insert the current images to the buffer.
38 | """
39 | if self.pool_size == 0: # if the buffer size is 0, do nothing
40 | return images
41 | return_images = []
42 | for image in images:
43 | image = torch.unsqueeze(image.data, 0)
44 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer
45 | self.num_imgs = self.num_imgs + 1
46 | self.images.append(image)
47 | return_images.append(image)
48 | else:
49 | p = random.uniform(0, 1)
50 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
51 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
52 | tmp = self.images[random_id].clone()
53 | self.images[random_id] = image
54 | return_images.append(tmp)
55 | else: # by another 50% chance, the buffer will return the current image
56 | return_images.append(image)
57 | return_images = torch.cat(return_images, 0) # collect all the images and return
58 | return return_images
59 |
--------------------------------------------------------------------------------
/install.yaml:
--------------------------------------------------------------------------------
1 | name: DMTN
2 | dependencies:
3 | - _libgcc_mutex=0.1=main
4 | - autopep8=1.5.4=py_0
5 | - blas=1.0=mkl
6 | - bzip2=1.0.8=h7b6447c_0
7 | - ca-certificates=2020.12.8=h06a4308_0
8 | - cairo=1.14.12=h8948797_3
9 | - certifi=2020.12.5=py37h06a4308_0
10 | - cudatoolkit=10.2.89=hfd86e86_1
11 | - ffmpeg=4.0=hcdf2ecd_0
12 | - fontconfig=2.13.0=h9420a91_0
13 | - freeglut=3.0.0=hf484d3e_5
14 | - freetype=2.10.4=h5ab3b9f_0
15 | - glib=2.66.1=h92f7085_0
16 | - graphite2=1.3.14=h23475e2_0
17 | - harfbuzz=1.8.8=hffaf4a1_0
18 | - hdf5=1.10.2=hba1933b_1
19 | - icu=58.2=he6710b0_3
20 | - intel-openmp=2020.2=254
21 | - jasper=2.0.14=h07fcdf6_1
22 | - jpeg=9b=h024ee3a_2
23 | - lcms2=2.11=h396b838_0
24 | - ld_impl_linux-64=2.33.1=h53a641e_7
25 | - libedit=3.1.20191231=h14c3975_1
26 | - libffi=3.3=he6710b0_2
27 | - libgcc-ng=9.1.0=hdf63c60_0
28 | - libgfortran-ng=7.3.0=hdf63c60_0
29 | - libglu=9.0.0=hf484d3e_1
30 | - libopencv=3.4.2=hb342d67_1
31 | - libopus=1.3.1=h7b6447c_0
32 | - libpng=1.6.37=hbc83047_0
33 | - libstdcxx-ng=9.1.0=hdf63c60_0
34 | - libtiff=4.1.0=h2733197_1
35 | - libuuid=1.0.3=h1bed415_2
36 | - libuv=1.40.0=h7b6447c_0
37 | - libvpx=1.7.0=h439df22_0
38 | - libxcb=1.14=h7b6447c_0
39 | - libxml2=2.9.10=hb55368b_3
40 | - lz4-c=1.9.2=heb0550a_3
41 | - mkl=2020.2=256
42 | - mkl-service=2.3.0=py37he8ac12f_0
43 | - mkl_fft=1.2.0=py37h23d657b_0
44 | - mkl_random=1.1.1=py37h0573a6f_0
45 | - ncurses=6.2=he6710b0_1
46 | - ninja=1.10.2=py37hff7bd54_0
47 | - numpy=1.19.2=py37h54aff64_0
48 | - numpy-base=1.19.2=py37hfa32c7d_0
49 | - olefile=0.46=py37_0
50 | - openssl=1.1.1i=h27cfd23_0
51 | - pcre=8.44=he6710b0_0
52 | - pillow=8.0.1=py37he98fc37_0
53 | - pip=20.3.1=py37h06a4308_0
54 | - pixman=0.40.0=h7b6447c_0
55 | - py-opencv=3.4.2=py37hb342d67_1
56 | - pycodestyle=2.6.0=py_0
57 | - python=3.7.9=h7579374_0
58 | - pytorch=1.7.1=py3.7_cuda10.2.89_cudnn7.6.5_0
59 | - pyyaml=5.3.1=py37h7b6447c_1
60 | - readline=8.0=h7b6447c_0
61 | - scipy=1.5.2=py37h0b6359f_0
62 | - setuptools=51.0.0=py37h06a4308_2
63 | - six=1.15.0=py37h06a4308_0
64 | - sqlite=3.33.0=h62c20be_0
65 | - tk=8.6.10=hbc83047_0
66 | - toml=0.10.1=py_0
67 | - torchaudio=0.7.2=py37
68 | - torchvision=0.2.1=py37_0
69 | - typing_extensions=3.7.4.3=py_0
70 | - wheel=0.36.1=pyhd3eb1b0_0
71 | - xz=5.2.5=h7b6447c_0
72 | - yaml=0.2.5=h7b6447c_0
73 | - zlib=1.2.11=h7b6447c_3
74 | - zstd=1.4.5=h9ceee32_0
75 | - pip:
76 | - augmentor==0.2.8
77 | - cycler==0.10.0
78 | - decorator==4.4.2
79 | - future==0.18.2
80 | - imageio==2.9.0
81 | - kiwisolver==1.3.1
82 | - matplotlib==3.3.3
83 | - networkx==2.5
84 | - pyparsing==2.4.7
85 | - python-dateutil==2.8.1
86 | - pywavelets==1.1.1
87 | - scikit-image==0.17.2
88 | - tifffile==2020.12.8
89 | - tqdm==4.54.1
90 |
--------------------------------------------------------------------------------
/config/run_SRD_da.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import os
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import yaml
7 |
8 |
9 | from src.main import main
10 | from src.utils import create_dir, init_config, copypth
11 | import torch
12 |
13 | if __name__ == '__main__':
14 | # inital
15 | multiprocessing.set_start_method('spawn')
16 | config = yaml.load(open("config.yml", 'r'), Loader=yaml.FullLoader)
17 | dest_path = Path('checkpoints/') / \
18 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'.pth')
19 | # cuda visble devices
20 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
21 | str(e) for e in config["GPU"])
22 | torch.autograd.set_detect_anomaly(True)
23 | checkpoints_path = Path('./checkpoints') / \
24 | config["SUBJECT_WORD"] # model checkpoints path
25 | create_dir(checkpoints_path)
26 | create_dir('./pre_train_model')
27 | config_path = os.path.join(checkpoints_path, 'config.yml')
28 |
29 | # pre_train (no data augmentation)
30 | MODE = 0
31 | print('\nmode-'+str(MODE)+': start pre_training(data augmentation)...\n')
32 | for i in range(1):
33 | skip_train = init_config(checkpoints_path, MODE=MODE,
34 | EVAL_INTERVAL_EPOCH=1, EPOCH=[90,i])
35 | if not skip_train:
36 | main(MODE, config_path)
37 | src_path = Path('./pre_train_model') / \
38 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_da.pth')
39 | copypth(dest_path, src_path)
40 |
41 | # pre_train (no data augmentation)
42 | MODE = 1
43 | print('\nmode-'+str(MODE)+': start pre_training(no data augmentation)...\n')
44 | for i in range(1):
45 | skip_train = init_config(checkpoints_path, MODE=MODE,
46 | EVAL_INTERVAL_EPOCH=1, EPOCH=[30,i])
47 | if not skip_train:
48 | main(MODE, config_path)
49 | src_path = Path('./pre_train_model') / \
50 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_no_da.pth')
51 | copypth(dest_path, src_path)
52 |
53 | # train
54 | MODE = 2
55 | print('\nmode-'+str(MODE)+': start training...\n')
56 | for i in range(1):
57 | skip_train = init_config(checkpoints_path, MODE=MODE,
58 | EVAL_INTERVAL_EPOCH=1, EPOCH=[30,i])
59 | if not skip_train:
60 | main(MODE, config_path)
61 | src_path = Path('./pre_train_model') / \
62 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_final.pth')
63 | copypth(dest_path, src_path)
64 |
65 | # test
66 | MODE = 3
67 | print('\nmode-'+str(MODE)+': start testing...\n')
68 | main(MODE, config_path)
69 |
70 | # eval on val set
71 | # MODE = 4
72 | # print('\nmode-'+str(MODE)+': start eval...\n')
73 | # main(MODE,config_path)
74 |
75 | # eval on test set
76 | MODE = 5
77 | print('\nmode-'+str(MODE)+': start eval...\n')
78 | main(MODE, config_path)
79 |
--------------------------------------------------------------------------------
/config/run_ISTD_da.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | import os
3 | from pathlib import Path
4 |
5 | import numpy as np
6 | import yaml
7 |
8 |
9 | from src.main import main
10 | from src.utils import create_dir, init_config, copypth
11 | import torch
12 |
13 | if __name__ == '__main__':
14 | # inital
15 | multiprocessing.set_start_method('spawn')
16 | config = yaml.load(open("config.yml", 'r'), Loader=yaml.FullLoader)
17 | dest_path = Path('checkpoints/') / \
18 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'.pth')
19 | # cuda visble devices
20 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
21 | str(e) for e in config["GPU"])
22 | torch.autograd.set_detect_anomaly(True)
23 | checkpoints_path = Path('./checkpoints') / \
24 | config["SUBJECT_WORD"] # model checkpoints path
25 | create_dir(checkpoints_path)
26 | create_dir('./pre_train_model')
27 | config_path = os.path.join(checkpoints_path, 'config.yml')
28 |
29 | # pre_train (no data augmentation)
30 | MODE = 0
31 | print('\nmode-'+str(MODE)+': start pre_training(data augmentation)...\n')
32 | for i in range(1):
33 | skip_train = init_config(checkpoints_path, MODE=MODE,
34 | EVAL_INTERVAL_EPOCH=1, EPOCH=[90,i])
35 | if not skip_train:
36 | main(MODE, config_path)
37 | src_path = Path('./pre_train_model') / \
38 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_da.pth')
39 | copypth(dest_path, src_path)
40 |
41 | # pre_train (no data augmentation)
42 | MODE = 1
43 | print('\nmode-'+str(MODE)+': start pre_training(no data augmentation)...\n')
44 | for i in range(1):
45 | skip_train = init_config(checkpoints_path, MODE=MODE,
46 | EVAL_INTERVAL_EPOCH=1, EPOCH=[30,i])
47 | if not skip_train:
48 | main(MODE, config_path)
49 | src_path = Path('./pre_train_model') / \
50 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_no_da.pth')
51 | copypth(dest_path, src_path)
52 |
53 | # train
54 | MODE = 2
55 | print('\nmode-'+str(MODE)+': start training...\n')
56 | for i in range(1):
57 | skip_train = init_config(checkpoints_path, MODE=MODE,
58 | EVAL_INTERVAL_EPOCH=0.1, EPOCH=[30,i])
59 | if not skip_train:
60 | main(MODE, config_path)
61 | src_path = Path('./pre_train_model') / \
62 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_final.pth')
63 | copypth(dest_path, src_path)
64 |
65 | # test
66 | MODE = 3
67 | print('\nmode-'+str(MODE)+': start testing...\n')
68 | main(MODE, config_path)
69 |
70 | # eval on val set
71 | # MODE = 4
72 | # print('\nmode-'+str(MODE)+': start eval...\n')
73 | # main(MODE,config_path)
74 |
75 | # eval on test set
76 | MODE = 5
77 | print('\nmode-'+str(MODE)+': start eval...\n')
78 | main(MODE, config_path)
79 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.optim as optim
6 | from src.network.network_DMTN import DMTN
7 |
8 | class BaseModel(nn.Module):
9 | def __init__(self, name, config):
10 | super(BaseModel, self).__init__()
11 |
12 | self.name = name
13 | self.config = config
14 | self.iteration = 0
15 | self.config = config
16 |
17 | self.weights_path = os.path.join(config.PATH, name + '.pth')
18 |
19 | def load(self):
20 | if os.path.exists(self.weights_path):
21 | print('Loading %s model...' % self.name)
22 |
23 | if torch.cuda.is_available():
24 | data = torch.load(self.weights_path)
25 | print(self.weights_path)
26 | else:
27 | data = torch.load(self.weights_path,
28 | map_location=lambda storage, loc: storage)
29 | self.network_instance.load_state_dict(data['model'])
30 | self.iteration = data['iteration']
31 | else:
32 | vgg19_weights_path = self.config.DATA_ROOT+"/vgg19-dcbb9e9d.pth"
33 | self.network_instance.network.vgg19.load_pretrained(
34 | vgg19_weights_path, self.config.DEVICE)
35 | for discriminator in self.network_instance.discriminator:
36 | discriminator.perceptual_loss.vgg19.load_pretrained(
37 | vgg19_weights_path, self.config.DEVICE)
38 |
39 | def save(self):
40 | print('\nsaving %s...\n' % self.weights_path)
41 | torch.save({
42 | 'iteration': self.iteration,
43 | 'model': self.network_instance.state_dict()
44 | }, self.weights_path)
45 |
46 | if self.config.BACKUP:
47 | INTERVAL_ = 4
48 | if self.config.SAVE_INTERVAL and self.iteration % (self.config.SAVE_INTERVAL*INTERVAL_) == 0:
49 | print('\nsaving %s...\n' % self.name+'_backup')
50 | torch.save({
51 | 'iteration': self.iteration,
52 | 'model': self.network_instance.state_dict()
53 | }, os.path.join(self.config.PATH, 'backups/' + self.name + '_' + str(self.iteration // (self.config.SAVE_INTERVAL*INTERVAL_)) + '.pth'))
54 |
55 |
56 | class Model(BaseModel):
57 | def __init__(self, config):
58 | super(Model, self).__init__(config.MODEL_NAME, config)
59 | self.INNER_OPTIMIZER = config.INNER_OPTIMIZER
60 | # networks choose
61 | if config.NETWORK == "DMTN":
62 | network_instance = DMTN(config, in_channels=3)
63 | else:
64 | network_instance = None
65 | self.add_module('network_instance', network_instance)
66 |
67 | def process(self, images, mask, GT, eval_mode=False):
68 | if not eval_mode:
69 | self.iteration += 1
70 | outputs, loss, logs = self.network_instance.process(
71 | images, mask, GT)
72 |
73 | return outputs, loss, logs
74 |
75 | def forward(self, images):
76 | inputs = images
77 | outputs = self.network_instance(inputs)
78 | return outputs
79 |
80 | def backward(self, loss=None):
81 | self.network_instance.backward(loss)
82 |
--------------------------------------------------------------------------------
/script/generate_flist_srd.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from pathlib import Path
4 | from flist import gen_flist
5 | from flist_train_val_test import gen_flist_train_val_test
6 | from process_syn_data import gen_syn_images_flist
7 |
8 | if __name__ == '__main__':
9 | # SRD
10 | SRD_path = "/gpfs/home/liujiawei/data-set/shadow_removal_dataset/SRD_Dataset_arg"
11 |
12 | save_path = SRD_path+"/data_val"
13 |
14 | if not Path(save_path).exists():
15 | os.mkdir(save_path)
16 |
17 | seed = 10
18 |
19 | # sys image
20 | shadow_syn_list,shadow_free_syn_list,mask_syn_list = gen_syn_images_flist(SRD_path+'/train')
21 |
22 | # train and val (data augmentation)
23 | data_path = Path(SRD_path)/"train/train_A"
24 | flist_main_name = "SRD_shadow.flist"
25 | flist_save_path = Path(save_path)/flist_main_name
26 | images = gen_flist(data_path)
27 | images+=shadow_syn_list
28 | np.savetxt(flist_save_path, images, fmt='%s')
29 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_pre_train.flist'),
30 | Path(save_path)/str(Path(flist_save_path).stem+'_pre_val.flist'), ""]
31 | id_list = gen_flist_train_val_test(
32 | flist_save_path, png_val_test_PATH, [8, 2, 0], seed, [])
33 |
34 | data_path = Path(SRD_path)/"train/train_B"
35 | flist_main_name = "SRD_mask.flist"
36 | flist_save_path = Path(save_path)/flist_main_name
37 | images = gen_flist(data_path)
38 | images+=mask_syn_list
39 | np.savetxt(flist_save_path, images, fmt='%s')
40 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_pre_train.flist'),
41 | Path(save_path)/str(Path(flist_save_path).stem+'_pre_val.flist'), ""]
42 | gen_flist_train_val_test(flist_save_path, png_val_test_PATH, [
43 | 8, 2, 0], seed, id_list)
44 |
45 | data_path = Path(SRD_path)/"train/train_C"
46 | flist_main_name = "SRD_shadow_free.flist"
47 | flist_save_path = Path(save_path)/flist_main_name
48 | images = gen_flist(data_path)
49 | images+=shadow_free_syn_list
50 | np.savetxt(flist_save_path, images, fmt='%s')
51 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_pre_train.flist'),
52 | Path(save_path)/str(Path(flist_save_path).stem+'_pre_val.flist'), ""]
53 | gen_flist_train_val_test(flist_save_path, png_val_test_PATH, [
54 | 8, 2, 0], seed, id_list)
55 |
56 | # train and val
57 | data_path = Path(SRD_path)/"train/train_A"
58 | flist_save_path = Path(save_path)/"SRD_shadow_train.flist"
59 | images = gen_flist(data_path)
60 | np.savetxt(flist_save_path, images, fmt='%s')
61 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_train.flist'),
62 | Path(save_path)/str(Path(flist_save_path).stem+'_val.flist'), ""]
63 | id_list = gen_flist_train_val_test(
64 | flist_save_path, png_val_test_PATH, [8, 2, 0], seed, [])
65 |
66 | data_path = Path(SRD_path)/"train/train_B"
67 | flist_save_path = Path(save_path)/"SRD_mask_train.flist"
68 | images = gen_flist(data_path)
69 | np.savetxt(flist_save_path, images, fmt='%s')
70 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_train.flist'),
71 | Path(save_path)/str(Path(flist_save_path).stem+'_val.flist'), ""]
72 | gen_flist_train_val_test(flist_save_path, png_val_test_PATH, [
73 | 8, 2, 0], seed, id_list)
74 |
75 | data_path = Path(SRD_path)/"train/train_C"
76 | flist_save_path = Path(save_path)/"SRD_shadow_free_train.flist"
77 | images = gen_flist(data_path)
78 | np.savetxt(flist_save_path, images, fmt='%s')
79 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_train.flist'),
80 | Path(save_path)/str(Path(flist_save_path).stem+'_val.flist'), ""]
81 | gen_flist_train_val_test(flist_save_path, png_val_test_PATH, [
82 | 8, 2, 0], seed, id_list)
83 |
84 | # test
85 | data_path = Path(SRD_path)/"test/test_A"
86 | flist_save_path = Path(save_path)/"SRD_shadow_test.flist"
87 | images = gen_flist(data_path)
88 | np.savetxt(flist_save_path, images, fmt='%s')
89 |
90 | data_path = Path(SRD_path)/"test/test_B"
91 | flist_save_path = Path(save_path)/"SRD_mask_test.flist"
92 | images = gen_flist(data_path)
93 | np.savetxt(flist_save_path, images, fmt='%s')
94 |
95 | data_path = Path(SRD_path)/"test/test_C"
96 | flist_save_path = Path(save_path)/"SRD_shadow_free_test.flist"
97 | images = gen_flist(data_path)
98 | np.savetxt(flist_save_path, images, fmt='%s')
99 |
100 |
--------------------------------------------------------------------------------
/script/generate_flist_istd.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from pathlib import Path
4 | from flist import gen_flist
5 | from flist_train_val_test import gen_flist_train_val_test
6 | from process_syn_data import gen_syn_images_flist
7 |
8 | if __name__ == '__main__':
9 | # ISTD
10 | ISTD_path = "/gpfs/home/liujiawei/data-set/shadow_removal_dataset/ISTD_Dataset_arg"
11 |
12 | save_path = ISTD_path+"/data_val"
13 |
14 | if not Path(save_path).exists():
15 | os.mkdir(save_path)
16 |
17 | seed = 10
18 |
19 | # sys image
20 | shadow_syn_list,shadow_free_syn_list,mask_syn_list = gen_syn_images_flist(ISTD_path+'/train')
21 |
22 | # train and val (data augmentation)
23 | data_path = Path(ISTD_path)/"train/train_A"
24 | flist_main_name = "ISTD_shadow.flist"
25 | flist_save_path = Path(save_path)/flist_main_name
26 | images = gen_flist(data_path)
27 | images+=shadow_syn_list
28 | np.savetxt(flist_save_path, images, fmt='%s')
29 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_pre_train.flist'),
30 | Path(save_path)/str(Path(flist_save_path).stem+'_pre_val.flist'), ""]
31 | id_list = gen_flist_train_val_test(
32 | flist_save_path, png_val_test_PATH, [8, 2, 0], seed, [])
33 |
34 | data_path = Path(ISTD_path)/"train/train_B"
35 | flist_main_name = "ISTD_mask.flist"
36 | flist_save_path = Path(save_path)/flist_main_name
37 | images = gen_flist(data_path)
38 | images+=mask_syn_list
39 | np.savetxt(flist_save_path, images, fmt='%s')
40 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_pre_train.flist'),
41 | Path(save_path)/str(Path(flist_save_path).stem+'_pre_val.flist'), ""]
42 | gen_flist_train_val_test(flist_save_path, png_val_test_PATH, [
43 | 8, 2, 0], seed, id_list)
44 |
45 | data_path = Path(ISTD_path)/"train/train_C"
46 | flist_main_name = "ISTD_shadow_free.flist"
47 | flist_save_path = Path(save_path)/flist_main_name
48 | images = gen_flist(data_path)
49 | images+=shadow_free_syn_list
50 | np.savetxt(flist_save_path, images, fmt='%s')
51 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_pre_train.flist'),
52 | Path(save_path)/str(Path(flist_save_path).stem+'_pre_val.flist'), ""]
53 | gen_flist_train_val_test(flist_save_path, png_val_test_PATH, [
54 | 8, 2, 0], seed, id_list)
55 |
56 | # train and val
57 | data_path = Path(ISTD_path)/"train/train_A"
58 | flist_save_path = Path(save_path)/"ISTD_shadow_train.flist"
59 | images = gen_flist(data_path)
60 | np.savetxt(flist_save_path, images, fmt='%s')
61 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_train.flist'),
62 | Path(save_path)/str(Path(flist_save_path).stem+'_val.flist'), ""]
63 | id_list = gen_flist_train_val_test(
64 | flist_save_path, png_val_test_PATH, [8, 2, 0], seed, [])
65 |
66 | data_path = Path(ISTD_path)/"train/train_B"
67 | flist_save_path = Path(save_path)/"ISTD_mask_train.flist"
68 | images = gen_flist(data_path)
69 | np.savetxt(flist_save_path, images, fmt='%s')
70 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_train.flist'),
71 | Path(save_path)/str(Path(flist_save_path).stem+'_val.flist'), ""]
72 | gen_flist_train_val_test(flist_save_path, png_val_test_PATH, [
73 | 8, 2, 0], seed, id_list)
74 |
75 | data_path = Path(ISTD_path)/"train/train_C"
76 | flist_save_path = Path(save_path)/"ISTD_shadow_free_train.flist"
77 | images = gen_flist(data_path)
78 | np.savetxt(flist_save_path, images, fmt='%s')
79 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_train.flist'),
80 | Path(save_path)/str(Path(flist_save_path).stem+'_val.flist'), ""]
81 | gen_flist_train_val_test(flist_save_path, png_val_test_PATH, [
82 | 8, 2, 0], seed, id_list)
83 |
84 | # test
85 | data_path = Path(ISTD_path)/"test/test_A"
86 | flist_save_path = Path(save_path)/"ISTD_shadow_test.flist"
87 | images = gen_flist(data_path)
88 | np.savetxt(flist_save_path, images, fmt='%s')
89 |
90 | data_path = Path(ISTD_path)/"test/test_B"
91 | flist_save_path = Path(save_path)/"ISTD_mask_test.flist"
92 | images = gen_flist(data_path)
93 | np.savetxt(flist_save_path, images, fmt='%s')
94 |
95 | data_path = Path(ISTD_path)/"test/test_C"
96 | flist_save_path = Path(save_path)/"ISTD_shadow_free_test.flist"
97 | images = gen_flist(data_path)
98 | np.savetxt(flist_save_path, images, fmt='%s')
99 |
100 |
--------------------------------------------------------------------------------
/src/metrics.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | from decimal import *
3 |
4 | import numpy as np
5 | import skimage.color as skcolor
6 | import torch
7 | import torch.nn as nn
8 |
9 |
10 | class Metrics(nn.Module):
11 | def __init__(self):
12 | super(Metrics, self).__init__()
13 | nThreads = multiprocessing.cpu_count()
14 | self.multiprocessingi_utils = MultiprocessingiUtils(int(nThreads/4))
15 |
16 | def rgb2lab_all(self, outputs, GT):
17 | outputs = self.multiprocessingi_utils.rgb2lab(outputs)
18 | GT = self.multiprocessingi_utils.rgb2lab(GT)
19 |
20 | return outputs, GT
21 |
22 | def rmse(self, outputs, mask, GT, dataset_mode=0):
23 | outputs, GT = self.rgb2lab_all(outputs, GT)
24 |
25 | outputs[outputs > 1.0] = 1.0
26 | outputs[outputs < 0] = 0
27 |
28 | mask[mask > 0] = 1
29 | mask[mask == 0] = 0
30 | mask = mask.expand(-1, 3, -1, -1)
31 | mask_inverse = 1 - mask
32 |
33 | error_map = (outputs - GT).abs()*255
34 |
35 | rmse_all = error_map.sum(dim=(1, 2, 3))
36 | n_pxl_all = torch.from_numpy(np.array(
37 | [error_map.shape[2]*error_map.shape[3]])).cuda().expand(rmse_all.shape[0]).type(torch.float32)
38 |
39 | rmse_shadow = (error_map*mask).sum(dim=(1, 2, 3))
40 | n_pxl_shadow = mask.sum(dim=(1, 2, 3)) / mask.shape[1]
41 |
42 | rmse_non_shadow = (error_map*mask_inverse).sum(dim=(1, 2, 3))
43 | n_pxl_non_shadow = mask_inverse.sum(
44 | dim=(1, 2, 3)) / mask_inverse.shape[1]
45 |
46 | if dataset_mode != 0:
47 | return rmse_shadow, n_pxl_shadow, rmse_non_shadow, n_pxl_non_shadow, rmse_all, n_pxl_all
48 | else:
49 | rmse_shadow_eval = (
50 | rmse_shadow / (n_pxl_shadow+(n_pxl_shadow == 0.0).type(torch.float32))).mean()
51 | rmse_non_shadow_eval = (
52 | rmse_non_shadow / (n_pxl_non_shadow+(n_pxl_non_shadow == 0.0).type(torch.float32))).mean()
53 | rmse_all_eval = (
54 | rmse_all/(n_pxl_all+(n_pxl_all == 0.0).type(torch.float32))).mean()
55 | # print('running rmse-shadow: %.4f, rmse-non-shadow: %.4f, rmse-all: %.4f'
56 | # % (rmse_shadow_eval, rmse_non_shadow_eval, rmse_all_eval))
57 | return rmse_shadow_eval, rmse_non_shadow_eval, rmse_all_eval
58 |
59 | def collect_rmse(self, rmse_shadow, n_pxl_shadow, rmse_non_shadow, n_pxl_non_shadow, rmse_all, n_pxl_all):
60 | # GPU->CPU
61 | rmse_shadow = rmse_shadow.cpu().numpy()
62 | n_pxl_shadow = n_pxl_shadow.cpu().numpy()
63 | rmse_non_shadow = rmse_non_shadow.cpu().numpy()
64 | n_pxl_non_shadow = n_pxl_non_shadow.cpu().numpy()
65 | rmse_all = rmse_all.cpu().numpy()
66 | n_pxl_all = n_pxl_all.cpu().numpy()
67 |
68 | # decimal
69 | getcontext().prec = 50
70 |
71 | # sum
72 | rmse_shadow_ = Decimal(0)
73 | for add in rmse_shadow:
74 | rmse_shadow_ += Decimal(float(add))
75 | n_pxl_shadow_ = Decimal(0)
76 | for add in n_pxl_shadow:
77 | n_pxl_shadow_ += Decimal(float(add))
78 | rmse_non_shadow_ = Decimal(0)
79 | for add in rmse_non_shadow:
80 | rmse_non_shadow_ += Decimal(float(add))
81 | n_pxl_non_shadow_ = Decimal(0)
82 | for add in n_pxl_non_shadow:
83 | n_pxl_non_shadow_ += Decimal(float(add))
84 | rmse_all_ = Decimal(0)
85 | for add in rmse_all:
86 | rmse_all_ += Decimal(float(add))
87 | n_pxl_all_ = Decimal(0)
88 | for add in n_pxl_all:
89 | n_pxl_all_ += Decimal(float(add))
90 |
91 | # compute
92 | rmse_shadow_eval = rmse_shadow_ / n_pxl_shadow_
93 | rmse_non_shadow_eval = rmse_non_shadow_ / n_pxl_non_shadow_
94 | rmse_all_eval = rmse_all_ / n_pxl_all_
95 |
96 | self.multiprocessingi_utils.close()
97 |
98 | return float(rmse_shadow_eval), float(rmse_non_shadow_eval), float(rmse_all_eval)
99 |
100 |
101 | def rgb2lab(img):
102 | img = img.transpose([1, 2, 0])
103 | return skcolor.rgb2lab(img)
104 |
105 |
106 | class MultiprocessingiUtils():
107 | def __init__(self, nThreads=4):
108 | self.nThreads = nThreads
109 | self.pool = multiprocessing.Pool(processes=self.nThreads)
110 |
111 | def rgb2lab(self, inputs):
112 | inputs = inputs.cpu().numpy()
113 | inputs_ = self.pool.map(rgb2lab, inputs)
114 | i = 0
115 | for input_ in inputs_:
116 | inputs[i, :, :, :] = input_.transpose([2, 0, 1])/255
117 | i += 1
118 | output = torch.from_numpy(inputs).cuda()
119 | return output
120 |
121 | def close(self):
122 | self.pool.close()
123 | self.pool.join()
124 |
125 | def creat_pool(self):
126 | self.pool = multiprocessing.Pool(processes=self.nThreads)
127 |
--------------------------------------------------------------------------------
/src/dataset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import random
4 |
5 | import Augmentor
6 | import cv2
7 | import numpy as np
8 | import torch
9 | import torchvision.transforms.functional as F
10 | from PIL import Image
11 | from torch.utils.data import DataLoader
12 | from torchvision import transforms as T
13 | from skimage.color import rgb2gray
14 |
15 | from .utils import imshow, create_mask
16 |
17 |
18 | class Dataset(torch.utils.data.Dataset):
19 | def __init__(self, config, data_flist, mask_flist, GT_flist, additional_mask=[], augment=True):
20 | super(Dataset, self).__init__()
21 | self.augment = augment
22 | self.data = self.load_flist(data_flist)
23 | self.mask = self.load_flist(mask_flist)
24 | self.GT = self.load_flist(GT_flist)
25 | if len(additional_mask) != 0:
26 | self.additional_mask = self.load_flist(additional_mask)
27 | else:
28 | self.additional_mask = []
29 |
30 | self.input_size_h = config.INPUT_SIZE_H
31 | self.input_size_w = config.INPUT_SIZE_W
32 | self.mask_type = config.MASK
33 | self.dataset_name = config.SUBJECT_WORD
34 |
35 | def __len__(self):
36 | return len(self.data)
37 |
38 | def __getitem__(self, index):
39 | try:
40 | item = self.load_item(index)
41 | except:
42 | print('loading error: ' + self.data[index])
43 | item = self.load_item(0)
44 |
45 | return item
46 |
47 | def load_name(self, index):
48 | name = self.data[index]
49 | return os.path.basename(name)
50 |
51 | def load_item(self, index):
52 |
53 | size_h = self.input_size_h
54 | size_w = self.input_size_w
55 |
56 | # load image
57 | img = cv2.imread(self.data[index], -1)
58 | mask_GT = cv2.imread(self.mask[index], -1)
59 | GT = cv2.imread(self.GT[index], -1)
60 |
61 | # augment data
62 | if self.augment:
63 | img, mask_GT, GT = self.data_augment(
64 | img, mask_GT, GT)
65 | else:
66 | imgh, imgw = img.shape[0:2]
67 | if not (size_h == imgh and size_w == imgw):
68 | img = self.resize(img, size_h, size_w)
69 | mask_GT = self.resize(mask_GT, size_h, size_w)
70 | GT = self.resize(GT, size_h, size_w)
71 |
72 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
73 | GT = cv2.cvtColor(GT, cv2.COLOR_BGR2RGB)
74 |
75 | # imshow(Image.fromarray(edge_truth))
76 | return self.to_tensor(img), self.to_tensor(mask_GT), self.to_tensor(GT)
77 |
78 | def data_augment(self, img, mask_GT, GT):
79 |
80 | # https://github.com/mdbloice/Augmentor
81 | images = [[img, GT, mask_GT]]
82 | p = Augmentor.DataPipeline(images)
83 | p.flip_random(1)
84 | width=random.randint(256,640)
85 | height = round((width/self.input_size_w)*self.input_size_h)
86 | p.resize(1, width, height)
87 | g = p.generator(batch_size=1)
88 | augmented_images = next(g)
89 | images = augmented_images[0][0]
90 | GT = augmented_images[0][1]
91 | mask_GT = augmented_images[0][2]
92 |
93 | if len(self.additional_mask) != 0:
94 | additional_mask = self.load_mask(
95 | self.input_size_h, self.input_size_w, random.randint(0, len(self.additional_mask)))
96 | additional_mask = additional_mask.astype(
97 | np.single)/np.max(additional_mask)*np.random.uniform(0, 0.8)
98 | additional_mask = additional_mask[:, :, np.newaxis]
99 | images = images*(1-additional_mask)
100 | images = images.astype(np.uint8)
101 | return images, mask_GT, GT
102 |
103 | def load_mask(self, imgh, imgw, index):
104 | mask_type = self.mask_type
105 |
106 | # external + random block
107 | if mask_type == 4:
108 | mask_type = 1 if np.random.binomial(1, 0.5) == 1 else 3
109 |
110 | # external + random block + half
111 | elif mask_type == 5:
112 | mask_type = np.random.randint(1, 4)
113 |
114 | # random block
115 | if mask_type == 1:
116 | return create_mask(imgw, imgh, imgw // 2, imgh // 2)
117 |
118 | # half
119 | if mask_type == 2:
120 | # randomly choose right or left
121 | return create_mask(imgw, imgh, imgw // 2, imgh, 0 if random.random() < 0.5 else imgw // 2, 0)
122 |
123 | # external
124 | if mask_type == 3:
125 | mask_index = random.randint(0, len(self.additional_mask) - 1)
126 | mask = cv2.imread(self.additional_mask[mask_index])
127 | mask = 255-mask
128 | mask = self.resize(mask, imgh, imgw)
129 | # threshold due to interpolation
130 | mask = (mask > 0).astype(np.uint8) * 255
131 | return mask
132 |
133 | # test mode: load mask non random
134 | if mask_type == 6:
135 | mask = cv2.imread(self.additional_mask[index])
136 | mask = self.resize(mask, imgh, imgw, centerCrop=False)
137 | mask = rgb2gray(mask)
138 | mask = (mask > 0).astype(np.uint8) * 255
139 | return mask
140 |
141 | def to_tensor(self, img):
142 | img = Image.fromarray(img) # returns an image object.
143 | img_t = F.to_tensor(img).float()
144 | return img_t
145 |
146 | def resize(self, img, height, width):
147 | img = cv2.resize(img, (width, height), interpolation=cv2.INTER_CUBIC)
148 | return img
149 |
150 | def load_flist(self, flist):
151 | if isinstance(flist, list):
152 | return flist
153 |
154 | # flist: image file path, image directory path, text file flist path
155 | if isinstance(flist, str):
156 | if os.path.isdir(flist):
157 | flist = list(glob.glob(flist + '/*.jpg')) + \
158 | list(glob.glob(flist + '/*.png'))
159 | flist.sort()
160 | return flist
161 |
162 | if os.path.isfile(flist):
163 | try:
164 | return np.genfromtxt(flist, dtype=np.str, encoding='utf-8')
165 | except:
166 | return [flist]
167 |
168 | return []
169 |
170 | def create_iterator(self, batch_size):
171 | while True:
172 | sample_loader = DataLoader(
173 | dataset=self,
174 | batch_size=batch_size,
175 | drop_last=True
176 | )
177 |
178 | for item in sample_loader:
179 | yield item
180 |
--------------------------------------------------------------------------------
/src/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | This part of the code is built based on the project:
3 | https://github.com/knazeri/edge-connect
4 | """
5 | import torch
6 | import torch.nn as nn
7 | import torchvision.models as models
8 | from .utils import blur, gauss_kernel, sobel, sobel_kernel
9 | from torch.autograd import Variable
10 | import numpy as np
11 | import os
12 |
13 |
14 | class AdversarialLoss(nn.Module):
15 | r"""
16 | Adversarial loss
17 | https://arxiv.org/abs/1711.10337
18 | """
19 |
20 | def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0):
21 | r"""
22 | type = nsgan | lsgan | hinge
23 | """
24 | super(AdversarialLoss, self).__init__()
25 |
26 | self.type = type
27 | self.register_buffer('real_label', torch.tensor(target_real_label))
28 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
29 |
30 | if type == 'nsgan':
31 | self.criterion = nn.BCELoss(reduction="mean")
32 |
33 | elif type == 'lsgan':
34 | self.criterion = nn.MSELoss(reduction="mean")
35 |
36 | elif type == 'hinge':
37 | self.criterion = nn.ReLU()
38 |
39 | def __call__(self, outputs, is_real, is_disc=None):
40 | if self.type == 'hinge':
41 | if is_disc:
42 | if is_real:
43 | outputs = -outputs
44 | return self.criterion(1 + outputs).mean()
45 | else:
46 | return (-outputs).mean()
47 |
48 | else:
49 | labels = (self.real_label if is_real else self.fake_label).expand_as(
50 | outputs)
51 | loss = self.criterion(outputs, labels)
52 | return loss
53 |
54 |
55 | class StyleLoss(nn.Module):
56 | r"""
57 | Perceptual loss, VGG-based
58 | https://arxiv.org/abs/1603.08155
59 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
60 | """
61 |
62 | def __init__(self):
63 | super(StyleLoss, self).__init__()
64 | self.add_module('vgg19', VGG19())
65 | self.criterion = torch.nn.L1Loss(reduction="mean")
66 |
67 | def compute_gram(self, x):
68 | b, ch, h, w = x.size()
69 | f = x.view(b, ch, w * h)
70 | f_T = f.transpose(1, 2)
71 | G = f.bmm(f_T) / (h * w * ch)
72 |
73 | return G
74 |
75 | def __call__(self, x, y):
76 | # Compute features
77 | x_vgg, y_vgg = self.vgg19(x), self.vgg19(y)
78 |
79 | # Compute loss
80 | style_loss = 0.0
81 | style_loss += self.criterion(self.compute_gram(
82 | x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2']))
83 | style_loss += self.criterion(self.compute_gram(
84 | x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4']))
85 | style_loss += self.criterion(self.compute_gram(
86 | x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4']))
87 | style_loss += self.criterion(self.compute_gram(
88 | x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2']))
89 |
90 | return style_loss
91 |
92 |
93 | class PerceptualLoss(nn.Module):
94 | r"""
95 | Perceptual loss, VGG-based
96 | https://arxiv.org/abs/1603.08155
97 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
98 | """
99 |
100 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
101 | super(PerceptualLoss, self).__init__()
102 | self.add_module('vgg19', VGG19())
103 | self.criterion = torch.nn.L1Loss(reduction="mean")
104 | self.weights = weights
105 |
106 | def __call__(self, x, y):
107 | # Compute features
108 | x_vgg, y_vgg = self.vgg19(x), self.vgg19(y)
109 |
110 | p1 = self.weights[0] * \
111 | self.criterion(x_vgg['relu1_2'], y_vgg['relu1_2'])
112 | p2 = self.weights[1] * \
113 | self.criterion(x_vgg['relu2_2'], y_vgg['relu2_2'])
114 | p3 = self.weights[2] * \
115 | self.criterion(x_vgg['relu3_2'], y_vgg['relu3_2'])
116 | p4 = self.weights[3] * \
117 | self.criterion(x_vgg['relu4_2'], y_vgg['relu4_2'])
118 | p5 = self.weights[4] * \
119 | self.criterion(x_vgg['relu5_2'], y_vgg['relu5_2'])
120 |
121 | return p1+p2+p3+p4+p5
122 |
123 |
124 | class VGG19(torch.nn.Module):
125 | def __init__(self):
126 | super(VGG19, self).__init__()
127 | # https://pytorch.org/hub/pytorch_vision_vgg/
128 | mean = np.array(
129 | [0.485, 0.456, 0.406], dtype=np.float32)
130 | mean = mean.reshape((1, 3, 1, 1))
131 | self.mean = Variable(torch.from_numpy(mean)).cuda()
132 | std = np.array(
133 | [0.229, 0.224, 0.225], dtype=np.float32)
134 | std = std.reshape((1, 3, 1, 1))
135 | self.std = Variable(torch.from_numpy(std)).cuda()
136 | self.initial_model()
137 |
138 | def forward(self, x):
139 | relu1_1 = self.relu1_1((x-self.mean)/self.std)
140 | relu1_2 = self.relu1_2(relu1_1)
141 |
142 | relu2_1 = self.relu2_1(relu1_2)
143 | relu2_2 = self.relu2_2(relu2_1)
144 |
145 | relu3_1 = self.relu3_1(relu2_2)
146 | relu3_2 = self.relu3_2(relu3_1)
147 | relu3_3 = self.relu3_3(relu3_2)
148 | relu3_4 = self.relu3_4(relu3_3)
149 |
150 | relu4_1 = self.relu4_1(relu3_4)
151 | relu4_2 = self.relu4_2(relu4_1)
152 | relu4_3 = self.relu4_3(relu4_2)
153 | relu4_4 = self.relu4_4(relu4_3)
154 |
155 | relu5_1 = self.relu5_1(relu4_4)
156 | relu5_2 = self.relu5_2(relu5_1)
157 | relu5_3 = self.relu5_3(relu5_2)
158 | relu5_4 = self.relu5_4(relu5_3)
159 |
160 | out = {
161 | 'relu1_1': relu1_1,
162 | 'relu1_2': relu1_2,
163 |
164 | 'relu2_1': relu2_1,
165 | 'relu2_2': relu2_2,
166 |
167 | 'relu3_1': relu3_1,
168 | 'relu3_2': relu3_2,
169 | 'relu3_3': relu3_3,
170 | 'relu3_4': relu3_4,
171 |
172 | 'relu4_1': relu4_1,
173 | 'relu4_2': relu4_2,
174 | 'relu4_3': relu4_3,
175 | 'relu4_4': relu4_4,
176 |
177 | 'relu5_1': relu5_1,
178 | 'relu5_2': relu5_2,
179 | 'relu5_3': relu5_3,
180 | 'relu5_4': relu5_4,
181 | }
182 | return out
183 |
184 | def load_pretrained(self, vgg19_weights_path, gpu):
185 | if os.path.exists(vgg19_weights_path):
186 | if torch.cuda.is_available():
187 | data = torch.load(vgg19_weights_path)
188 | print("load vgg_pretrained_model:"+vgg19_weights_path)
189 | else:
190 | data = torch.load(vgg19_weights_path,
191 | map_location=lambda storage, loc: storage)
192 | self.initial_model(data)
193 | self.to(gpu)
194 | else:
195 | print("you need download vgg_pretrained_model in the directory of "+str(self.config.DATA_ROOT) +
196 | "\n'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'")
197 | raise Exception("Don't load vgg_pretrained_model")
198 |
199 | def initial_model(self,data=None):
200 | vgg19 = models.vgg19()
201 | if data is not None:
202 | vgg19.load_state_dict(data)
203 | features = vgg19.features
204 | self.relu1_1 = torch.nn.Sequential()
205 | self.relu1_2 = torch.nn.Sequential()
206 |
207 | self.relu2_1 = torch.nn.Sequential()
208 | self.relu2_2 = torch.nn.Sequential()
209 |
210 | self.relu3_1 = torch.nn.Sequential()
211 | self.relu3_2 = torch.nn.Sequential()
212 | self.relu3_3 = torch.nn.Sequential()
213 | self.relu3_4 = torch.nn.Sequential()
214 |
215 | self.relu4_1 = torch.nn.Sequential()
216 | self.relu4_2 = torch.nn.Sequential()
217 | self.relu4_3 = torch.nn.Sequential()
218 | self.relu4_4 = torch.nn.Sequential()
219 |
220 | self.relu5_1 = torch.nn.Sequential()
221 | self.relu5_2 = torch.nn.Sequential()
222 | self.relu5_3 = torch.nn.Sequential()
223 | self.relu5_4 = torch.nn.Sequential()
224 |
225 | for x in range(2):
226 | self.relu1_1.add_module(str(x), features[x])
227 |
228 | for x in range(2, 4):
229 | self.relu1_2.add_module(str(x), features[x])
230 |
231 | for x in range(4, 7):
232 | self.relu2_1.add_module(str(x), features[x])
233 |
234 | for x in range(7, 9):
235 | self.relu2_2.add_module(str(x), features[x])
236 |
237 | for x in range(9, 12):
238 | self.relu3_1.add_module(str(x), features[x])
239 |
240 | for x in range(12, 14):
241 | self.relu3_2.add_module(str(x), features[x])
242 |
243 | for x in range(14, 16):
244 | self.relu3_3.add_module(str(x), features[x])
245 |
246 | for x in range(16, 18):
247 | self.relu3_4.add_module(str(x), features[x])
248 |
249 | for x in range(18, 21):
250 | self.relu4_1.add_module(str(x), features[x])
251 |
252 | for x in range(21, 23):
253 | self.relu4_2.add_module(str(x), features[x])
254 |
255 | for x in range(23, 25):
256 | self.relu4_3.add_module(str(x), features[x])
257 |
258 | for x in range(25, 27):
259 | self.relu4_4.add_module(str(x), features[x])
260 |
261 | for x in range(27, 30):
262 | self.relu5_1.add_module(str(x), features[x])
263 |
264 | for x in range(30, 32):
265 | self.relu5_2.add_module(str(x), features[x])
266 |
267 | for x in range(32, 34):
268 | self.relu5_3.add_module(str(x), features[x])
269 |
270 | for x in range(34, 36):
271 | self.relu5_4.add_module(str(x), features[x])
272 |
273 | # don't need the gradients, just want the features
274 | # for param in self.parameters():
275 | # param.requires_grad = False
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DMTN
2 |
3 |
4 |
5 | # 1. Resources
6 | [国内资源链接(密码:e2ww)](https://rec.ustc.edu.cn/share/16db9c80-15d5-11ef-9ca4-7b154f4fe8b0)
7 |
8 | ## 1.1 Dataset
9 | * SRD ([github](https://github.com/Liangqiong/DeShadowNet) | [paper](https://openaccess.thecvf.com/content_cvpr_2017/papers/Qu_DeshadowNet_A_Multi-Context_CVPR_2017_paper.pdf))
10 | * ISTD ([github](https://github.com/DeepInsight-PCALab/ST-CGAN) | [paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Wang_Stacked_Conditional_Generative_CVPR_2018_paper.pdf))
11 | * ISTD+DA, SRD+DA ([github](https://github.com/vinthony/ghost-free-shadow-removal) | [paper](https://arxiv.org/abs/1911.08718))
12 | * [SSRD](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EckYI_84wMdJpVgf5EqkmSABmnOD_-53YZ6v2KIQsiLeXA?e=67Gj47)
13 |
14 | The SSRD dataset does not contain the ground truth of shadow-free images due to the presence of self shadow in images.
15 |
16 | ## 1.2 Results | Model Weight
17 |
18 | **TEST RESULTS ON SRD:**
19 | * Results on SRD: [DMTN_SRD](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EQt3ZoJAbZ5Cq_mhHxzrUVYBcsiaPLjnsN-SmhotYz-UOg?e=hS2RkJ) | Weight: [DMTN_SRD.pth](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EWc4B9PP-rtGp4LxPWGOkfoB6oi6Coh1tu-qG5qxBk-7Cg?e=sdlSiA)
20 | * Results on SRD: [DMTN+Mask_SRD](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EU0JfEPuOUNNlSQDc0TfYgQBOAtMpXjK5yRoa3q2H_bcnQ?e=MgsEu5) | Weight: [DMTN+Mask_SRD.pth](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EU4NJ0CPbwpBrzyXH5FLlMEBzqwKhcXlxe8k4vQiXRrJUw?e=FPKlfF)
21 |
22 | **TEST RESULTS ON ISTD:**
23 | * Results on ISTD: [DMTN_ISTD](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EbyzaV72N2FElC5nOsp3-ZYBuUoVLiy29rmXBMXVXXY6Lg?e=wKA55D) | Weight: [DMTN_ISTD.pth](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EROUwnLgz9BGi3OtJa5SIs8BwdgYBZTXeMJ1NLcGfHAwCg?e=v1f51U)
24 | * Results on ISTD: [DMTN+Mask_ISTD](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EcRgA4y1UAZIkRIabbm71iIBNhRH-JIugQDbInyWE3rpNQ?e=D8BiGD) | Weight: [DMTN+Mask_ISTD.pth](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EcnDQNKeoRdBtUYjQdirl34BR73n--qRnFIo6RnxPvk-KQ?e=9V0LIR)
25 | * Results on ISTD+DA: [DMTN_ISTD_DA](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/ERGASEyFybBDm9rYZv4a3I4B6FwMmrhZMImk_-b7Lo-YeQ?e=MbzrMk) | Weight: [DMTN_ISTD_DA.pth](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EWwqrUr7Qh9KugvJ2S5KsdMBYz6aiR-ufiX3kn3zB626lg?e=7QJRra)
26 |
27 | **TEST RESULTS ON ISTD+:**
28 | * Results on ISTD+: [DMTN_ISTD+](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EZnB81g7L3VPuGo2zhVclVEBPhsO6MBJYPtbOnqxmEDHuw?e=MZLmUM) | Weight: [DMTN_ISTD+.pth](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/ETVno1MtDsdLknqDNKq60VwB9Bq-oq8kZ8B8aiwQBZXbQQ?e=B0S37N)
29 | * Results on ISTD+: [DMTN+Mask_ISTD+](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EZEQr_hD7XdGgPiesl0L8aABSugt0z5U6V9q2Wv-fEr-VA?e=zq5A7s) | Weight: [DMTN+Mask_ISTD+.pth](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/ETo6UMeCGNhFjJ20o0RedaQBG7XIDcfbqucJ3A-hK6IQAQ?e=uN6sTs)
30 |
31 | **TEST RESULTS ON SSRD:** (DHAN and DMTN are pretrained on SRD dataset (size:420x320))
32 |
33 | * Results of DMTN on SSRD: [DMTN_SSRD](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/ET7vtW6b-RNFiK7hJe8coXoBjMMUj2vZ4nEj1SitH8wuKA?e=ZDnfYV) | Weight: [DMTN_SRD_420_320.pth](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EQgZbEFJCLZGiAM8rnbE-ZUBHXw3zyTrhdet7JDSCrYiuA?e=6PcofV)
34 | * Results of DHAN on SSRD:[DHAN_SSRD](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EROyGJwa2C5JkO1bLDVV_AsBbXRPKoZbBy5EsjAsz6xujg?e=nw53O6)
35 |
36 |
37 | ## 1.3 Visual results
38 |
39 |
40 | *Visual comparison results of **penumbra** removal on the SRD dataset - (Powered by [MulimgViewer](https://github.com/nachifur/MulimgViewer))*
41 |
42 |
43 |
44 | *Visual comparison results of **self shadow** removal on the SSRD dataset - (Powered by [MulimgViewer](https://github.com/nachifur/MulimgViewer))*
45 |
46 |
47 |
48 |
49 | ## 1.4 Evaluation Code
50 | Currently, MATLAB evaluation codes are used in most state-of-the-art works for shadow removal.
51 |
52 | [Our evaluation code](https://github.com/nachifur/DMTN/blob/main/evaluation/evaluate_MAE_PSNR_SSIM.m) (i.e., 1+2)
53 | 1. MAE (i.e., RMSE in paper): https://github.com/tsingqguo/exposure-fusion-shadow-removal
54 | 2. PSNR+SSIM: https://github.com/zhuyr97/AAAI2022_Unfolding_Network_Shadow_Removal/tree/master/codes
55 |
56 | Notably, there are slight differences between the different evaluation codes.
57 | * [wang_cvpr2018](https://github.com/DeepInsight-PCALab/ST-CGAN), [le_iccv2019](https://github.com/cvlab-stonybrook/SID): no imresize;
58 | * [fu_cvpr2021](https://github.com/tsingqguo/exposure-fusion-shadow-removal): first imresize, then double;
59 | * [zhu_aaai2022](https://github.com/zhuyr97/AAAI2022_Unfolding_Network_Shadow_Removal): first double, then imresize;
60 | * Our evaluation code: MAE->fu_cvpr2021, psnr+ssim->zhu_aaai2022
61 |
62 | # 2. Environments
63 | **ubuntu18.04+cuda10.2+pytorch1.7.1**
64 | 1. create environments
65 | ```
66 | conda env create -f install.yaml
67 | ```
68 | 2. activate environments
69 | ```
70 | conda activate DMTN
71 | ```
72 |
73 | # 3. Data Processing
74 | For example, generate the dataset list of ISTD:
75 | 1. Download:
76 | * ISTD and SRD
77 | * [USR shadowfree images](https://github.com/xw-hu/Mask-ShadowGAN)
78 | * [Syn. Shadow](https://github.com/vinthony/ghost-free-shadow-removal)
79 | * [SRD shadow mask](https://github.com/vinthony/ghost-free-shadow-removal)
80 | * train_B_ISTD:
81 | ```
82 | cp -r ISTD_Dataset_arg/train_B ISTD_Dataset_arg/train_B_ISTD
83 | cp -r ISTD_Dataset_arg/train_B SRD_Dataset_arg/train_B_ISTD
84 | ```
85 | * [VGG19](https://download.pytorch.org/models/vgg19-dcbb9e9d.pth)
86 | ```
87 | cp vgg19-dcbb9e9d.pth ISTD_Dataset_arg/
88 | cp vgg19-dcbb9e9d.pth SRD_Dataset_arg/
89 | ```
90 | 2. The data folders should be:
91 | ```
92 | ISTD_Dataset_arg
93 | * train
94 | - train_A # ISTD shadow image
95 | - train_B # ISTD shadow mask
96 | - train_C # ISTD shadowfree image
97 | - shadow_free # USR shadowfree images
98 | - synC # Syn. shadow
99 | - train_B_ISTD # ISTD shadow mask
100 | * test
101 | - test_A # ISTD shadow image
102 | - test_B # ISTD shadow mask
103 | - test_C # ISTD shadowfree image
104 | * vgg19-dcbb9e9d.pth
105 |
106 | SRD_Dataset_arg
107 | * train # renaming the original `Train` folder in `SRD`.
108 | - train_A # SRD shadow image, renaming the original `shadow` folder in `SRD`.
109 | - train_B # SRD shadow mask
110 | - train_C # SRD shadowfree image, renaming the original `shadow_free` folder in `SRD`.
111 | - shadow_free # USR shadowfree images
112 | - synC # Syn. shadow
113 | - train_B_ISTD # ISTD shadow mask
114 | * test # renaming the original `test_data` folder in `SRD`.
115 | - train_A # SRD shadow image, renaming the original `shadow` folder in `SRD`.
116 | - train_B # SRD shadow mask
117 | - train_C # SRD shadowfree image, renaming the original `shadow_free` folder in `SRD`.
118 | * vgg19-dcbb9e9d.pth
119 | ```
120 | 3. Edit `generate_flist_istd.py`: (Replace path)
121 |
122 | ```
123 | ISTD_path = "/Your_data_storage_path/ISTD_Dataset_arg"
124 | ```
125 | 4. Generate Datasets List. (Already contains ISTD+DA.)
126 | ```
127 | conda activate DMTN
128 | cd script/
129 | python generate_flist_istd.py
130 | ```
131 | 5. Edit `config_ISTD.yml`: (Replace path)
132 | ```
133 | DATA_ROOT: /Your_data_storage_path/ISTD_Dataset_arg
134 | ```
135 |
136 | # 4. Training+Test+Evaluation
137 | ## 4.1 Training+Test+Evaluation
138 | For example, training+test+evaluation on ISTD dataset.
139 | ```
140 | cp config/config_ISTD.yml config.yml
141 | cp config/run_ISTD.py run.py
142 | conda activate DMTN
143 | python run.py
144 | ```
145 | ## 4.2 Only Test and Evaluation
146 | For example, test+evaluation on ISTD dataset.
147 | 1. Download weight file(`DMTN_ISTD.pth`) to `pre_train_model/ISTD`
148 | 2. Copy file
149 | ```
150 | cp config/config_ISTD.yml config.yml
151 | cp config/run_ISTD.py run.py
152 | mkdir -p checkpoints/ISTD/
153 | cp config.yml checkpoints/ISTD/config.yml
154 | cp pre_train_model/ISTD/DMTN_ISTD.pth checkpoints/ISTD/ShadowRemoval.pth
155 | ```
156 |
157 | 3. Edit `run.py`. Comment the training code.
158 |
159 | ```
160 | # # pre_train (no data augmentation)
161 | # MODE = 0
162 | # print('\nmode-'+str(MODE)+': start pre_training(data augmentation)...\n')
163 | # for i in range(1):
164 | # skip_train = init_config(checkpoints_path, MODE=MODE,
165 | # EVAL_INTERVAL_EPOCH=1, EPOCH=[90,i])
166 | # if not skip_train:
167 | # main(MODE, config_path)
168 | # src_path = Path('./pre_train_model') / \
169 | # config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_da.pth')
170 | # copypth(dest_path, src_path)
171 |
172 | # # train
173 | # MODE = 2
174 | # print('\nmode-'+str(MODE)+': start training...\n')
175 | # for i in range(1):
176 | # skip_train = init_config(checkpoints_path, MODE=MODE,
177 | # EVAL_INTERVAL_EPOCH=0.1, EPOCH=[60,i])
178 | # if not skip_train:
179 | # main(MODE, config_path)
180 | # src_path = Path('./pre_train_model') / \
181 | # config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_final.pth')
182 | # copypth(dest_path, src_path)
183 | ```
184 | 4. Run
185 |
186 | ```
187 | conda activate DMTN
188 | python run.py
189 | ```
190 | ## 4.3 Show Results
191 | After evaluation, execute the following code to display the final RMSE.
192 | ```
193 | python show_eval_result.py
194 | ```
195 | Output:
196 | ```
197 | running rmse-shadow: xxx, rmse-non-shadow: xxx, rmse-all: xxx # ISRD
198 | ```
199 | This is the evaluation result of python+pytorch, which is only used during training. To get the evaluation results in the paper, you need to run the [matlab code](#1.4).
200 |
201 | ## 4.4 Test on SSRD
202 | 1. Edit `src/network/network_DMTN.py`. Modify the line (https://github.com/nachifur/DMTN/blob/main/src/network/network_DMTN.py#L339).
203 | ```
204 | SSRD = True
205 | ```
206 | 2. Test like the section `4.2 Only Test and Evaluation`.
207 |
208 | # 5. Acknowledgements
209 | Part of the code is based upon:
210 | * https://github.com/nachifur/LLPC
211 | * https://github.com/vinthony/ghost-free-shadow-removal
212 | * https://github.com/knazeri/edge-connect
213 |
214 | # 6. Citation
215 | ```
216 | @ARTICLE{liu2023decoupled,
217 | author={Liu, Jiawei and Wang, Qiang and Fan, Huijie and Li, Wentao and Qu, Liangqiong and Tang, Yandong},
218 | journal={IEEE Transactions on Multimedia},
219 | title={A Decoupled Multi-Task Network for Shadow Removal},
220 | year={2023},
221 | volume={},
222 | number={},
223 | pages={1-14},
224 | doi={10.1109/TMM.2023.3252271}}
225 | ```
226 | # 7. Contact
227 | Please contact Jiawei Liu if there is any question (liujiawei18@mails.ucas.ac.cn).
228 |
229 | # 8. Revised Errors in the Paper
230 | Sorry! Here are the revised errors:
231 | 1. In Section III-C-2)-`Fig. 7 (or Fig. 5(b)) shows...`, "we can achieve feature decoupling, i.e., some channels of F represent shadow images (~~`I_m`~~ `I_s`)".
232 |
--------------------------------------------------------------------------------
/src/network/networks.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | import torch.nn.functional as F
6 |
7 | from src.image_pool import ImagePool
8 |
9 | from src.loss import AdversarialLoss, PerceptualLoss, StyleLoss
10 |
11 |
12 | class BaseNetwork(nn.Module):
13 | def __init__(self):
14 | super(BaseNetwork, self).__init__()
15 |
16 | def init_weights(self, init_type='normal', gain=0.02):
17 | '''
18 | initialize network's weights
19 | init_type: normal | xavier | kaiming | orthogonal
20 | '''
21 |
22 | def init_func(m):
23 | classname = m.__class__.__name__
24 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
25 | if init_type == 'normal':
26 | nn.init.normal_(m.weight.data, 0.0, gain)
27 | elif init_type == 'xavier':
28 | nn.init.xavier_normal_(m.weight.data, gain=gain)
29 | elif init_type == 'kaiming':
30 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
31 | elif init_type == 'orthogonal':
32 | nn.init.orthogonal_(m.weight.data, gain=gain)
33 |
34 | if hasattr(m, 'bias') and m.bias is not None:
35 | nn.init.constant_(m.bias.data, 0.0)
36 |
37 | elif classname.find('BatchNorm2d') != -1:
38 | nn.init.normal_(m.weight.data, 1.0, gain)
39 | nn.init.constant_(m.bias.data, 0.0)
40 |
41 | self.apply(init_func)
42 |
43 |
44 | class Discriminator(BaseNetwork):
45 | def __init__(self, config, in_channels, init_weights=True):
46 | super(Discriminator, self).__init__()
47 | # config
48 | self.config = config
49 | self.use_sigmoid = self.config.GAN_LOSS != 'hinge'
50 | norm = self.config.DIS_NORM
51 | # network
52 | self.conv1 = conv2d_layer(
53 | in_channels, 64, kernel_size=4, stride=2, norm=norm)
54 | self.conv2 = conv2d_layer(64, 128, kernel_size=4, stride=2, norm=norm)
55 | self.conv3 = conv2d_layer(128, 256, kernel_size=4, stride=2, norm=norm)
56 | self.conv4 = conv2d_layer(256, 512, kernel_size=4, stride=1, norm=norm)
57 | self.conv5 = conv2d_layer(512, 1, kernel_size=4, stride=1, norm=norm)
58 | # loss
59 | self.add_module('adversarial_loss',
60 | AdversarialLoss(type=self.config.GAN_LOSS))
61 | self.add_module('perceptual_loss',
62 | PerceptualLoss())
63 | self.fake_pool = ImagePool(self.config.POOL_SIZE)
64 | # optimizer
65 | self.optimizer = optim.Adam(
66 | params=self.parameters(),
67 | lr=float(config.LR_D),
68 | betas=(config.BETA1, config.BETA2)
69 | )
70 | if init_weights:
71 | self.init_weights()
72 |
73 | def forward(self, x):
74 | conv1 = self.conv1(x)
75 | conv2 = self.conv2(conv1)
76 | conv3 = self.conv3(conv2)
77 | conv4 = self.conv4(conv3)
78 | conv5 = self.conv5(conv4)
79 |
80 | outputs = conv5
81 | if self.use_sigmoid:
82 | outputs = torch.sigmoid(conv5)
83 |
84 | return outputs, [conv1, conv2, conv3, conv4, conv5]
85 |
86 | def cal_loss(self, images, outputs, GT, cat_img=True):
87 | # discriminator loss
88 | loss = []
89 | dis_loss = 0
90 | if cat_img:
91 | dis_input_real = torch.cat((images, GT), dim=1)
92 | dis_input_fake = torch.cat((images, F.interpolate(self.fake_pool.query(outputs.detach()), outputs.shape[2:], mode="bilinear",align_corners=True)), dim=1)
93 | else:
94 | dis_input_real = GT
95 | dis_input_fake = F.interpolate(self.fake_pool.query(outputs.detach()), outputs.shape[2:], mode="bilinear",align_corners=True)
96 |
97 | dis_real, dis_real_feat = self(dis_input_real)
98 | dis_fake, dis_fake_feat = self(dis_input_fake)
99 | dis_real_loss = self.adversarial_loss(dis_real, True, True)
100 | dis_fake_loss = self.adversarial_loss(dis_fake, False, True)
101 | dis_loss += (dis_real_loss + dis_fake_loss) / 2
102 | loss.append(dis_loss)
103 |
104 | # generator adversarial loss
105 | if cat_img:
106 | gen_input_fake = torch.cat((images, outputs), dim=1)
107 | else:
108 | gen_input_fake = outputs
109 | gen_fake, gen_fake_feat = self(gen_input_fake)
110 | gen_gan_loss = self.adversarial_loss(gen_fake, True, False)
111 |
112 | # generator perceptual loss
113 | gen_perceptual_loss = self.perceptual_loss(outputs, GT)
114 |
115 | loss.append(gen_gan_loss)
116 | loss.append(gen_perceptual_loss)
117 | return loss
118 |
119 | def backward(self, loss):
120 | self.optimizer.zero_grad()
121 | loss.backward()
122 | self.optimizer.step()
123 |
124 |
125 | def conv2d_layer(in_channels, channels, kernel_size=3, stride=1, padding=1, dilation=1, norm="batch", activation_fn="LeakyReLU", conv_mode="none", pad_mode="ReflectionPad2d"):
126 | """
127 | norm: batch, spectral, instance, spectral_instance, none
128 |
129 | activation_fn: Sigmoid, ReLU, LeakyReLU, none
130 |
131 | conv_mode: transpose, upsample, none
132 |
133 | pad_mode: ReflectionPad2d, ReplicationPad2d, ZeroPad2d
134 | """
135 | layer = []
136 | # padding
137 | if conv_mode == "transpose":
138 | pass
139 | else:
140 | if pad_mode == "ReflectionPad2d":
141 | layer.append(nn.ReflectionPad2d(padding))
142 | elif pad_mode == "ReplicationPad2d":
143 | layer.append(nn.ReflectionPad2d(padding))
144 | else:
145 | layer.append(nn.ZeroPad2d(padding))
146 | padding = 0
147 |
148 | # conv layer
149 | if norm == "spectral" or norm == "spectral_instance":
150 | bias = False
151 | # conv
152 | if conv_mode == "transpose":
153 | conv_ = nn.ConvTranspose2d
154 | elif conv_mode == "upsample":
155 | layer.append(nn.Upsample(mode='bilinear', scale_factor=stride))
156 | conv_ = nn.Conv2d
157 | else:
158 | conv_ = nn.Conv2d
159 | else:
160 | bias = True
161 | # conv
162 | if conv_mode == "transpose":
163 | layer.append(nn.ConvTranspose2d(in_channels, channels, kernel_size,
164 | bias=bias, stride=stride, padding=padding, dilation=dilation))
165 | elif conv_mode == "upsample":
166 | layer.append(nn.Upsample(mode='bilinear', scale_factor=stride))
167 | layer.append(nn.Conv2d(in_channels, channels, kernel_size,
168 | bias=bias, stride=stride, padding=padding, dilation=dilation))
169 | else:
170 | layer.append(nn.Conv2d(in_channels, channels, kernel_size,
171 | bias=bias, stride=stride, padding=padding, dilation=dilation))
172 |
173 | # norm
174 | if norm == "spectral":
175 | layer.append(spectral_norm(conv_(in_channels, channels, kernel_size,
176 | stride=stride, bias=bias, padding=padding, dilation=dilation), True))
177 | elif norm == "instance":
178 | layer.append(nn.InstanceNorm2d(
179 | channels, affine=True, track_running_stats=False))
180 | elif norm == "batch":
181 | layer.append(nn.BatchNorm2d(
182 | channels, affine=True, track_running_stats=False))
183 | elif norm == "spectral_instance":
184 | layer.append(spectral_norm(conv_(in_channels, channels, kernel_size,
185 | stride=stride, bias=bias, padding=padding, dilation=dilation), True))
186 | layer.append(nn.InstanceNorm2d(
187 | channels, affine=True, track_running_stats=False))
188 | elif norm == "batch_":
189 | layer.append(BatchNorm_(channels))
190 | else:
191 | pass
192 |
193 | # activation_fn
194 | if activation_fn == "Sigmoid":
195 | layer.append(nn.Sigmoid())
196 | elif activation_fn == "ReLU":
197 | layer.append(nn.ReLU(True))
198 | elif activation_fn == "none":
199 | pass
200 | else:
201 | layer.append(nn.LeakyReLU(0.2,inplace=True))
202 |
203 | return nn.Sequential(*layer)
204 |
205 |
206 | class BatchNorm_(nn.Module):
207 | def __init__(self, channels):
208 | super(BatchNorm_, self).__init__()
209 | self.w0 = torch.nn.Parameter(
210 | torch.FloatTensor([1.0]), requires_grad=True)
211 | self.w1 = torch.nn.Parameter(
212 | torch.FloatTensor([0.0]), requires_grad=True)
213 | self.BatchNorm2d = nn.BatchNorm2d(
214 | channels, affine=True, track_running_stats=False)
215 |
216 | def forward(self, x):
217 | outputs = self.w0*x+self.w1*self.BatchNorm2d(x)
218 | return outputs
219 |
220 | def avgcov2d_layer(pool_kernel_size, pool_stride, in_channels, channels, conv_kernel_size=3, conv_stride=1, padding=1, dilation=1, norm="batch", activation_fn="LeakyReLU"):
221 | layer = []
222 | layer.append(nn.AvgPool2d(pool_kernel_size, pool_stride))
223 | layer.append(conv2d_layer(in_channels, channels, kernel_size=conv_kernel_size, stride=conv_stride,
224 | padding=padding, dilation=dilation, norm=norm, activation_fn=activation_fn))
225 | return nn.Sequential(*layer)
226 |
227 |
228 | def spectral_norm(module, mode=True):
229 | if mode:
230 | return nn.utils.spectral_norm(module)
231 |
232 | return module
233 |
234 |
235 | def get_encoder(encoder_param, norm):
236 | encoder = []
237 | index = 0
238 | for param in encoder_param:
239 | if index == 0:
240 | encoder.append(conv2d_layer(param[0], param[1], kernel_size=param[2], stride=param[3], padding=param[4], dilation=1, norm=norm,
241 | activation_fn="ReLU", conv_mode="none", pad_mode="ReflectionPad2d"))
242 | else:
243 | encoder.append(conv2d_layer(param[0], param[1], kernel_size=param[2], stride=param[3], padding=param[4], dilation=1, norm=norm,
244 | activation_fn="ReLU", conv_mode="none", pad_mode="ZeroPad2d"))
245 | index += 1
246 | return encoder
247 |
248 |
249 | def get_middle(middle_param, norm):
250 | blocks = []
251 | for _ in range(middle_param[0]):
252 | block = ResnetBlock(
253 | middle_param[1], norm)
254 | blocks.append(block)
255 | return blocks
256 |
257 |
258 | def get_decoder(decoder_param, norm, Sigmoid=True):
259 | if Sigmoid:
260 | activation_fn = "Sigmoid"
261 | else:
262 | activation_fn = "none"
263 | decoder = []
264 | index = 0
265 | for param in decoder_param:
266 | if index == len(decoder_param)-1:
267 | decoder.append(conv2d_layer(param[0], param[1], kernel_size=param[2], stride=param[3], padding=param[4], dilation=1, norm="none",
268 | activation_fn=activation_fn, conv_mode="none", pad_mode="ReflectionPad2d"))
269 | else:
270 | decoder.append(conv2d_layer(param[0], param[1], kernel_size=param[2], stride=param[3], padding=param[4], dilation=1, norm=norm,
271 | activation_fn="ReLU", conv_mode="transpose", pad_mode="ZeroPad2d"))
272 | index += 1
273 | return decoder
274 |
275 |
276 | def get_encoder_decoder(in_channels, ResnetBlockNum, Sigmoid=True, norm="batch"):
277 | encoder_param = [
278 | [in_channels, 64, 7, 1, 3],
279 | [64, 128, 4, 2, 1],
280 | [128, 256, 4, 2, 1]]
281 | encoder = nn.Sequential(
282 | *get_encoder(encoder_param, norm))
283 | middle_param = [ResnetBlockNum, 256]
284 | middle = nn.Sequential(
285 | *get_middle(middle_param, norm))
286 | decoder_param = [
287 | [256, 128, 4, 2, 1],
288 | [128, 64, 4, 2, 1],
289 | [64, 3, 7, 1, 3]]
290 | decoder = nn.Sequential(
291 | *get_decoder(decoder_param, norm, Sigmoid=Sigmoid))
292 | return nn.Sequential(*[encoder, middle, decoder])
293 |
294 |
295 | class ResnetBlock(nn.Module):
296 | def __init__(self, dim, norm):
297 | super(ResnetBlock, self).__init__()
298 | self.conv_block = nn.Sequential(
299 | conv2d_layer(dim, dim, kernel_size=3, stride=1, padding=1, dilation=1, norm=norm,
300 | activation_fn="ReLU", conv_mode="none", pad_mode="ReflectionPad2d"),
301 |
302 | conv2d_layer(dim, dim, kernel_size=3, stride=1, padding=2, dilation=2, norm=norm,
303 | activation_fn="ReLU", conv_mode="none", pad_mode="ReflectionPad2d"),
304 |
305 | conv2d_layer(dim, dim, kernel_size=3, stride=1, padding=1, dilation=1, norm=norm,
306 | activation_fn="ReLU", conv_mode="none", pad_mode="ReflectionPad2d"),
307 | )
308 |
309 | def forward(self, x):
310 | out = x + self.conv_block(x)
311 | return out
312 |
--------------------------------------------------------------------------------
/src/network/network_DMTN.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from numpy.core.fromnumeric import shape
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 |
8 | from src.loss import VGG19
9 | from src.network.networks import Discriminator, avgcov2d_layer, conv2d_layer
10 | from torch.autograd import Variable
11 | from src.utils import solve_factor, imshow
12 | import torchvision.transforms.functional as TF
13 | import scipy.linalg
14 |
15 |
16 | class BaseNetwork(nn.Module):
17 | def __init__(self):
18 | super(BaseNetwork, self).__init__()
19 |
20 | def init_weights(self, init_type='identity', gain=0.02):
21 | '''
22 | initialize network's weights
23 | init_type: normal | xavier | kaiming | orthogonal
24 | '''
25 |
26 | def init_func(m):
27 | classname = m.__class__.__name__
28 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
29 | if init_type == 'normal':
30 | nn.init.normal_(m.weight.data, 0.0, gain)
31 | elif init_type == 'xavier':
32 | nn.init.xavier_normal_(m.weight.data, gain=gain)
33 | elif init_type == 'kaiming':
34 | nn.init.kaiming_normal_(
35 | m.weight.data, a=0, mode='fan_in')
36 | elif init_type == 'orthogonal':
37 | nn.init.orthogonal_(m.weight.data, gain=gain)
38 | elif init_type == "identity":
39 | if isinstance(m, nn.Linear):
40 | nn.init.normal_(m.weight, 0, gain)
41 | else:
42 | identity_initializer(m.weight.data)
43 |
44 | if hasattr(m, 'bias') and m.bias is not None:
45 | nn.init.constant_(m.bias.data, 0.0)
46 |
47 | elif classname.find('BatchNorm2d') != -1:
48 | nn.init.normal_(m.weight.data, 1.0, gain)
49 | nn.init.constant_(m.bias.data, 0.0)
50 |
51 | self.apply(init_func)
52 |
53 |
54 | def identity_initializer(data):
55 | shape = data.shape
56 | array = np.zeros(shape, dtype=float)
57 | cx, cy = shape[2]//2, shape[3]//2
58 | for i in range(np.minimum(shape[0], shape[1])):
59 | array[i, i, cx, cy] = 1
60 | return torch.tensor(array, dtype=torch.float32)
61 |
62 |
63 | class DMTN(BaseNetwork):
64 | """DMTN"""
65 |
66 | def __init__(self, config, in_channels=3, init_weights=True):
67 | super(DMTN, self).__init__()
68 |
69 | # gan
70 | channels = 64
71 | stage_num = [12, 2]
72 | self.network = DMTNSOURCE(
73 | in_channels, channels, norm=config.GAN_NORM, stage_num=stage_num)
74 | # gan loss
75 | if config.LOSS == "MSELoss":
76 | self.add_module('loss', nn.MSELoss(reduction="mean"))
77 | elif config.LOSS == "L1Loss":
78 | self.add_module('loss', nn.L1Loss(reduction="mean"))
79 |
80 | # gan optimizer
81 | self.optimizer = optim.Adam(
82 | params=self.network.parameters(),
83 | lr=float(config.LR),
84 | betas=(config.BETA1, config.BETA2)
85 | )
86 |
87 | # dis
88 | self.ADV = config.ADV
89 | if self.ADV:
90 | discriminator = []
91 | discriminator.append(Discriminator(config, in_channels=6))
92 | self.discriminator = nn.Sequential(*discriminator)
93 |
94 | if init_weights:
95 | self.init_weights(config.INIT_TYPE)
96 |
97 | def process(self, images, mask, GT):
98 | loss = []
99 | logs = []
100 | inputs = images
101 | img, net_matte, outputs = self(
102 | inputs)
103 |
104 | matte_gt = GT-inputs
105 | matte_gt = matte_gt - \
106 | (matte_gt.min(dim=2, keepdim=True).values).min(
107 | dim=3, keepdim=True).values
108 | matte_gt = matte_gt / \
109 | (matte_gt.max(dim=2, keepdim=True).values).max(
110 | dim=3, keepdim=True).values
111 | match_loss_1 = 0
112 | match_loss_2 = 0
113 | matte_loss = 0
114 |
115 | match_loss_1 += self.cal_loss(outputs, GT)*255*10
116 | match_loss_2 += self.cal_loss(img, images)*255*10
117 | matte_loss += self.cal_loss(net_matte, matte_gt)*255
118 |
119 | if self.ADV:
120 | dis_loss_1, gen_gan_loss_1, perceptual_loss_1 = self.discriminator[0].cal_loss(
121 | images, outputs, GT)
122 |
123 | perceptual_loss_1 = perceptual_loss_1*1000
124 |
125 | gen_loss = perceptual_loss_1 + match_loss_1+match_loss_2 +\
126 | matte_loss+gen_gan_loss_1
127 |
128 | loss.append(gen_loss)
129 | loss.append(dis_loss_1)
130 |
131 | logs.append(("l_match1", match_loss_1.item()))
132 | logs.append(("l_match2", match_loss_2.item()))
133 | logs.append(("l_matte", matte_loss.item()))
134 | logs.append(("l_perceptual_1", perceptual_loss_1.item()))
135 | logs.append(("l_adv1", gen_gan_loss_1.item()))
136 | logs.append(("l_gen", gen_loss.item()))
137 | logs.append(("l_dis1", dis_loss_1.item()))
138 | else:
139 | gen_loss = match_loss_1 + match_loss_2 + matte_loss
140 | gen_loss = gen_loss
141 | loss.append(gen_loss)
142 |
143 | logs.append(("l_match1", match_loss_1.item()))
144 | logs.append(("l_match2", match_loss_2.item()))
145 | logs.append(("l_matte", matte_loss.item()))
146 | logs.append(("l_gen", gen_loss.item()))
147 |
148 | return [net_matte, outputs], loss, logs
149 |
150 | def forward(self, x):
151 | outputs = self.network(x)
152 | return outputs
153 |
154 | def cal_loss(self, outputs, GT):
155 | matching_loss = self.loss(outputs, GT)
156 | return matching_loss
157 |
158 | def backward(self, loss):
159 | self.optimizer.zero_grad()
160 | loss[0].backward()
161 | self.optimizer.step()
162 | if self.ADV:
163 | i = 0
164 | for discriminator in self.discriminator:
165 | discriminator.backward(loss[1+i])
166 | i += 1
167 |
168 |
169 | class FeatureDecouplingModule(nn.Module):
170 | '''Shadow feature decoupling module'''
171 |
172 | def __init__(self, in_channels=64, channels=3):
173 | super(FeatureDecouplingModule, self).__init__()
174 | kernel_size = 1
175 |
176 | w = torch.randn(channels, in_channels, kernel_size, kernel_size)
177 | self.w0 = torch.nn.Parameter(torch.FloatTensor(
178 | self.normalize_to_0_1(w)), requires_grad=True)
179 | w = torch.randn(channels, in_channels, kernel_size, kernel_size)
180 | self.w1 = torch.nn.Parameter(torch.FloatTensor(
181 | self.normalize_to_0_1(w)), requires_grad=True)
182 | w = torch.zeros(channels, in_channels, kernel_size, kernel_size)
183 | self.w2 = torch.nn.Parameter(torch.FloatTensor(
184 | w), requires_grad=True)
185 |
186 | self.bias_proportion = torch.nn.Parameter(torch.zeros(
187 | (channels, 1, 1, 1)), requires_grad=True)
188 |
189 | self.alpha_0 = torch.nn.Parameter(torch.ones(
190 | (1, channels, 1, 1)), requires_grad=True)
191 | self.alpha_1 = torch.nn.Parameter(torch.ones(
192 | (1, channels, 1, 1)), requires_grad=True)
193 | self.alpha_2 = torch.nn.Parameter(torch.ones(
194 | (1, channels, 1, 1)), requires_grad=True)
195 |
196 | self.bias_0 = torch.nn.Parameter(torch.zeros(
197 | (1, channels, 1, 1)), requires_grad=True)
198 | self.bias_1 = torch.nn.Parameter(torch.zeros(
199 | (1, channels, 1, 1)), requires_grad=True)
200 | self.bias_2 = torch.nn.Parameter(torch.zeros(
201 | (1, channels, 1, 1)), requires_grad=True)
202 |
203 | def forward(self, x):
204 | w0 = self.w0
205 | w1 = self.w1
206 | o, c, k_w, k_h = w0.shape
207 |
208 | w0_attention = 1+self.w2
209 | w1_attention = 1-self.w2
210 |
211 | w = w0*w0_attention
212 | median_w = torch.median(w, dim=1, keepdim=True)
213 | w0_correct = F.relu(w-median_w.values+self.bias_proportion)
214 |
215 | w0_correct = self.normalize_to_0_1(w0_correct)
216 |
217 | w = w1*w1_attention
218 | median_w = torch.median(w, dim=1, keepdim=True)
219 | w1_correct = F.relu(w-median_w.values-self.bias_proportion)
220 | w1_correct = self.normalize_to_0_1(w1_correct)
221 |
222 | w2_correct = w0_correct+w1_correct
223 |
224 | img = torch.sigmoid(self.alpha_0*F.conv2d(x, w0_correct)+self.bias_0)
225 | matte = torch.sigmoid(self.alpha_1*F.conv2d(x, w1_correct)+self.bias_1)
226 | img_free = torch.sigmoid(
227 | self.alpha_2*F.conv2d(x, w2_correct)+self.bias_2)
228 |
229 | return img, matte, img_free
230 |
231 | def normalize_to_0_1(self, w):
232 | w = w-w.min()
233 | w = w/w.max()
234 | return w
235 |
236 |
237 | class DMTNSOURCE(nn.Module):
238 | def __init__(self, in_channels=3, channels=64, norm="batch", stage_num=[6, 4]):
239 | super(DMTNSOURCE, self).__init__()
240 | self.stage_num = stage_num
241 |
242 | # Pre-trained VGG19
243 | self.add_module('vgg19', VGG19())
244 |
245 | # SE
246 | cat_channels = in_channels+64+128+256+512+512
247 | self.se = nn.Sequential(SELayer(cat_channels),
248 | conv2d_layer(cat_channels, channels, kernel_size=3, padding=1, dilation=1, norm=norm))
249 |
250 | self.down_sample = conv2d_layer(
251 | channels, 2*channels, kernel_size=4, stride=2, padding=1, dilation=1, norm=norm)
252 |
253 | # coarse
254 | coarse_list = []
255 | for i in range(self.stage_num[0]):
256 | coarse_list.append(SemiConvModule(
257 | 2*channels, norm, mid_dilation=2**(i % 6)))
258 | self.coarse_list = nn.Sequential(*coarse_list)
259 |
260 | self.up_conv = conv2d_layer(
261 | 2*channels, channels, kernel_size=3, stride=1, padding=1, dilation=1, norm=norm)
262 |
263 | # fine
264 | fine_list = []
265 | for i in range(self.stage_num[1]):
266 | fine_list.append(SemiConvModule(
267 | channels, norm, mid_dilation=2**(i % 6)))
268 | self.fine_list = nn.Sequential(*fine_list)
269 |
270 | self.se_coarse = nn.Sequential(SELayer(2*channels),
271 | conv2d_layer(2*channels, channels, kernel_size=3, padding=1, dilation=1, norm=norm))
272 |
273 | # SPP
274 | self.spp = SPP(channels, norm=norm)
275 |
276 | # Shadow feature decoupling module'
277 | self.FDM = FeatureDecouplingModule(in_channels=channels, channels=3)
278 |
279 | def forward(self, x):
280 | size = (x.shape[2], x.shape[3])
281 |
282 | # vgg
283 | x_vgg = self.vgg19(x)
284 |
285 | # hyper-column features
286 | x_cat = torch.cat((
287 | x,
288 | F.interpolate(x_vgg['relu1_2'], size,
289 | mode="bilinear", align_corners=True),
290 | F.interpolate(x_vgg['relu2_2'], size,
291 | mode="bilinear", align_corners=True),
292 | F.interpolate(x_vgg['relu3_2'], size,
293 | mode="bilinear", align_corners=True),
294 | F.interpolate(x_vgg['relu4_2'], size,
295 | mode="bilinear", align_corners=True),
296 | F.interpolate(x_vgg['relu5_2'], size, mode="bilinear", align_corners=True)), dim=1)
297 |
298 | # SE
299 | x = self.se(x_cat)
300 |
301 | # coarse
302 | x_ = x
303 | x = self.down_sample(x)
304 | for i in range(self.stage_num[0]):
305 | x = self.coarse_list[i](x)
306 |
307 | size = (x_.shape[2], x_.shape[3])
308 | x = F.interpolate(x, size, mode="bilinear", align_corners=True)
309 | x = self.up_conv(x)
310 |
311 | # fine
312 | x = self.se_coarse(torch.cat((x_, x), dim=1))
313 | for i in range(self.stage_num[1]):
314 | x = self.fine_list[i](x)
315 |
316 | # spp
317 | x = self.spp(x)
318 |
319 | # output
320 | img, matte_out, img_free = self.FDM(x)
321 |
322 | return [img, matte_out, img_free]
323 |
324 |
325 | class SemiConvModule(nn.Module):
326 | def __init__(self, channels=64, norm="batch", mid_dilation=2):
327 | super(SemiConvModule, self).__init__()
328 | list_factor = solve_factor(channels)
329 | self.group = list_factor[int(len(list_factor)/2)-1]
330 | self.split_channels = int(channels/2)
331 |
332 | # Conv
333 | self.conv_dilation = conv2d_layer(
334 | self.split_channels, self.split_channels, kernel_size=3, padding=mid_dilation, dilation=mid_dilation, norm=norm)
335 | self.conv_3x3 = conv2d_layer(
336 | self.split_channels, self.split_channels, kernel_size=3, padding=1, dilation=1, norm=norm)
337 |
338 | def forward(self, x):
339 | SSRD=False
340 | if SSRD:
341 | x_conv = x[:, self.split_channels:, :, :]
342 | x_identity = x[:, 0:self.split_channels, :, :]
343 | else:
344 | x_conv = x[:, 0:self.split_channels, :, :]
345 | x_identity = x[:, self.split_channels:, :, :]
346 |
347 | x_conv = x_conv+self.conv_dilation(x_conv)+self.conv_3x3(x_conv)
348 |
349 | x = torch.cat((x_identity, x_conv), dim=1)
350 | x = self.channel_shuffle(x)
351 | return x
352 |
353 | def channel_shuffle(self, x):
354 | batchsize, num_channels, height, width = x.data.size()
355 | assert num_channels % self.group == 0
356 | group_channels = num_channels // self.group
357 |
358 | x = x.reshape(batchsize, group_channels, self.group, height, width)
359 | x = x.permute(0, 2, 1, 3, 4)
360 | x = x.reshape(batchsize, num_channels, height, width)
361 |
362 | return x
363 |
364 | def identity(self, x):
365 | return x
366 |
367 |
368 | class SPP(nn.Module):
369 | def __init__(self, channels=64, norm="batch"):
370 | super(SPP, self).__init__()
371 | self.net2 = avgcov2d_layer(
372 | 4, 4, channels, channels, 1, padding=0, norm=norm)
373 | self.net8 = avgcov2d_layer(
374 | 8, 8, channels, channels, 1, padding=0, norm=norm)
375 | self.net16 = avgcov2d_layer(
376 | 16, 16, channels, channels, 1, padding=0, norm=norm)
377 | self.net32 = avgcov2d_layer(
378 | 32, 32, channels, channels, 1, padding=0, norm=norm)
379 | self.output = conv2d_layer(channels*5, channels, 3, norm=norm)
380 |
381 | def forward(self, x):
382 | size = (x.shape[2], x.shape[3])
383 | x = torch.cat((
384 | F.interpolate(self.net2(x), size, mode="bilinear",
385 | align_corners=True),
386 | F.interpolate(self.net8(x), size, mode="bilinear",
387 | align_corners=True),
388 | F.interpolate(self.net16(x), size,
389 | mode="bilinear", align_corners=True),
390 | F.interpolate(self.net32(x), size,
391 | mode="bilinear", align_corners=True),
392 | x), dim=1)
393 | x = self.output(x)
394 | return x
395 |
396 |
397 | class SELayer(nn.Module):
398 | def __init__(self, channel, reduction=8):
399 | super(SELayer, self).__init__()
400 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
401 | self.fc = nn.Sequential(
402 | nn.Linear(channel, channel // reduction, bias=True),
403 | nn.ReLU(inplace=True),
404 | nn.Linear(channel // reduction, channel, bias=True),
405 | nn.Sigmoid()
406 | )
407 |
408 | def forward(self, x):
409 | b, c, _, _ = x.size()
410 | y = self.avg_pool(x).view(b, c)
411 | y = self.fc(y).view(b, c, 1, 1)
412 | return x * y
413 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import shutil
4 | import sys
5 | import time
6 | from pathlib import Path
7 | from shutil import copyfile
8 |
9 | import cv2
10 | import matplotlib.pyplot as plt
11 | import numpy as np
12 | import scipy.stats as st
13 | import torch
14 | import torch.nn.functional as F
15 | import yaml
16 | from PIL import Image
17 |
18 | def solve_factor(num):
19 | # solve factors for a number
20 | list_factor = []
21 | i = 1
22 | if num > 2:
23 | while i <= num:
24 | i += 1
25 | if num % i == 0:
26 | list_factor.append(i)
27 | else:
28 | pass
29 | else:
30 | pass
31 |
32 | list_factor = list(set(list_factor))
33 | list_factor = np.sort(list_factor)
34 | return list_factor
35 |
36 | def copypth(dest_path, src_path):
37 | if (src_path).is_file():
38 | copyfile(src_path, dest_path)
39 | print(str(src_path)+" copy to "+str(dest_path))
40 |
41 |
42 | def gauss_kernel(kernlen=21, nsig=3, channels=1):
43 | # https://github.com/dojure/FPIE/blob/master/utils.py
44 | interval = (2 * nsig + 1.) / (kernlen)
45 | x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kernlen + 1)
46 | kern1d = np.diff(st.norm.cdf(x))
47 | kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
48 | kernel = kernel_raw / kernel_raw.sum()
49 | out_filter = np.array(kernel, dtype=np.float32)
50 | out_filter = out_filter.reshape((1, 1, kernlen, kernlen))
51 | out_filter = np.repeat(out_filter, channels, axis=1)
52 | return out_filter
53 |
54 |
55 | def blur(x, kernel_var):
56 | return F.conv2d(x, kernel_var, padding=1)
57 |
58 |
59 | def sobel_kernel(kernlen=3, channels=1):
60 | out_filter = np.array(
61 | [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=np.float32)
62 | out_filter = out_filter.reshape((1, 1, kernlen, kernlen))
63 | out_filter_x = np.repeat(out_filter, channels, axis=1)
64 |
65 | out_filter = np.array(
66 | [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=np.float32)
67 | out_filter = out_filter.reshape((1, 1, kernlen, kernlen))
68 | out_filter_y = np.repeat(out_filter, channels, axis=1)
69 | return [out_filter_x, out_filter_y]
70 |
71 |
72 | def sobel(x, kernel_var):
73 | sobel_kernel_x = kernel_var[0]
74 | sobel_kernel_y = kernel_var[1]
75 | sobel_x = F.conv2d(x, sobel_kernel_x, padding=1)
76 | sobel_y = F.conv2d(x, sobel_kernel_y, padding=1)
77 |
78 | return sobel_x.abs()+sobel_y.abs()
79 |
80 |
81 | def create_dir(dir):
82 | if not os.path.exists(dir):
83 | os.makedirs(dir)
84 |
85 |
86 | def create_config(config_path, cover=False):
87 | if cover:
88 | copyfile('./config.yml', config_path)
89 | else:
90 | if not os.path.exists(config_path):
91 | copyfile('./config.yml', config_path)
92 |
93 |
94 | def init_config(checkpoints_path, MODE=0, EVAL_INTERVAL_EPOCH=1, EPOCH=[30, 0]):
95 | if EVAL_INTERVAL_EPOCH < 1:
96 | APPEND = 0
97 | else:
98 | APPEND = 1
99 | if len(EPOCH) > 2:
100 | lr_restart = True
101 | else:
102 | lr_restart = False
103 | # Re-training after training is interrupted abnormally
104 | skip_train = restart_train(
105 | checkpoints_path, MODE, APPEND, EPOCH, lr_restart)
106 | if skip_train:
107 | return skip_train
108 |
109 | # edit config
110 | config_path = os.path.join(checkpoints_path, 'config.yml')
111 | fr = open(config_path, 'r')
112 | config = yaml.load(fr, Loader=yaml.FullLoader)
113 | fr.close()
114 |
115 | EPOCHLIST = EPOCH
116 | ALL_EPOCH = EPOCH[-2]
117 | EPOCH = EPOCH[EPOCH[-1]]
118 |
119 | if MODE == 0 or MODE == 1:
120 | flist = config['TRAIN_FLIST_PRE']
121 | else:
122 | flist = config['TRAIN_FLIST']
123 | TRAIN_DATA_NUM = len(np.genfromtxt(
124 | config["DATA_ROOT"]+flist, dtype=np.str, encoding='utf-8'))
125 | print('train data number is:{}'.format(TRAIN_DATA_NUM))
126 |
127 | if torch.cuda.is_available():
128 | BATCH_SIZE = config['BATCH_SIZE']
129 | MAX_ITERS = EPOCH * (TRAIN_DATA_NUM // BATCH_SIZE) # drop_last=True
130 | if config["DEBUG"]:
131 | INTERVAL = 10
132 | config['MAX_ITERS'] = 80
133 | config['EVAL_INTERVAL'] = INTERVAL
134 | config['DEBUG'] = 1
135 | config['SAMPLE_INTERVAL'] = INTERVAL
136 | config['SAVE_INTERVAL'] = INTERVAL
137 | else:
138 | INTERVAL = ((EVAL_INTERVAL_EPOCH * (TRAIN_DATA_NUM // BATCH_SIZE))//10)*10
139 | config['MAX_ITERS'] = MAX_ITERS
140 | config['EVAL_INTERVAL'] = INTERVAL
141 | config['DEBUG'] = 0
142 | config['SAMPLE_INTERVAL'] = 500
143 | config['SAVE_INTERVAL'] = INTERVAL
144 |
145 | if EVAL_INTERVAL_EPOCH < 1:
146 | config['FORCE_EXIT'] = 0
147 | else:
148 | config['FORCE_EXIT'] = 1
149 |
150 | config["MODE"] = MODE
151 | config['ALL_EPOCH'] = ALL_EPOCH
152 | config['EPOCH'] = EPOCH
153 | config['EPOCHLIST'] = EPOCHLIST
154 | config['APPEND'] = APPEND
155 | config['EVAL_INTERVAL_EPOCH'] = EVAL_INTERVAL_EPOCH
156 | save_config(config, config_path)
157 | else:
158 | print("cuda is unavailable")
159 |
160 | return skip_train
161 |
162 |
163 | def restart_train(checkpoints_path, MODE, APPEND, EPOCH, lr_restart):
164 | config_path = os.path.join(checkpoints_path, 'config.yml')
165 | fr = open('./config.yml', 'r')
166 | config = yaml.load(fr, Loader=yaml.FullLoader)
167 | fr.close()
168 |
169 | if MODE == 0:
170 | src_path = Path('./pre_train_model') / \
171 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_da.pth')
172 | elif MODE == 1:
173 | src_path = Path('./pre_train_model') / \
174 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_no_da.pth')
175 | elif MODE == 2:
176 | src_path = Path('./pre_train_model') / \
177 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_final.pth')
178 |
179 | if lr_restart:
180 | src_path = Path(checkpoints_path)/("model_save_mode_" +
181 | str(MODE))/(str(EPOCH[EPOCH[-1]]-1)+".0.pth")
182 |
183 | log_eval_val_ap_path = Path(checkpoints_path) / \
184 | ('log_eval_val_ap_'+str(MODE)+'.txt')
185 |
186 |
187 |
188 | if Path(src_path).is_file():
189 | skip_train = True
190 | cover = False # retrain
191 | else:
192 | skip_train = False
193 | if APPEND == 1 and log_eval_val_ap_path.is_file():
194 | cover = False
195 | else:
196 | cover = True # new train stage
197 |
198 | # append
199 | if (not Path(src_path).is_file()) and APPEND == 1 and log_eval_val_ap_path.is_file():
200 | eval_val_ap = np.genfromtxt(
201 | log_eval_val_ap_path, dtype=np.str, encoding='utf-8').astype(np.float)
202 | src_path = Path(checkpoints_path)/("model_save_mode_" +
203 | str(MODE))/(str(eval_val_ap[0, 1]) + '.pth')
204 |
205 | if eval_val_ap[0, 1]==EPOCH[EPOCH[-1]-1]-1:
206 | cover = True
207 |
208 | # copy .pth
209 | dest_path = Path('checkpoints/') / \
210 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'.pth')
211 | copypth(dest_path, src_path)
212 |
213 | # create config
214 | create_config(config_path, cover=cover)
215 | print("cover config file-"+str(cover))
216 |
217 | if skip_train:
218 | print("skip train stage of mode-"+str(MODE))
219 | return skip_train
220 |
221 |
222 | def save_config(config, config_path):
223 | with open(config_path, 'w') as f_obj:
224 | yaml.dump(config, f_obj)
225 |
226 |
227 | def stitch_images(*outputs, img_per_row=2):
228 | inputs_all = [*outputs]
229 | inputs = inputs_all[0]
230 | gap = 5
231 | images = [inputs_all[0], *inputs_all[1], *inputs_all[2:]]
232 |
233 | columns = len(images)
234 |
235 | height, width = inputs[0][:, :, 0].shape
236 | img = Image.new('RGB', (width * img_per_row * columns + gap *
237 | (img_per_row - 1), height * int(len(inputs) / img_per_row)))
238 |
239 | for ix in range(len(inputs)):
240 | xoffset = int(ix % img_per_row) * width * \
241 | columns + int(ix % img_per_row) * gap
242 | yoffset = int(ix / img_per_row) * height
243 |
244 | for cat in range(len(images)):
245 | im = np.array((images[cat][ix]).cpu()).astype(np.uint8).squeeze()
246 | im = Image.fromarray(im)
247 | img.paste(im, (xoffset + cat * width, yoffset))
248 |
249 | return img
250 |
251 |
252 | def imshow(img, title=''):
253 | fig = plt.gcf()
254 | fig.canvas.set_window_title(title)
255 | plt.axis('off')
256 | if len(img.size) == 3:
257 | plt.imshow(img, interpolation='none')
258 | plt.show()
259 | else:
260 | plt.imshow(img, cmap='Greys_r')
261 | plt.show()
262 |
263 |
264 | def imsave(img, path):
265 | im = Image.fromarray(img.cpu().numpy().astype(np.uint8).squeeze())
266 | im.save(path)
267 |
268 |
269 | def create_mask(width, height, mask_width, mask_height, x=None, y=None):
270 | # # https://github.com/knazeri/edge-connect
271 | mask = np.zeros((height, width))
272 | mask_x = x if x is not None else random.randint(0, width - mask_width)
273 | mask_y = y if y is not None else random.randint(0, height - mask_height)
274 | mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1
275 | return mask
276 |
277 |
278 | class Progbar(object):
279 | # https://github.com/knazeri/edge-connect
280 | """Displays a progress bar.
281 |
282 | Arguments:
283 | target: Total number of steps expected, None if unknown.
284 | width: Progress bar width on screen.
285 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
286 | stateful_metrics: Iterable of string names of metrics that
287 | should *not* be averaged over time. Metrics in this list
288 | will be displayed as-is. All others will be averaged
289 | by the progbar before display.
290 | interval: Minimum visual progress update interval (in seconds).
291 | """
292 |
293 | def __init__(self, target, width=25, verbose=1, interval=0.05,
294 | stateful_metrics=None):
295 | self.target = target
296 | self.width = width
297 | self.verbose = verbose
298 | self.interval = interval
299 | if stateful_metrics:
300 | self.stateful_metrics = set(stateful_metrics)
301 | else:
302 | self.stateful_metrics = set()
303 |
304 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
305 | sys.stdout.isatty()) or
306 | 'ipykernel' in sys.modules or
307 | 'posix' in sys.modules)
308 | self._total_width = 0
309 | self._seen_so_far = 0
310 | # We use a dict + list to avoid garbage collection
311 | # issues found in OrderedDict
312 | self._values = {}
313 | self._values_order = []
314 | self._start = time.time()
315 | self._last_update = 0
316 |
317 | def update(self, current, values=None):
318 | """Updates the progress bar.
319 |
320 | Arguments:
321 | current: Index of current step.
322 | values: List of tuples:
323 | `(name, value_for_last_step)`.
324 | If `name` is in `stateful_metrics`,
325 | `value_for_last_step` will be displayed as-is.
326 | Else, an average of the metric over time will be displayed.
327 | """
328 | values = values or []
329 | for k, v in values:
330 | if k not in self._values_order:
331 | self._values_order.append(k)
332 | if k not in self.stateful_metrics:
333 | if k not in self._values:
334 | self._values[k] = [v * (current - self._seen_so_far),
335 | current - self._seen_so_far]
336 | else:
337 | self._values[k][0] += v * (current - self._seen_so_far)
338 | self._values[k][1] += (current - self._seen_so_far)
339 | else:
340 | self._values[k] = v
341 | self._seen_so_far = current
342 |
343 | now = time.time()
344 | info = ' - %.0fs' % (now - self._start)
345 | if self.verbose == 1:
346 | if (now - self._last_update < self.interval and
347 | self.target is not None and current < self.target):
348 | return
349 |
350 | prev_total_width = self._total_width
351 | if self._dynamic_display:
352 | sys.stdout.write('\b' * prev_total_width)
353 | sys.stdout.write('\r')
354 | else:
355 | sys.stdout.write('\n')
356 |
357 | if self.target is not None:
358 | numdigits = int(np.floor(np.log10(self.target))) + 1
359 | barstr = '%%%dd/%d [' % (numdigits, self.target)
360 | bar = barstr % current
361 | prog = float(current) / self.target
362 | prog_width = int(self.width * prog)
363 | if prog_width > 0:
364 | bar += ('=' * (prog_width - 1))
365 | if current < self.target:
366 | bar += '>'
367 | else:
368 | bar += '='
369 | bar += ('.' * (self.width - prog_width))
370 | bar += ']'
371 | else:
372 | bar = '%7d/Unknown' % current
373 |
374 | self._total_width = len(bar)
375 | sys.stdout.write(bar)
376 |
377 | if current:
378 | time_per_unit = (now - self._start) / current
379 | else:
380 | time_per_unit = 0
381 | if self.target is not None and current < self.target:
382 | eta = time_per_unit * (self.target - current)
383 | if eta > 3600:
384 | eta_format = '%d:%02d:%02d' % (eta // 3600,
385 | (eta % 3600) // 60,
386 | eta % 60)
387 | elif eta > 60:
388 | eta_format = '%d:%02d' % (eta // 60, eta % 60)
389 | else:
390 | eta_format = '%ds' % eta
391 |
392 | info = ' - ETA: %s' % eta_format
393 | else:
394 | if time_per_unit >= 1:
395 | info += ' %.0fs/step' % time_per_unit
396 | elif time_per_unit >= 1e-3:
397 | info += ' %.0fms/step' % (time_per_unit * 1e3)
398 | else:
399 | info += ' %.0fus/step' % (time_per_unit * 1e6)
400 |
401 | for k in self._values_order:
402 | info += ' - %s:' % k
403 | if isinstance(self._values[k], list):
404 | avg = np.mean(
405 | self._values[k][0] / max(1, self._values[k][1]))
406 | if abs(avg) > 1e-3:
407 | info += ' %.4f' % avg
408 | else:
409 | info += ' %.4e' % avg
410 | else:
411 | info += ' %s' % self._values[k]
412 |
413 | self._total_width += len(info)
414 | if prev_total_width > self._total_width:
415 | info += (' ' * (prev_total_width - self._total_width))
416 |
417 | if self.target is not None and current >= self.target:
418 | info += '\n'
419 |
420 | sys.stdout.write(info)
421 | sys.stdout.flush()
422 |
423 | elif self.verbose == 2:
424 | if self.target is None or current >= self.target:
425 | for k in self._values_order:
426 | info += ' - %s:' % k
427 | avg = np.mean(
428 | self._values[k][0] / max(1, self._values[k][1]))
429 | if avg > 1e-3:
430 | info += ' %.4f' % avg
431 | else:
432 | info += ' %.4e' % avg
433 | info += '\n'
434 |
435 | sys.stdout.write(info)
436 | sys.stdout.flush()
437 |
438 | self._last_update = now
439 |
440 | def add(self, n, values=None):
441 | self.update(self._seen_so_far + n, values)
442 |
--------------------------------------------------------------------------------
/src/model_top.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from shutil import copyfile
4 |
5 | import numpy as np
6 | import torch
7 | import torchvision.transforms.functional as F
8 | import yaml
9 | from torch.utils.data import DataLoader
10 |
11 | from .dataset import Dataset
12 | from .metrics import Metrics
13 | from .models import Model
14 | from .utils import (Progbar, create_dir, imsave, imshow, save_config,
15 | stitch_images)
16 |
17 |
18 | class ModelTop():
19 | def __init__(self, config):
20 | # config
21 | self.config = config
22 | if config.DEBUG == 1:
23 | self.debug = True
24 | else:
25 | self.debug = False
26 | self.model_name = config.MODEL_NAME
27 | self.RESULTS_SAMPLE = self.config.RESULTS_SAMPLE
28 | # model
29 | self.model = Model(config).to(config.DEVICE)
30 | # eval
31 | self.metrics = Metrics().to(config.DEVICE)
32 | # dataset
33 | if config.MODE == 3: # test
34 | self.test_dataset = Dataset(
35 | config, config.DATA_ROOT+config.TEST_FLIST, config.DATA_ROOT+config.TEST_MASK_FLIST, config.DATA_ROOT+config.TEST_GT_FLIST, augment=False)
36 | elif config.MODE == 4: # eval
37 | self.val_dataset = Dataset(
38 | config, config.DATA_ROOT+config.VAL_FLIST, config.DATA_ROOT+config.VAL_MASK_FLIST, config.DATA_ROOT+config.VAL_GT_FLIST, augment=False)
39 | elif config.MODE == 5: # eval
40 | self.val_dataset = Dataset(
41 | config, config.DATA_ROOT+config.TEST_FLIST, config.DATA_ROOT+config.TEST_MASK_FLIST, config.DATA_ROOT+config.TEST_GT_FLIST, augment=False)
42 | else:
43 | if config.MODE == 0:
44 | self.train_dataset = Dataset(
45 | config, config.DATA_ROOT+config.TRAIN_FLIST_PRE, config.DATA_ROOT+config.TRAIN_MASK_FLIST_PRE, config.DATA_ROOT+config.TRAIN_GT_FLIST_PRE, augment=True)
46 | self.val_dataset = Dataset(
47 | config, config.DATA_ROOT+config.VAL_FLIST_PRE, config.DATA_ROOT+config.VAL_MASK_FLIST_PRE, config.DATA_ROOT+config.VAL_GT_FLIST_PRE, augment=True)
48 | elif config.MODE == 1:
49 | self.train_dataset = Dataset(
50 | config, config.DATA_ROOT+config.TRAIN_FLIST_PRE, config.DATA_ROOT+config.TRAIN_MASK_FLIST_PRE, config.DATA_ROOT+config.TRAIN_GT_FLIST_PRE, augment=False)
51 | self.val_dataset = Dataset(
52 | config, config.DATA_ROOT+config.VAL_FLIST_PRE, config.DATA_ROOT+config.VAL_MASK_FLIST_PRE, config.DATA_ROOT+config.VAL_GT_FLIST_PRE, augment=False)
53 | elif config.MODE == 2:
54 | self.train_dataset = Dataset(
55 | config, config.DATA_ROOT+config.TRAIN_FLIST, config.DATA_ROOT+config.TRAIN_MASK_FLIST, config.DATA_ROOT+config.TRAIN_GT_FLIST, augment=False)
56 | self.val_dataset = Dataset(
57 | config, config.DATA_ROOT+config.VAL_FLIST, config.DATA_ROOT+config.VAL_MASK_FLIST, config.DATA_ROOT+config.VAL_GT_FLIST, augment=False)
58 | self.sample_iterator = self.val_dataset.create_iterator(
59 | config.SAMPLE_SIZE)
60 |
61 | # path
62 | self.samples_path = os.path.join(config.PATH, 'samples')
63 | self.results_path = os.path.join(config.PATH, 'results')
64 | self.backups_path = os.path.join(config.PATH, 'backups')
65 | self.results_samples_path = os.path.join(self.results_path, 'samples')
66 | if self.config.BACKUP:
67 | create_dir(self.backups_path)
68 | if config.RESULTS is not None:
69 | self.results_path = os.path.join(config.RESULTS)
70 | create_dir("./pre_train_model/"+self.config.SUBJECT_WORD)
71 | # load file
72 | self.log_file = os.path.join(
73 | config.PATH, 'log_' + self.model_name + '.dat')
74 |
75 | # avoid overfitting
76 | if config.MODE < 3:
77 | data_save_path = os.path.join(
78 | self.config.PATH, 'log_eval_val_ap_id.txt')
79 | if os.path.exists(data_save_path):
80 | self.eval_val_ap_id = np.genfromtxt(
81 | data_save_path, dtype=np.str, encoding='utf-8').astype(np.float)
82 | if config.APPEND == 0:
83 | self.eval_val_ap_id[config.MODE] = 0
84 | else:
85 | self.eval_val_ap_id = [0.0, 0.0, 0.0]
86 |
87 | data_save_path = os.path.join(
88 | self.config.PATH, 'final_model_epoch.txt')
89 | if os.path.exists(data_save_path):
90 | self.epoch = np.genfromtxt(
91 | data_save_path, dtype=np.str, encoding='utf-8').astype(np.float).astype(np.int)
92 | if config.APPEND == 0:
93 | self.epoch[config.MODE, 0] = 0
94 | else:
95 | self.epoch = np.zeros((3, 2)).astype(int)
96 |
97 | data_save_path = os.path.join(
98 | self.config.PATH, 'log_eval_val_ap_'+str(self.config.MODE)+'.txt')
99 | if os.path.exists(data_save_path):
100 | self.eval_val_ap = np.genfromtxt(
101 | data_save_path, dtype=np.str, encoding='utf-8').astype(np.float)
102 | if config.APPEND == 0:
103 | if config.FORCE_EXIT == 0:
104 | if config.MODE == 0 or config.MODE == 1:
105 | self.eval_val_ap = np.ones(
106 | (config.PRE_TRAIN_EVAL_LEN, 2))*1e6
107 | elif config.MODE == 2:
108 | self.eval_val_ap = np.ones(
109 | (config.TRAIN_EVAL_LEN, 2))*1e6
110 | else:
111 | self.eval_val_ap = np.ones((config.ALL_EPOCH, 2))*1e6
112 | else:
113 | if config.FORCE_EXIT == 0:
114 | if config.MODE == 0 or config.MODE == 1:
115 | self.eval_val_ap = np.ones(
116 | (config.PRE_TRAIN_EVAL_LEN, 2))*1e6
117 | elif config.MODE == 2:
118 | self.eval_val_ap = np.ones(
119 | (config.TRAIN_EVAL_LEN, 2))*1e6
120 | else:
121 | self.eval_val_ap = np.ones((config.ALL_EPOCH, 2))*1e6
122 |
123 | # lr scheduler
124 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
125 | self.model.network_instance.optimizer, 'min', factor=0.5, patience=0, min_lr=1e-5)
126 | if config.ADV:
127 | self.scheduler_dis = [torch.optim.lr_scheduler.ReduceLROnPlateau(
128 | discriminator.optimizer, 'min', factor=0.5, patience=0, min_lr=5e-6) for discriminator in self.model.network_instance.discriminator]
129 |
130 | def load(self):
131 | self.model.load()
132 |
133 | def save(self):
134 | self.model.save()
135 |
136 | def train(self):
137 | # initial
138 | self.restart_train_check_lr_scheduler()
139 | train_loader = DataLoader(
140 | dataset=self.train_dataset,
141 | batch_size=self.config.BATCH_SIZE,
142 | num_workers=4,
143 | drop_last=True,
144 | shuffle=True
145 | )
146 | keep_training = True
147 | mode = self.config.MODE
148 | max_iteration = int(float((self.config.MAX_ITERS)))
149 | max_epoch = self.config.EPOCHLIST[self.config.EPOCHLIST[-1]]+1
150 | total = len(self.train_dataset)
151 | self.TRAIN_DATA_NUM = total
152 | if total == 0:
153 | print(
154 | 'No training data was provided! Check \'TRAIN_FLIST\' value in the configuration file.')
155 | return
156 | # train
157 | while(keep_training):
158 | # epoch
159 | epoch = self.epoch[self.config.MODE, 0]
160 | epoch += 1
161 | self.epoch[self.config.MODE, 0] = epoch
162 | print('\n\nTraining epoch: %d' % epoch)
163 | # progbar
164 | progbar = Progbar(
165 | total, width=20, stateful_metrics=['epoch', 'iter'])
166 |
167 | for items in train_loader:
168 | # initial
169 | self.model.train()
170 |
171 | # get data
172 | images, mask, GT = self.cuda(
173 | *items)
174 | # imshow(F.to_pil_image((outputs)[0,:,:,:].cpu()))
175 | # train
176 | outputs, loss, logs = self.model.process(
177 | images, mask, GT)
178 | # backward
179 | self.model.backward(loss)
180 | iteration = self.model.iteration
181 |
182 | # log-epoch, iteration
183 | logs = [
184 | ("epoch", epoch),
185 | ("iter", iteration),
186 | ] + logs
187 | # progbar
188 | progbar.add(len(images), values=logs if self.config.VERBOSE else [
189 | x for x in logs if not x[0].startswith('l_')])
190 |
191 | # log model at checkpoints
192 | if self.config.LOG_INTERVAL and iteration % self.config.LOG_INTERVAL == 0:
193 | self.log(logs)
194 | # sample model at checkpoints
195 | if self.config.SAMPLE_INTERVAL and iteration % self.config.SAMPLE_INTERVAL == 0:
196 | self.sample()
197 | # save model at checkpoints
198 | if self.config.SAVE_INTERVAL and iteration % self.config.SAVE_INTERVAL == 0:
199 | self.save()
200 | # force_exit
201 | if epoch >= max_epoch: # if iteration >= max_iteration:
202 | force_exit = True
203 | else:
204 | force_exit = False
205 | # end condition
206 | if force_exit:
207 | keep_training = False
208 | self.force_exit()
209 | print('\n ***force_exit: max iteration***')
210 | break
211 | # evaluate model at checkpoints
212 | if (self.config.EVAL_INTERVAL and iteration % self.config.EVAL_INTERVAL == 0):
213 | print('\nstart eval...\n')
214 | pre_exit, ap = self.eval()
215 | self.scheduler.step(ap)
216 | if self.config.ADV:
217 | for scheduler_dis in self.scheduler_dis:
218 | scheduler_dis.step(ap)
219 |
220 | with open(self.config.CONFIG_PATH, 'r') as f_obj:
221 | config = yaml.load(f_obj, Loader=yaml.FullLoader)
222 | config['LR'] = self.scheduler.optimizer.param_groups[0]['lr']
223 | config['LR_D'] = self.scheduler_dis[0].optimizer.param_groups[0]['lr']
224 | save_config(config, self.config.CONFIG_PATH)
225 | else:
226 | pre_exit = False
227 | # debug
228 | if self.debug:
229 | if iteration >= 40:
230 | if self.config.MODE == 0:
231 | force_exit = True
232 | copyfile(Path('checkpoints')/self.config.SUBJECT_WORD/(self.config.MODEL_NAME+'.pth'),
233 | Path('./pre_train_model')/self.config.SUBJECT_WORD/(self.config.MODEL_NAME+'_pre_da.pth'))
234 | if self.config.MODE == 1:
235 | force_exit = True
236 | copyfile(Path('checkpoints')/self.config.SUBJECT_WORD/(self.config.MODEL_NAME+'.pth'),
237 | Path('./pre_train_model')/self.config.SUBJECT_WORD/(self.config.MODEL_NAME+'_pre_no_da.pth'))
238 | # end condition
239 | if pre_exit:
240 | keep_training = False
241 | break
242 |
243 | print('\nEnd training....')
244 |
245 | def eval(self):
246 | # torch.cuda.empty_cache()
247 | if self.config.MODE == 4 or self.config.MODE == 5:
248 | BATCH_SIZE = self.config.BATCH_SIZE # *8
249 | num_workers = 4 # 8
250 | else:
251 | BATCH_SIZE = self.config.BATCH_SIZE
252 | num_workers = 4
253 | val_loader = DataLoader(
254 | dataset=self.val_dataset,
255 | batch_size=BATCH_SIZE,
256 | num_workers=num_workers,
257 | drop_last=False,
258 | shuffle=False
259 | )
260 | total = len(self.val_dataset)
261 | self.metrics.multiprocessingi_utils.creat_pool()
262 |
263 | self.model.eval()
264 | progbar = Progbar(total, width=20, stateful_metrics=['it'])
265 | iteration = 0
266 | log_eval_PR = [[0], [0]]
267 | n_thresh = 99
268 | # zero all counts
269 | rmse_shadow = torch.Tensor(0).cuda()
270 | n_pxl_shadow = torch.Tensor(0).cuda()
271 | rmse_non_shadow = torch.Tensor(0).cuda()
272 | n_pxl_non_shadow = torch.Tensor(0).cuda()
273 | rmse_all = torch.Tensor(0).cuda()
274 | n_pxl_all = torch.Tensor(0).cuda()
275 | # eval each image
276 | with torch.no_grad():
277 | for items in val_loader:
278 | iteration += 1
279 | images, mask, GT = self.cuda(
280 | *items)
281 | # eval
282 | outputs = self.model(images)
283 |
284 | rmse_shadow_, n_pxl_shadow_, rmse_non_shadow_, n_pxl_non_shadow_, rmse_all_, n_pxl_all_ = self.metrics.rmse(
285 | outputs[-1], mask, GT, dataset_mode=1)
286 |
287 | rmse_shadow = torch.cat((rmse_shadow, rmse_shadow_), dim=0)
288 | n_pxl_shadow = torch.cat((n_pxl_shadow, n_pxl_shadow_), dim=0)
289 | rmse_non_shadow = torch.cat(
290 | (rmse_non_shadow, rmse_non_shadow_), dim=0)
291 | n_pxl_non_shadow = torch.cat(
292 | (n_pxl_non_shadow, n_pxl_non_shadow_), dim=0)
293 | rmse_all = torch.cat((rmse_all, rmse_all_), dim=0)
294 | n_pxl_all = torch.cat((n_pxl_all, n_pxl_all_), dim=0)
295 |
296 | if self.debug:
297 | if iteration == 10:
298 | break
299 |
300 | rmse_shadow_eval, rmse_non_shadow_eval, rmse_all_eval = self.metrics.collect_rmse(
301 | rmse_shadow, n_pxl_shadow, rmse_non_shadow, n_pxl_non_shadow, rmse_all, n_pxl_all)
302 |
303 | # print
304 | print('running rmse-shadow: %.4f, rmse-non-shadow: %.4f, rmse-all: %.4f'
305 | % (rmse_shadow_eval, rmse_non_shadow_eval, rmse_all_eval))
306 | data_save_path = os.path.join(
307 | self.config.PATH, 'show_result.txt')
308 | np.savetxt(data_save_path, [
309 | rmse_shadow_eval, rmse_non_shadow_eval, rmse_all_eval], fmt='%s')
310 |
311 | # avoid overfitting (pre_train and train)
312 | if self.config.MODE == 0 or self.config.MODE == 1 or self.config.MODE == 2:
313 | if self.config.FORCE_EXIT == 0:
314 | if rmse_all_eval < np.mean(self.eval_val_ap[:, 0]):
315 | exit_ = False
316 | else:
317 | exit_ = True
318 | else:
319 | exit_ = False
320 |
321 | self.eval_val_ap = np.delete(self.eval_val_ap, -1, axis=0)
322 | self.eval_val_ap = np.append(
323 | [[rmse_all_eval, self.eval_val_ap_id[self.config.MODE]]], self.eval_val_ap, axis=0)
324 |
325 | model_save_path = "model_save_mode_"+str(self.config.MODE)
326 | model_save_path = Path(self.config.PATH)/model_save_path
327 | create_dir(model_save_path)
328 | self.model.weights_path = os.path.join(
329 | model_save_path, str(self.eval_val_ap_id[self.config.MODE]) + '.pth')
330 | self.save()
331 | self.model.weights_path = os.path.join(
332 | self.config.PATH, self.model_name + '.pth')
333 |
334 | data_save_path = os.path.join(
335 | self.config.PATH, 'final_model_epoch.txt')
336 | np.savetxt(data_save_path, self.epoch, fmt='%s')
337 | if self.eval_val_ap_id[self.config.MODE] == (len(self.eval_val_ap)-1):
338 | self.eval_val_ap_id[self.config.MODE] = 0.0
339 | else:
340 | self.eval_val_ap_id[self.config.MODE] += 1.0
341 |
342 | data_save_path = os.path.join(
343 | self.config.PATH, 'log_eval_val_ap_'+str(self.config.MODE)+'.txt')
344 | np.savetxt(data_save_path, self.eval_val_ap, fmt='%s')
345 | data_save_path = os.path.join(
346 | self.config.PATH, 'log_eval_val_ap_id.txt')
347 | np.savetxt(data_save_path, self.eval_val_ap_id, fmt='%s')
348 |
349 | if exit_:
350 | idmin = np.array(self.eval_val_ap[:, 0]).argmin()
351 |
352 | data_save_path = os.path.join(
353 | self.config.PATH, 'final_model_epoch.txt')
354 | if self.config.EVAL_INTERVAL_EPOCH < 1:
355 | self.epoch[self.config.MODE,
356 | 1] = self.eval_val_ap[idmin, 1]
357 | else:
358 | self.epoch[self.config.MODE,
359 | 1] = self.epoch[self.config.MODE, 0]-idmin-1
360 | print('final model id:'+str(self.epoch[self.config.MODE, 1]))
361 | np.savetxt(data_save_path, self.epoch, fmt='%s')
362 |
363 | pre_train_save_path = "./pre_train_model/"+self.config.SUBJECT_WORD
364 | if os.path.exists(os.path.join(model_save_path, str(self.eval_val_ap[idmin, 1]) + '.pth')):
365 | if self.config.MODE == 0:
366 | PATH_WEIDHT = os.path.join(
367 | pre_train_save_path, self.model_name + '_pre_da.pth')
368 | elif self.config.MODE == 1:
369 | PATH_WEIDHT = os.path.join(
370 | pre_train_save_path, self.model_name + '_pre_no_da.pth')
371 | elif self.config.MODE == 2:
372 | PATH_WEIDHT = os.path.join(
373 | pre_train_save_path, self.model_name + '_final.pth')
374 | copyfile(os.path.join(model_save_path, str(
375 | self.eval_val_ap[idmin, 1]) + '.pth'), PATH_WEIDHT)
376 | print(os.path.join(model_save_path, str(
377 | self.eval_val_ap[idmin, 1]) + '.pth')+" copy to "+PATH_WEIDHT)
378 | pre_model_save = torch.load(PATH_WEIDHT)
379 | torch.save(
380 | {'iteration': 0, 'model': pre_model_save['model']}, PATH_WEIDHT)
381 |
382 | copyfile(PATH_WEIDHT, os.path.join(
383 | self.config.PATH, self.model_name + '.pth'))
384 | return exit_, rmse_all_eval
385 |
386 | def force_exit(self):
387 | model_save_path = "model_save_mode_"+str(self.config.MODE)
388 | model_save_path = Path(self.config.PATH)/model_save_path
389 |
390 | idmin = np.array(self.eval_val_ap[:, 0]).argmin()
391 |
392 | data_save_path = os.path.join(
393 | self.config.PATH, 'final_model_epoch.txt')
394 |
395 | self.epoch[self.config.MODE,0] = self.epoch[self.config.MODE,0]-1
396 | if self.config.EVAL_INTERVAL_EPOCH < 1:
397 | self.epoch[self.config.MODE,
398 | 1] = self.eval_val_ap[idmin, 1]
399 | else:
400 | self.epoch[self.config.MODE,
401 | 1] = self.epoch[self.config.MODE, 0]-idmin-1
402 | print('\n\nfinal model id:'+str(self.epoch[self.config.MODE, 1]))
403 | np.savetxt(data_save_path, self.epoch, fmt='%s')
404 |
405 | pre_train_save_path = "./pre_train_model/"+self.config.SUBJECT_WORD
406 | if os.path.exists(os.path.join(model_save_path, str(self.eval_val_ap[idmin, 1]) + '.pth')):
407 | if self.config.MODE == 0:
408 | PATH_WEIDHT = os.path.join(
409 | pre_train_save_path, self.model_name + '_pre_da.pth')
410 | elif self.config.MODE == 1:
411 | PATH_WEIDHT = os.path.join(
412 | pre_train_save_path, self.model_name + '_pre_no_da.pth')
413 | elif self.config.MODE == 2:
414 | PATH_WEIDHT = os.path.join(
415 | pre_train_save_path, self.model_name + '_final.pth')
416 | copyfile(os.path.join(model_save_path, str(
417 | self.eval_val_ap[idmin, 1]) + '.pth'), PATH_WEIDHT)
418 | print(os.path.join(model_save_path, str(
419 | self.eval_val_ap[idmin, 1]) + '.pth')+" copy to "+PATH_WEIDHT)
420 | pre_model_save = torch.load(PATH_WEIDHT)
421 | torch.save(
422 | {'iteration': 0, 'model': pre_model_save['model']}, PATH_WEIDHT)
423 |
424 | copyfile(PATH_WEIDHT, os.path.join(
425 | self.config.PATH, self.model_name + '.pth'))
426 |
427 | def test(self):
428 | # initial
429 | self.model.eval()
430 | if self.RESULTS_SAMPLE:
431 | save_path = os.path.join(
432 | self.results_samples_path, self.model_name)
433 | create_dir(save_path)
434 | else:
435 | save_path = os.path.join(self.results_path, self.model_name)
436 | create_dir(save_path)
437 | if self.debug:
438 | debug_path = os.path.join(save_path, "debug")
439 | create_dir(debug_path)
440 | save_path = debug_path
441 | test_loader = DataLoader(
442 | dataset=self.test_dataset,
443 | batch_size=1,
444 | )
445 | # test
446 | index = 0
447 | with torch.no_grad():
448 | for items in test_loader:
449 | name = self.test_dataset.load_name(index)
450 | index += 1
451 | images, mask, GT = self.cuda(
452 | *items)
453 | if self.RESULTS_SAMPLE:
454 | image_per_row = 2
455 | if self.config.SAMPLE_SIZE <= 6:
456 | image_per_row = 1
457 | outputs = self.model(
458 | images)
459 | i = 0
460 | for output in outputs:
461 | outputs[i] = self.postprocess(output)
462 | i += 1
463 | matte_gt = GT-images
464 | matte_gt = matte_gt - \
465 | (matte_gt.min(dim=2, keepdim=True).values).min(
466 | dim=3, keepdim=True).values
467 | matte_gt = matte_gt / \
468 | (matte_gt.max(dim=2, keepdim=True).values).max(
469 | dim=3, keepdim=True).values
470 | images = stitch_images(
471 | self.postprocess(images),
472 | outputs,
473 | self.postprocess(matte_gt),
474 | self.postprocess(GT),
475 | img_per_row=image_per_row,
476 | )
477 | images.save(path)
478 | else:
479 | outputs = self.model(images)
480 | outputs = self.postprocess(outputs[-1])[0]
481 | path = os.path.join(save_path, name)
482 | imsave(outputs, path)
483 | # debug
484 | if self.debug:
485 | if index == 10:
486 | break
487 | print('\nEnd test....')
488 |
489 | def sample(self, it=None):
490 | # initial, do not sample when validation set is empty
491 | if len(self.val_dataset) == 0:
492 | return
493 | # torch.cuda.empty_cache()
494 | self.model.eval()
495 | iteration = self.model.iteration
496 |
497 | items = next(self.sample_iterator)
498 | images, mask, GT = self.cuda(
499 | *items)
500 | image_per_row = 2
501 | if self.config.SAMPLE_SIZE <= 6:
502 | image_per_row = 1
503 | outputs = self.model(
504 | images)
505 | i = 0
506 | for output in outputs:
507 | outputs[i] = self.postprocess(output)
508 | i += 1
509 | matte_gt = GT-images
510 | matte_gt = matte_gt - \
511 | (matte_gt.min(dim=2, keepdim=True).values).min(
512 | dim=3, keepdim=True).values
513 | matte_gt = matte_gt / \
514 | (matte_gt.max(dim=2, keepdim=True).values).max(
515 | dim=3, keepdim=True).values
516 | images = stitch_images(
517 | self.postprocess(images),
518 | outputs,
519 | self.postprocess(matte_gt),
520 | self.postprocess(GT),
521 | img_per_row=image_per_row,
522 | )
523 |
524 | path = os.path.join(self.samples_path, self.model_name)
525 | name = os.path.join(path, "mode_"+str(self.config.MODE) +
526 | "_"+str(iteration).zfill(5) + ".png")
527 | create_dir(path)
528 | print('\nsaving sample ' + name)
529 | images.save(name)
530 |
531 | def log(self, logs):
532 | with open(self.log_file, 'a') as f:
533 | f.write('%s\n' % ' '.join([str(item[1]) for item in logs]))
534 |
535 | def cuda(self, *args):
536 | return (item.to(self.config.DEVICE) for item in args)
537 |
538 | def postprocess(self, img):
539 | # [0, 1] => [0, 255]
540 | img = img * 255.0
541 | img = img.permute(0, 2, 3, 1)
542 | return img.int()
543 |
544 | def restart_train_check_lr_scheduler(self):
545 | checkpoints_path = Path('./checkpoints') / self.config.SUBJECT_WORD
546 | log_eval_val_ap_path = Path(checkpoints_path) / \
547 | ('log_eval_val_ap_'+str(self.config.MODE)+'.txt')
548 |
549 | if log_eval_val_ap_path.is_file():
550 | eval_val_ap = np.genfromtxt(
551 | log_eval_val_ap_path, dtype=np.str, encoding='utf-8').astype(np.float)
552 | EPOCH = self.config.EPOCHLIST
553 |
554 | if EPOCH[-1]!=0 and eval_val_ap[0, 1] != EPOCH[EPOCH[-1]-1]-1:
555 | ap = str(eval_val_ap[0, 0])
556 | self.scheduler.step(ap)
557 | if self.config.ADV:
558 | for scheduler_dis in self.scheduler_dis:
559 | scheduler_dis.step(ap)
560 |
--------------------------------------------------------------------------------