├── img ├── fig1.jpg ├── fig2.jpg └── fig3.jpg ├── evaluation └── evaluate_MAE_PSNR_SSIM.m ├── show_eval_result.py ├── script ├── flist.py ├── flist_train_val_test.py ├── process_syn_data.py ├── generate_flist_srd.py └── generate_flist_istd.py ├── debug_ready.py ├── src ├── config.py ├── main.py ├── image_pool.py ├── models.py ├── metrics.py ├── dataset.py ├── loss.py ├── network │ ├── networks.py │ └── network_DMTN.py ├── utils.py └── model_top.py ├── deploy_ready.py ├── config ├── config_SRD_da.yml ├── config_ISTD_da.yml ├── config_SRD.yml ├── config_ISTD.yml ├── run_ISTD.py ├── run_SRD.py ├── run_SRD_da.py └── run_ISTD_da.py ├── install.yaml └── README.md /img/fig1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nachifur/DMTN/HEAD/img/fig1.jpg -------------------------------------------------------------------------------- /img/fig2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nachifur/DMTN/HEAD/img/fig2.jpg -------------------------------------------------------------------------------- /img/fig3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nachifur/DMTN/HEAD/img/fig3.jpg -------------------------------------------------------------------------------- /evaluation/evaluate_MAE_PSNR_SSIM.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nachifur/DMTN/HEAD/evaluation/evaluate_MAE_PSNR_SSIM.m -------------------------------------------------------------------------------- /show_eval_result.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | rmse = np.genfromtxt( 3 | 'checkpoints/ISTD/show_result.txt', dtype=np.str, encoding='utf-8').astype(np.float) 4 | 5 | print('running rmse-shadow: %.4f, rmse-non-shadow: %.4f, rmse-all: %.4f' 6 | % (rmse[0], rmse[1], rmse[2])) 7 | -------------------------------------------------------------------------------- /script/flist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from pathlib import Path 4 | 5 | def gen_flist(data_path): 6 | ext = {'.JPG', '.JPEG', '.PNG', '.TIF', 'TIFF','json'} 7 | images = [] 8 | for root, dirs, files in os.walk(data_path): 9 | print('loading ' + root) 10 | for file in files: 11 | if os.path.splitext(file)[1].upper() in ext: 12 | images.append(os.path.join(root, file)) 13 | 14 | images = sorted(images) 15 | return images 16 | -------------------------------------------------------------------------------- /debug_ready.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | 4 | # # clear 5 | # flag = os.system("rm -rf checkpoints") 6 | # if flag == 0: 7 | # print("clear checkpoints success") 8 | # flag = os.system("rm -rf src/__pycache__") 9 | # if flag == 0: 10 | # print("clear src/__pycache__ success") 11 | 12 | fr = open("config.yml", 'r') 13 | config = yaml.load(fr, Loader=yaml.FullLoader) 14 | config["DEBUG"] = 1 15 | config["GPU"] = [0] 16 | with open("config.yml", 'w') as f_obj: 17 | yaml.dump(config, f_obj) 18 | print("in debug mode") -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import yaml 4 | 5 | 6 | class Config(dict): 7 | def __init__(self, config_path): 8 | with open(config_path, 'r') as f: 9 | self._yaml = f.read() 10 | self._dict = yaml.load(self._yaml, Loader=yaml.FullLoader) 11 | self._dict['PATH'] = os.path.dirname(config_path) 12 | 13 | def __getattr__(self, name): 14 | if self._dict.get(name) is not None: 15 | return self._dict[name] 16 | 17 | return None 18 | 19 | def print(self): 20 | print('Model configurations:') 21 | print('---------------------------------') 22 | print(self._yaml) 23 | print('') 24 | print('---------------------------------') 25 | print('') 26 | -------------------------------------------------------------------------------- /deploy_ready.py: -------------------------------------------------------------------------------- 1 | import os 2 | import yaml 3 | import argparse 4 | from pathlib import Path 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--gpu', help="gpu") 8 | ARGS = parser.parse_args() 9 | gpu = ARGS.gpu 10 | 11 | 12 | 13 | fr = open("config.yml", 'r') 14 | config = yaml.load(fr, Loader=yaml.FullLoader) 15 | config["DEBUG"] = 0 16 | config["GPU"] = [gpu] 17 | with open("config.yml", 'w') as f_obj: 18 | yaml.dump(config, f_obj) 19 | print("in deploy mode") 20 | 21 | # clear 22 | # model_name = config["MODEL_NAME"] + '.pth' 23 | # checkpoints_path = str(Path('./checkpoints') / \ 24 | # config["SUBJECT_WORD"]/model_name) 25 | # flag = os.system("rm "+checkpoints_path) 26 | # if flag == 0: 27 | # print("clear "+checkpoints_path+" success") 28 | flag = os.system("rm -rf src/__pycache__") 29 | if flag == 0: 30 | print("clear src/__pycache__ success") 31 | 32 | -------------------------------------------------------------------------------- /config/config_SRD_da.yml: -------------------------------------------------------------------------------- 1 | ADV: 1 2 | BACKUP: 0 3 | BATCH_SIZE: 1 4 | BETA1: 0.9 5 | BETA2: 0.999 6 | DATA_ROOT: /home/xxx/data-set/shadow_removal_dataset/SRD_Dataset_arg 7 | DEBUG: 0 8 | DIS_NORM: spectral 9 | FORCE_EXIT: 1 10 | GAN_LOSS: nsgan 11 | GAN_NORM: batch_ 12 | GPU: 13 | - '0' 14 | INIT_TYPE: identity 15 | INPUT_SIZE_H: 640 16 | INPUT_SIZE_W: 840 17 | LOG_INTERVAL: 10 18 | LOSS: L1Loss 19 | LR: 0.0001 20 | LR_D: 5.0e-05 21 | MIDDLE_RES_NUM: 4 22 | MODEL_NAME: ShadowRemoval 23 | NETWORK: DMTN 24 | POOL_SIZE: 50 25 | PRE_TRAIN_EVAL_LEN: 10 26 | RESULTS: ./checkpoints/SRD/results 27 | RESULTS_SAMPLE: 0 28 | SAMPLE_SIZE: 1 29 | SEED: 10 30 | SPLIT: 1 31 | SUBJECT_WORD: SRD 32 | TRAIN_EVAL_LEN: 10 33 | VERBOSE: 1 34 | 35 | TEST_FLIST: /data_val/SRD_shadow_test.flist 36 | TEST_GT_FLIST: /data_val/SRD_shadow_free_test.flist 37 | TEST_MASK_FLIST: /data_val/SRD_mask_test.flist 38 | 39 | TRAIN_FLIST: /data_val/SRD_shadow_train_train.flist 40 | TRAIN_GT_FLIST: /data_val/SRD_shadow_free_train_train.flist 41 | TRAIN_MASK_FLIST: /data_val/SRD_mask_train_train.flist 42 | 43 | TRAIN_FLIST_PRE: /data_val/SRD_shadow_pre_train.flist 44 | TRAIN_GT_FLIST_PRE: /data_val/SRD_shadow_free_pre_train.flist 45 | TRAIN_MASK_FLIST_PRE: /data_val/SRD_mask_pre_train.flist 46 | 47 | VAL_FLIST: /data_val/SRD_shadow_train_val.flist 48 | VAL_GT_FLIST: /data_val/SRD_shadow_free_train_val.flist 49 | VAL_MASK_FLIST: /data_val/SRD_mask_train_val.flist 50 | 51 | VAL_FLIST_PRE: /data_val/SRD_shadow_pre_val.flist 52 | VAL_GT_FLIST_PRE: /data_val/SRD_shadow_free_pre_val.flist 53 | VAL_MASK_FLIST_PRE: /data_val/SRD_mask_pre_val.flist -------------------------------------------------------------------------------- /config/config_ISTD_da.yml: -------------------------------------------------------------------------------- 1 | ADV: 1 2 | BACKUP: 0 3 | BATCH_SIZE: 1 4 | BETA1: 0.9 5 | BETA2: 0.999 6 | DATA_ROOT: /home/xxx/data-set/shadow_removal_dataset/ISTD_Dataset_arg 7 | DEBUG: 0 8 | DIS_NORM: spectral 9 | FORCE_EXIT: 1 10 | GAN_LOSS: nsgan 11 | GAN_NORM: batch_ 12 | GPU: 13 | - '0' 14 | INIT_TYPE: identity 15 | INPUT_SIZE_H: 480 16 | INPUT_SIZE_W: 640 17 | LOG_INTERVAL: 10 18 | LOSS: L1Loss 19 | LR: 0.0001 20 | LR_D: 5.0e-05 21 | MIDDLE_RES_NUM: 4 22 | MODEL_NAME: ShadowRemoval 23 | NETWORK: DMTN 24 | POOL_SIZE: 50 25 | PRE_TRAIN_EVAL_LEN: 10 26 | RESULTS: ./checkpoints/ISTD/results 27 | RESULTS_SAMPLE: 0 28 | SAMPLE_SIZE: 1 29 | SEED: 10 30 | SPLIT: 1 31 | SUBJECT_WORD: ISTD 32 | TRAIN_EVAL_LEN: 10 33 | VERBOSE: 1 34 | 35 | TEST_FLIST: /data_val/ISTD_shadow_test.flist 36 | TEST_GT_FLIST: /data_val/ISTD_shadow_free_test.flist 37 | TEST_MASK_FLIST: /data_val/ISTD_mask_test.flist 38 | 39 | TRAIN_FLIST: /data_val/ISTD_shadow_train_train.flist 40 | TRAIN_GT_FLIST: /data_val/ISTD_shadow_free_train_train.flist 41 | TRAIN_MASK_FLIST: /data_val/ISTD_mask_train_train.flist 42 | 43 | TRAIN_FLIST_PRE: /data_val/ISTD_shadow_pre_train.flist 44 | TRAIN_GT_FLIST_PRE: /data_val/ISTD_shadow_free_pre_train.flist 45 | TRAIN_MASK_FLIST_PRE: /data_val/ISTD_mask_pre_train.flist 46 | 47 | VAL_FLIST: /data_val/ISTD_shadow_train_val.flist 48 | VAL_GT_FLIST: /data_val/ISTD_shadow_free_train_val.flist 49 | VAL_MASK_FLIST: /data_val/ISTD_mask_train_val.flist 50 | 51 | VAL_FLIST_PRE: /data_val/ISTD_shadow_pre_val.flist 52 | VAL_GT_FLIST_PRE: /data_val/ISTD_shadow_free_pre_val.flist 53 | VAL_MASK_FLIST_PRE: /data_val/ISTD_mask_pre_val.flist -------------------------------------------------------------------------------- /config/config_SRD.yml: -------------------------------------------------------------------------------- 1 | ADV: 1 2 | BACKUP: 0 3 | BATCH_SIZE: 1 4 | BETA1: 0.9 5 | BETA2: 0.999 6 | DATA_ROOT: /home/xxx/data-set/shadow_removal_dataset/SRD_Dataset_arg 7 | DEBUG: 0 8 | DIS_NORM: spectral 9 | FORCE_EXIT: 1 10 | GAN_LOSS: nsgan 11 | GAN_NORM: batch_ 12 | GPU: 13 | - '0' 14 | INIT_TYPE: identity 15 | INPUT_SIZE_H: 640 16 | INPUT_SIZE_W: 840 17 | LOG_INTERVAL: 10 18 | LOSS: L1Loss 19 | LR: 0.0001 20 | LR_D: 5.0e-05 21 | MIDDLE_RES_NUM: 4 22 | MODEL_NAME: ShadowRemoval 23 | NETWORK: DMTN 24 | POOL_SIZE: 50 25 | PRE_TRAIN_EVAL_LEN: 10 26 | RESULTS: ./checkpoints/SRD/results 27 | RESULTS_SAMPLE: 0 28 | SAMPLE_SIZE: 1 29 | SEED: 10 30 | SPLIT: 1 31 | SUBJECT_WORD: SRD 32 | TRAIN_EVAL_LEN: 10 33 | VERBOSE: 1 34 | 35 | TEST_FLIST: /data_val/SRD_shadow_test.flist 36 | TEST_GT_FLIST: /data_val/SRD_shadow_free_test.flist 37 | TEST_MASK_FLIST: /data_val/SRD_mask_test.flist 38 | 39 | TRAIN_FLIST: /data_val/SRD_shadow_train_train.flist 40 | TRAIN_GT_FLIST: /data_val/SRD_shadow_free_train_train.flist 41 | TRAIN_MASK_FLIST: /data_val/SRD_mask_train_train.flist 42 | 43 | TRAIN_FLIST_PRE: /data_val/SRD_shadow_train_train.flist 44 | TRAIN_GT_FLIST_PRE: /data_val/SRD_shadow_free_train_train.flist 45 | TRAIN_MASK_FLIST_PRE: /data_val/SRD_mask_train_train.flist 46 | 47 | VAL_FLIST: /data_val/SRD_shadow_train_val.flist 48 | VAL_GT_FLIST: /data_val/SRD_shadow_free_train_val.flist 49 | VAL_MASK_FLIST: /data_val/SRD_mask_train_val.flist 50 | 51 | VAL_FLIST_PRE: /data_val/SRD_shadow_train_val.flist 52 | VAL_GT_FLIST_PRE: /data_val/SRD_shadow_free_train_val.flist 53 | VAL_MASK_FLIST_PRE: /data_val/SRD_mask_train_val.flist 54 | -------------------------------------------------------------------------------- /config/config_ISTD.yml: -------------------------------------------------------------------------------- 1 | ADV: 1 2 | BACKUP: 0 3 | BATCH_SIZE: 1 4 | BETA1: 0.9 5 | BETA2: 0.999 6 | DATA_ROOT: /home/xxx/data-set/shadow_removal_dataset/ISTD_Dataset_arg 7 | DEBUG: 0 8 | DIS_NORM: spectral 9 | FORCE_EXIT: 1 10 | GAN_LOSS: nsgan 11 | GAN_NORM: batch_ 12 | GPU: 13 | - '0' 14 | INIT_TYPE: identity 15 | INPUT_SIZE_H: 480 16 | INPUT_SIZE_W: 640 17 | LOG_INTERVAL: 10 18 | LOSS: L1Loss 19 | LR: 0.0001 20 | LR_D: 5.0e-05 21 | MIDDLE_RES_NUM: 4 22 | MODEL_NAME: ShadowRemoval 23 | NETWORK: DMTN 24 | POOL_SIZE: 50 25 | PRE_TRAIN_EVAL_LEN: 10 26 | RESULTS: ./checkpoints/ISTD/results 27 | RESULTS_SAMPLE: 0 28 | SAMPLE_SIZE: 1 29 | SEED: 10 30 | SPLIT: 1 31 | SUBJECT_WORD: ISTD 32 | TRAIN_EVAL_LEN: 10 33 | VERBOSE: 1 34 | 35 | TEST_FLIST: /data_val/ISTD_shadow_test.flist 36 | TEST_GT_FLIST: /data_val/ISTD_shadow_free_test.flist 37 | TEST_MASK_FLIST: /data_val/ISTD_mask_test.flist 38 | 39 | TRAIN_FLIST: /data_val/ISTD_shadow_train_train.flist 40 | TRAIN_GT_FLIST: /data_val/ISTD_shadow_free_train_train.flist 41 | TRAIN_MASK_FLIST: /data_val/ISTD_mask_train_train.flist 42 | 43 | TRAIN_FLIST_PRE: /data_val/ISTD_shadow_train_train.flist 44 | TRAIN_GT_FLIST_PRE: /data_val/ISTD_shadow_free_train_train.flist 45 | TRAIN_MASK_FLIST_PRE: /data_val/ISTD_mask_train_train.flist 46 | 47 | VAL_FLIST: /data_val/ISTD_shadow_train_val.flist 48 | VAL_GT_FLIST: /data_val/ISTD_shadow_free_train_val.flist 49 | VAL_MASK_FLIST: /data_val/ISTD_mask_train_val.flist 50 | 51 | VAL_FLIST_PRE: /data_val/ISTD_shadow_train_val.flist 52 | VAL_GT_FLIST_PRE: /data_val/ISTD_shadow_free_train_val.flist 53 | VAL_MASK_FLIST_PRE: /data_val/ISTD_mask_train_val.flist 54 | -------------------------------------------------------------------------------- /script/flist_train_val_test.py: -------------------------------------------------------------------------------- 1 | """divide the data of all folders in a directory into: training set, verification set, test set""" 2 | import os 3 | import numpy as np 4 | import random 5 | 6 | 7 | def gen_flist_train_val_test(flist, train_val_test_path, train_val_test_ratio, SEED, id_list): 8 | random.seed(SEED) 9 | # get flist 10 | ext = {'.JPG', '.JPEG', '.PNG', '.TIF', 'TIFF', 'json'} 11 | images = np.genfromtxt(flist, dtype=np.str, encoding='utf-8') 12 | # shuffle 13 | files_num = len(images) 14 | if len(id_list) == 0: 15 | id_list = list(range(files_num)) 16 | shuffle = True 17 | else: 18 | shuffle = False 19 | if shuffle: 20 | random.shuffle(id_list) 21 | images = np.array(images)[id_list] 22 | # save 23 | images_train_val_test = [[], [], []] 24 | i_list = [0] 25 | sum_ = 0 26 | if sum(train_val_test_ratio) == 10: 27 | for i in range(3): 28 | if train_val_test_ratio[i] > 0: 29 | sum_ = sum_ + \ 30 | np.int(np.floor(train_val_test_ratio[i]*files_num/10)) 31 | if i == 2: 32 | i_list.append(files_num) 33 | else: 34 | i_list.append(sum_) 35 | images_train_val_test[i] = images[i_list[i]:i_list[i+1]] 36 | # save 37 | np.savetxt(train_val_test_path[i], 38 | images_train_val_test[i], fmt='%s') 39 | else: 40 | sum_ = sum_ + \ 41 | np.int(np.floor(train_val_test_ratio[i]*files_num/10)) 42 | i_list.append(sum_) 43 | else: 44 | print('input train_val_test_ratio error!') 45 | return id_list 46 | -------------------------------------------------------------------------------- /script/process_syn_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | def parpare_image_syn(val_path,stage='train_shadow_free'): 5 | 6 | iminput = val_path 7 | val_mask_name = val_path.split('/')[-1].split('_')[-1] 8 | gtmask = val_path.replace(stage,'train_B_ISTD').replace(val_path.split('/')[-1],val_mask_name) 9 | 10 | val_im_name = '_'.join(val_path.split('/')[-1].split('_')[0:-1])+'.jpg' 11 | imtarget = val_path.replace(stage,'shadow_free').replace(val_path.split('/')[-1],val_im_name) 12 | 13 | return iminput,imtarget,gtmask 14 | 15 | def prepare_data(train_path, stage=['train_A']): 16 | input_names=[] 17 | for dirname in train_path: 18 | for subfolder in stage: 19 | train_b = dirname + "/"+ subfolder+"/" 20 | for root, _, fnames in sorted(os.walk(train_b)): 21 | for fname in fnames: 22 | if is_image_file(fname): 23 | input_names.append(os.path.join(train_b, fname)) 24 | return input_names 25 | 26 | IMG_EXTENSIONS = [ 27 | '.jpg', '.JPG', '.jpeg', '.JPEG', 28 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 29 | ] 30 | 31 | def is_image_file(filename): 32 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 33 | 34 | def gen_syn_images_flist(train_real_root): 35 | train_real_root = [train_real_root] 36 | syn_images=prepare_data(train_real_root,stage=['synC']) 37 | 38 | shadow_syn_list= [] 39 | shadow_free_syn_list= [] 40 | mask_syn_list= [] 41 | for i in range(len(syn_images)): 42 | shadow,shadow_free,mask = parpare_image_syn(syn_images[i],stage='synC') 43 | shadow_syn_list.append(shadow) 44 | shadow_free_syn_list.append(shadow_free) 45 | mask_syn_list.append(mask) 46 | return shadow_syn_list,shadow_free_syn_list,mask_syn_list 47 | 48 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | from shutil import copyfile 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | 10 | from src.config import Config 11 | from src.model_top import ModelTop 12 | from src.utils import create_dir 13 | 14 | 15 | def main(mode, config_path): 16 | r""" 17 | Args: 18 | mode (int): 1: train, 2: test, 3: eval, reads from config file if not specified 19 | """ 20 | 21 | config = load_config(mode, config_path) 22 | config.CONFIG_PATH = config_path 23 | 24 | # init device 25 | if torch.cuda.is_available(): 26 | config.DEVICE = torch.device("cuda") 27 | torch.backends.cudnn.benchmark = True # cudnn auto-tuner 28 | else: 29 | config.DEVICE = torch.device("cpu") 30 | 31 | # set cv2 running threads to 1 (prevents deadlocks with pytorch dataloader) 32 | cv2.setNumThreads(0) 33 | 34 | # initialize random seed 35 | torch.manual_seed(config.SEED) 36 | torch.cuda.manual_seed_all(config.SEED) 37 | np.random.seed(config.SEED) 38 | random.seed(config.SEED) 39 | 40 | # build the model and initialize 41 | model = ModelTop(config) 42 | model.load() 43 | 44 | # model pre training (data augmentation) 45 | if config.MODE == 0: 46 | config.print() 47 | model.train() 48 | 49 | # model pre training (no data augmentation) 50 | if config.MODE == 1: 51 | config.print() 52 | model.train() 53 | 54 | # model training 55 | if config.MODE == 2: 56 | config.print() 57 | model.train() 58 | 59 | # model test 60 | elif config.MODE == 3: 61 | model.test() 62 | 63 | # model eval on val set 64 | elif config.MODE ==4: 65 | model.eval() 66 | 67 | # model eval on test set 68 | elif config.MODE == 5: 69 | model.eval() 70 | 71 | 72 | def load_config(mode, config_path): 73 | r"""loads model config 74 | 75 | Args: 76 | mode (int): 1: train, 2: test 3:eval reads from config file if not specified 77 | """ 78 | 79 | # load config file 80 | config = Config(config_path) 81 | config.MODE = mode 82 | 83 | return config 84 | 85 | 86 | if __name__ == "__main__": 87 | main() 88 | -------------------------------------------------------------------------------- /config/run_ISTD.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import yaml 7 | 8 | 9 | from src.main import main 10 | from src.utils import create_dir, init_config, copypth 11 | import torch 12 | 13 | if __name__ == '__main__': 14 | # inital 15 | multiprocessing.set_start_method('spawn') 16 | config = yaml.load(open("config.yml", 'r'), Loader=yaml.FullLoader) 17 | dest_path = Path('checkpoints/') / \ 18 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'.pth') 19 | # cuda visble devices 20 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( 21 | str(e) for e in config["GPU"]) 22 | torch.autograd.set_detect_anomaly(True) 23 | checkpoints_path = Path('./checkpoints') / \ 24 | config["SUBJECT_WORD"] # model checkpoints path 25 | create_dir(checkpoints_path) 26 | create_dir('./pre_train_model') 27 | config_path = os.path.join(checkpoints_path, 'config.yml') 28 | 29 | # pre_train (no data augmentation) 30 | MODE = 0 31 | print('\nmode-'+str(MODE)+': start pre_training(data augmentation)...\n') 32 | for i in range(1): 33 | skip_train = init_config(checkpoints_path, MODE=MODE, 34 | EVAL_INTERVAL_EPOCH=1, EPOCH=[90,i]) 35 | if not skip_train: 36 | main(MODE, config_path) 37 | src_path = Path('./pre_train_model') / \ 38 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_da.pth') 39 | copypth(dest_path, src_path) 40 | 41 | # train 42 | MODE = 2 43 | print('\nmode-'+str(MODE)+': start training...\n') 44 | for i in range(1): 45 | skip_train = init_config(checkpoints_path, MODE=MODE, 46 | EVAL_INTERVAL_EPOCH=0.1, EPOCH=[60,i]) 47 | if not skip_train: 48 | main(MODE, config_path) 49 | src_path = Path('./pre_train_model') / \ 50 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_final.pth') 51 | copypth(dest_path, src_path) 52 | 53 | # test 54 | MODE = 3 55 | print('\nmode-'+str(MODE)+': start testing...\n') 56 | main(MODE, config_path) 57 | 58 | # eval on val set 59 | # MODE = 4 60 | # print('\nmode-'+str(MODE)+': start eval...\n') 61 | # main(MODE,config_path) 62 | 63 | # eval on test set 64 | MODE = 5 65 | print('\nmode-'+str(MODE)+': start eval...\n') 66 | main(MODE, config_path) 67 | -------------------------------------------------------------------------------- /config/run_SRD.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import yaml 7 | 8 | 9 | from src.main import main 10 | from src.utils import create_dir, init_config, copypth 11 | import torch 12 | 13 | if __name__ == '__main__': 14 | # inital 15 | multiprocessing.set_start_method('spawn') 16 | config = yaml.load(open("config.yml", 'r'), Loader=yaml.FullLoader) 17 | dest_path = Path('checkpoints/') / \ 18 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'.pth') 19 | # cuda visble devices 20 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( 21 | str(e) for e in config["GPU"]) 22 | torch.autograd.set_detect_anomaly(True) 23 | checkpoints_path = Path('./checkpoints') / \ 24 | config["SUBJECT_WORD"] # model checkpoints path 25 | create_dir(checkpoints_path) 26 | create_dir('./pre_train_model') 27 | config_path = os.path.join(checkpoints_path, 'config.yml') 28 | 29 | # pre_train (no data augmentation) 30 | MODE = 0 31 | print('\nmode-'+str(MODE)+': start pre_training(data augmentation)...\n') 32 | for i in range(1): 33 | skip_train = init_config(checkpoints_path, MODE=MODE, 34 | EVAL_INTERVAL_EPOCH=1, EPOCH=[90,i]) 35 | if not skip_train: 36 | main(MODE, config_path) 37 | src_path = Path('./pre_train_model') / \ 38 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_da.pth') 39 | copypth(dest_path, src_path) 40 | 41 | # train 42 | MODE = 2 43 | print('\nmode-'+str(MODE)+': start training...\n') 44 | for i in range(6): 45 | skip_train = init_config(checkpoints_path, MODE=MODE, 46 | EVAL_INTERVAL_EPOCH=1, EPOCH=[10,20,30,40,50,60,i]) 47 | if not skip_train: 48 | main(MODE, config_path) 49 | src_path = Path('./pre_train_model') / \ 50 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_final.pth') 51 | copypth(dest_path, src_path) 52 | 53 | # test 54 | MODE = 3 55 | print('\nmode-'+str(MODE)+': start testing...\n') 56 | main(MODE, config_path) 57 | 58 | # eval on val set 59 | # MODE = 4 60 | # print('\nmode-'+str(MODE)+': start eval...\n') 61 | # main(MODE,config_path) 62 | 63 | # eval on test set 64 | MODE = 5 65 | print('\nmode-'+str(MODE)+': start eval...\n') 66 | main(MODE, config_path) 67 | -------------------------------------------------------------------------------- /src/image_pool.py: -------------------------------------------------------------------------------- 1 | """ 2 | This part of the code is built based on the project: 3 | https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix 4 | """ 5 | import random 6 | import torch 7 | 8 | 9 | class ImagePool(): 10 | """This class implements an image buffer that stores previously generated images. 11 | 12 | This buffer enables us to update discriminators using a history of generated images 13 | rather than the ones produced by the latest generators. 14 | """ 15 | 16 | def __init__(self, pool_size): 17 | """Initialize the ImagePool class 18 | 19 | Parameters: 20 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 21 | """ 22 | self.pool_size = pool_size 23 | if self.pool_size > 0: # create an empty pool 24 | self.num_imgs = 0 25 | self.images = [] 26 | 27 | def query(self, images): 28 | """Return an image from the pool. 29 | 30 | Parameters: 31 | images: the latest generated images from the generator 32 | 33 | Returns images from the buffer. 34 | 35 | By 50/100, the buffer will return input images. 36 | By 50/100, the buffer will return images previously stored in the buffer, 37 | and insert the current images to the buffer. 38 | """ 39 | if self.pool_size == 0: # if the buffer size is 0, do nothing 40 | return images 41 | return_images = [] 42 | for image in images: 43 | image = torch.unsqueeze(image.data, 0) 44 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 45 | self.num_imgs = self.num_imgs + 1 46 | self.images.append(image) 47 | return_images.append(image) 48 | else: 49 | p = random.uniform(0, 1) 50 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 51 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 52 | tmp = self.images[random_id].clone() 53 | self.images[random_id] = image 54 | return_images.append(tmp) 55 | else: # by another 50% chance, the buffer will return the current image 56 | return_images.append(image) 57 | return_images = torch.cat(return_images, 0) # collect all the images and return 58 | return return_images 59 | -------------------------------------------------------------------------------- /install.yaml: -------------------------------------------------------------------------------- 1 | name: DMTN 2 | dependencies: 3 | - _libgcc_mutex=0.1=main 4 | - autopep8=1.5.4=py_0 5 | - blas=1.0=mkl 6 | - bzip2=1.0.8=h7b6447c_0 7 | - ca-certificates=2020.12.8=h06a4308_0 8 | - cairo=1.14.12=h8948797_3 9 | - certifi=2020.12.5=py37h06a4308_0 10 | - cudatoolkit=10.2.89=hfd86e86_1 11 | - ffmpeg=4.0=hcdf2ecd_0 12 | - fontconfig=2.13.0=h9420a91_0 13 | - freeglut=3.0.0=hf484d3e_5 14 | - freetype=2.10.4=h5ab3b9f_0 15 | - glib=2.66.1=h92f7085_0 16 | - graphite2=1.3.14=h23475e2_0 17 | - harfbuzz=1.8.8=hffaf4a1_0 18 | - hdf5=1.10.2=hba1933b_1 19 | - icu=58.2=he6710b0_3 20 | - intel-openmp=2020.2=254 21 | - jasper=2.0.14=h07fcdf6_1 22 | - jpeg=9b=h024ee3a_2 23 | - lcms2=2.11=h396b838_0 24 | - ld_impl_linux-64=2.33.1=h53a641e_7 25 | - libedit=3.1.20191231=h14c3975_1 26 | - libffi=3.3=he6710b0_2 27 | - libgcc-ng=9.1.0=hdf63c60_0 28 | - libgfortran-ng=7.3.0=hdf63c60_0 29 | - libglu=9.0.0=hf484d3e_1 30 | - libopencv=3.4.2=hb342d67_1 31 | - libopus=1.3.1=h7b6447c_0 32 | - libpng=1.6.37=hbc83047_0 33 | - libstdcxx-ng=9.1.0=hdf63c60_0 34 | - libtiff=4.1.0=h2733197_1 35 | - libuuid=1.0.3=h1bed415_2 36 | - libuv=1.40.0=h7b6447c_0 37 | - libvpx=1.7.0=h439df22_0 38 | - libxcb=1.14=h7b6447c_0 39 | - libxml2=2.9.10=hb55368b_3 40 | - lz4-c=1.9.2=heb0550a_3 41 | - mkl=2020.2=256 42 | - mkl-service=2.3.0=py37he8ac12f_0 43 | - mkl_fft=1.2.0=py37h23d657b_0 44 | - mkl_random=1.1.1=py37h0573a6f_0 45 | - ncurses=6.2=he6710b0_1 46 | - ninja=1.10.2=py37hff7bd54_0 47 | - numpy=1.19.2=py37h54aff64_0 48 | - numpy-base=1.19.2=py37hfa32c7d_0 49 | - olefile=0.46=py37_0 50 | - openssl=1.1.1i=h27cfd23_0 51 | - pcre=8.44=he6710b0_0 52 | - pillow=8.0.1=py37he98fc37_0 53 | - pip=20.3.1=py37h06a4308_0 54 | - pixman=0.40.0=h7b6447c_0 55 | - py-opencv=3.4.2=py37hb342d67_1 56 | - pycodestyle=2.6.0=py_0 57 | - python=3.7.9=h7579374_0 58 | - pytorch=1.7.1=py3.7_cuda10.2.89_cudnn7.6.5_0 59 | - pyyaml=5.3.1=py37h7b6447c_1 60 | - readline=8.0=h7b6447c_0 61 | - scipy=1.5.2=py37h0b6359f_0 62 | - setuptools=51.0.0=py37h06a4308_2 63 | - six=1.15.0=py37h06a4308_0 64 | - sqlite=3.33.0=h62c20be_0 65 | - tk=8.6.10=hbc83047_0 66 | - toml=0.10.1=py_0 67 | - torchaudio=0.7.2=py37 68 | - torchvision=0.2.1=py37_0 69 | - typing_extensions=3.7.4.3=py_0 70 | - wheel=0.36.1=pyhd3eb1b0_0 71 | - xz=5.2.5=h7b6447c_0 72 | - yaml=0.2.5=h7b6447c_0 73 | - zlib=1.2.11=h7b6447c_3 74 | - zstd=1.4.5=h9ceee32_0 75 | - pip: 76 | - augmentor==0.2.8 77 | - cycler==0.10.0 78 | - decorator==4.4.2 79 | - future==0.18.2 80 | - imageio==2.9.0 81 | - kiwisolver==1.3.1 82 | - matplotlib==3.3.3 83 | - networkx==2.5 84 | - pyparsing==2.4.7 85 | - python-dateutil==2.8.1 86 | - pywavelets==1.1.1 87 | - scikit-image==0.17.2 88 | - tifffile==2020.12.8 89 | - tqdm==4.54.1 90 | -------------------------------------------------------------------------------- /config/run_SRD_da.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import yaml 7 | 8 | 9 | from src.main import main 10 | from src.utils import create_dir, init_config, copypth 11 | import torch 12 | 13 | if __name__ == '__main__': 14 | # inital 15 | multiprocessing.set_start_method('spawn') 16 | config = yaml.load(open("config.yml", 'r'), Loader=yaml.FullLoader) 17 | dest_path = Path('checkpoints/') / \ 18 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'.pth') 19 | # cuda visble devices 20 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( 21 | str(e) for e in config["GPU"]) 22 | torch.autograd.set_detect_anomaly(True) 23 | checkpoints_path = Path('./checkpoints') / \ 24 | config["SUBJECT_WORD"] # model checkpoints path 25 | create_dir(checkpoints_path) 26 | create_dir('./pre_train_model') 27 | config_path = os.path.join(checkpoints_path, 'config.yml') 28 | 29 | # pre_train (no data augmentation) 30 | MODE = 0 31 | print('\nmode-'+str(MODE)+': start pre_training(data augmentation)...\n') 32 | for i in range(1): 33 | skip_train = init_config(checkpoints_path, MODE=MODE, 34 | EVAL_INTERVAL_EPOCH=1, EPOCH=[90,i]) 35 | if not skip_train: 36 | main(MODE, config_path) 37 | src_path = Path('./pre_train_model') / \ 38 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_da.pth') 39 | copypth(dest_path, src_path) 40 | 41 | # pre_train (no data augmentation) 42 | MODE = 1 43 | print('\nmode-'+str(MODE)+': start pre_training(no data augmentation)...\n') 44 | for i in range(1): 45 | skip_train = init_config(checkpoints_path, MODE=MODE, 46 | EVAL_INTERVAL_EPOCH=1, EPOCH=[30,i]) 47 | if not skip_train: 48 | main(MODE, config_path) 49 | src_path = Path('./pre_train_model') / \ 50 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_no_da.pth') 51 | copypth(dest_path, src_path) 52 | 53 | # train 54 | MODE = 2 55 | print('\nmode-'+str(MODE)+': start training...\n') 56 | for i in range(1): 57 | skip_train = init_config(checkpoints_path, MODE=MODE, 58 | EVAL_INTERVAL_EPOCH=1, EPOCH=[30,i]) 59 | if not skip_train: 60 | main(MODE, config_path) 61 | src_path = Path('./pre_train_model') / \ 62 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_final.pth') 63 | copypth(dest_path, src_path) 64 | 65 | # test 66 | MODE = 3 67 | print('\nmode-'+str(MODE)+': start testing...\n') 68 | main(MODE, config_path) 69 | 70 | # eval on val set 71 | # MODE = 4 72 | # print('\nmode-'+str(MODE)+': start eval...\n') 73 | # main(MODE,config_path) 74 | 75 | # eval on test set 76 | MODE = 5 77 | print('\nmode-'+str(MODE)+': start eval...\n') 78 | main(MODE, config_path) 79 | -------------------------------------------------------------------------------- /config/run_ISTD_da.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import yaml 7 | 8 | 9 | from src.main import main 10 | from src.utils import create_dir, init_config, copypth 11 | import torch 12 | 13 | if __name__ == '__main__': 14 | # inital 15 | multiprocessing.set_start_method('spawn') 16 | config = yaml.load(open("config.yml", 'r'), Loader=yaml.FullLoader) 17 | dest_path = Path('checkpoints/') / \ 18 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'.pth') 19 | # cuda visble devices 20 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( 21 | str(e) for e in config["GPU"]) 22 | torch.autograd.set_detect_anomaly(True) 23 | checkpoints_path = Path('./checkpoints') / \ 24 | config["SUBJECT_WORD"] # model checkpoints path 25 | create_dir(checkpoints_path) 26 | create_dir('./pre_train_model') 27 | config_path = os.path.join(checkpoints_path, 'config.yml') 28 | 29 | # pre_train (no data augmentation) 30 | MODE = 0 31 | print('\nmode-'+str(MODE)+': start pre_training(data augmentation)...\n') 32 | for i in range(1): 33 | skip_train = init_config(checkpoints_path, MODE=MODE, 34 | EVAL_INTERVAL_EPOCH=1, EPOCH=[90,i]) 35 | if not skip_train: 36 | main(MODE, config_path) 37 | src_path = Path('./pre_train_model') / \ 38 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_da.pth') 39 | copypth(dest_path, src_path) 40 | 41 | # pre_train (no data augmentation) 42 | MODE = 1 43 | print('\nmode-'+str(MODE)+': start pre_training(no data augmentation)...\n') 44 | for i in range(1): 45 | skip_train = init_config(checkpoints_path, MODE=MODE, 46 | EVAL_INTERVAL_EPOCH=1, EPOCH=[30,i]) 47 | if not skip_train: 48 | main(MODE, config_path) 49 | src_path = Path('./pre_train_model') / \ 50 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_no_da.pth') 51 | copypth(dest_path, src_path) 52 | 53 | # train 54 | MODE = 2 55 | print('\nmode-'+str(MODE)+': start training...\n') 56 | for i in range(1): 57 | skip_train = init_config(checkpoints_path, MODE=MODE, 58 | EVAL_INTERVAL_EPOCH=0.1, EPOCH=[30,i]) 59 | if not skip_train: 60 | main(MODE, config_path) 61 | src_path = Path('./pre_train_model') / \ 62 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_final.pth') 63 | copypth(dest_path, src_path) 64 | 65 | # test 66 | MODE = 3 67 | print('\nmode-'+str(MODE)+': start testing...\n') 68 | main(MODE, config_path) 69 | 70 | # eval on val set 71 | # MODE = 4 72 | # print('\nmode-'+str(MODE)+': start eval...\n') 73 | # main(MODE,config_path) 74 | 75 | # eval on test set 76 | MODE = 5 77 | print('\nmode-'+str(MODE)+': start eval...\n') 78 | main(MODE, config_path) 79 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from src.network.network_DMTN import DMTN 7 | 8 | class BaseModel(nn.Module): 9 | def __init__(self, name, config): 10 | super(BaseModel, self).__init__() 11 | 12 | self.name = name 13 | self.config = config 14 | self.iteration = 0 15 | self.config = config 16 | 17 | self.weights_path = os.path.join(config.PATH, name + '.pth') 18 | 19 | def load(self): 20 | if os.path.exists(self.weights_path): 21 | print('Loading %s model...' % self.name) 22 | 23 | if torch.cuda.is_available(): 24 | data = torch.load(self.weights_path) 25 | print(self.weights_path) 26 | else: 27 | data = torch.load(self.weights_path, 28 | map_location=lambda storage, loc: storage) 29 | self.network_instance.load_state_dict(data['model']) 30 | self.iteration = data['iteration'] 31 | else: 32 | vgg19_weights_path = self.config.DATA_ROOT+"/vgg19-dcbb9e9d.pth" 33 | self.network_instance.network.vgg19.load_pretrained( 34 | vgg19_weights_path, self.config.DEVICE) 35 | for discriminator in self.network_instance.discriminator: 36 | discriminator.perceptual_loss.vgg19.load_pretrained( 37 | vgg19_weights_path, self.config.DEVICE) 38 | 39 | def save(self): 40 | print('\nsaving %s...\n' % self.weights_path) 41 | torch.save({ 42 | 'iteration': self.iteration, 43 | 'model': self.network_instance.state_dict() 44 | }, self.weights_path) 45 | 46 | if self.config.BACKUP: 47 | INTERVAL_ = 4 48 | if self.config.SAVE_INTERVAL and self.iteration % (self.config.SAVE_INTERVAL*INTERVAL_) == 0: 49 | print('\nsaving %s...\n' % self.name+'_backup') 50 | torch.save({ 51 | 'iteration': self.iteration, 52 | 'model': self.network_instance.state_dict() 53 | }, os.path.join(self.config.PATH, 'backups/' + self.name + '_' + str(self.iteration // (self.config.SAVE_INTERVAL*INTERVAL_)) + '.pth')) 54 | 55 | 56 | class Model(BaseModel): 57 | def __init__(self, config): 58 | super(Model, self).__init__(config.MODEL_NAME, config) 59 | self.INNER_OPTIMIZER = config.INNER_OPTIMIZER 60 | # networks choose 61 | if config.NETWORK == "DMTN": 62 | network_instance = DMTN(config, in_channels=3) 63 | else: 64 | network_instance = None 65 | self.add_module('network_instance', network_instance) 66 | 67 | def process(self, images, mask, GT, eval_mode=False): 68 | if not eval_mode: 69 | self.iteration += 1 70 | outputs, loss, logs = self.network_instance.process( 71 | images, mask, GT) 72 | 73 | return outputs, loss, logs 74 | 75 | def forward(self, images): 76 | inputs = images 77 | outputs = self.network_instance(inputs) 78 | return outputs 79 | 80 | def backward(self, loss=None): 81 | self.network_instance.backward(loss) 82 | -------------------------------------------------------------------------------- /script/generate_flist_srd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from pathlib import Path 4 | from flist import gen_flist 5 | from flist_train_val_test import gen_flist_train_val_test 6 | from process_syn_data import gen_syn_images_flist 7 | 8 | if __name__ == '__main__': 9 | # SRD 10 | SRD_path = "/gpfs/home/liujiawei/data-set/shadow_removal_dataset/SRD_Dataset_arg" 11 | 12 | save_path = SRD_path+"/data_val" 13 | 14 | if not Path(save_path).exists(): 15 | os.mkdir(save_path) 16 | 17 | seed = 10 18 | 19 | # sys image 20 | shadow_syn_list,shadow_free_syn_list,mask_syn_list = gen_syn_images_flist(SRD_path+'/train') 21 | 22 | # train and val (data augmentation) 23 | data_path = Path(SRD_path)/"train/train_A" 24 | flist_main_name = "SRD_shadow.flist" 25 | flist_save_path = Path(save_path)/flist_main_name 26 | images = gen_flist(data_path) 27 | images+=shadow_syn_list 28 | np.savetxt(flist_save_path, images, fmt='%s') 29 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_pre_train.flist'), 30 | Path(save_path)/str(Path(flist_save_path).stem+'_pre_val.flist'), ""] 31 | id_list = gen_flist_train_val_test( 32 | flist_save_path, png_val_test_PATH, [8, 2, 0], seed, []) 33 | 34 | data_path = Path(SRD_path)/"train/train_B" 35 | flist_main_name = "SRD_mask.flist" 36 | flist_save_path = Path(save_path)/flist_main_name 37 | images = gen_flist(data_path) 38 | images+=mask_syn_list 39 | np.savetxt(flist_save_path, images, fmt='%s') 40 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_pre_train.flist'), 41 | Path(save_path)/str(Path(flist_save_path).stem+'_pre_val.flist'), ""] 42 | gen_flist_train_val_test(flist_save_path, png_val_test_PATH, [ 43 | 8, 2, 0], seed, id_list) 44 | 45 | data_path = Path(SRD_path)/"train/train_C" 46 | flist_main_name = "SRD_shadow_free.flist" 47 | flist_save_path = Path(save_path)/flist_main_name 48 | images = gen_flist(data_path) 49 | images+=shadow_free_syn_list 50 | np.savetxt(flist_save_path, images, fmt='%s') 51 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_pre_train.flist'), 52 | Path(save_path)/str(Path(flist_save_path).stem+'_pre_val.flist'), ""] 53 | gen_flist_train_val_test(flist_save_path, png_val_test_PATH, [ 54 | 8, 2, 0], seed, id_list) 55 | 56 | # train and val 57 | data_path = Path(SRD_path)/"train/train_A" 58 | flist_save_path = Path(save_path)/"SRD_shadow_train.flist" 59 | images = gen_flist(data_path) 60 | np.savetxt(flist_save_path, images, fmt='%s') 61 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_train.flist'), 62 | Path(save_path)/str(Path(flist_save_path).stem+'_val.flist'), ""] 63 | id_list = gen_flist_train_val_test( 64 | flist_save_path, png_val_test_PATH, [8, 2, 0], seed, []) 65 | 66 | data_path = Path(SRD_path)/"train/train_B" 67 | flist_save_path = Path(save_path)/"SRD_mask_train.flist" 68 | images = gen_flist(data_path) 69 | np.savetxt(flist_save_path, images, fmt='%s') 70 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_train.flist'), 71 | Path(save_path)/str(Path(flist_save_path).stem+'_val.flist'), ""] 72 | gen_flist_train_val_test(flist_save_path, png_val_test_PATH, [ 73 | 8, 2, 0], seed, id_list) 74 | 75 | data_path = Path(SRD_path)/"train/train_C" 76 | flist_save_path = Path(save_path)/"SRD_shadow_free_train.flist" 77 | images = gen_flist(data_path) 78 | np.savetxt(flist_save_path, images, fmt='%s') 79 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_train.flist'), 80 | Path(save_path)/str(Path(flist_save_path).stem+'_val.flist'), ""] 81 | gen_flist_train_val_test(flist_save_path, png_val_test_PATH, [ 82 | 8, 2, 0], seed, id_list) 83 | 84 | # test 85 | data_path = Path(SRD_path)/"test/test_A" 86 | flist_save_path = Path(save_path)/"SRD_shadow_test.flist" 87 | images = gen_flist(data_path) 88 | np.savetxt(flist_save_path, images, fmt='%s') 89 | 90 | data_path = Path(SRD_path)/"test/test_B" 91 | flist_save_path = Path(save_path)/"SRD_mask_test.flist" 92 | images = gen_flist(data_path) 93 | np.savetxt(flist_save_path, images, fmt='%s') 94 | 95 | data_path = Path(SRD_path)/"test/test_C" 96 | flist_save_path = Path(save_path)/"SRD_shadow_free_test.flist" 97 | images = gen_flist(data_path) 98 | np.savetxt(flist_save_path, images, fmt='%s') 99 | 100 | -------------------------------------------------------------------------------- /script/generate_flist_istd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from pathlib import Path 4 | from flist import gen_flist 5 | from flist_train_val_test import gen_flist_train_val_test 6 | from process_syn_data import gen_syn_images_flist 7 | 8 | if __name__ == '__main__': 9 | # ISTD 10 | ISTD_path = "/gpfs/home/liujiawei/data-set/shadow_removal_dataset/ISTD_Dataset_arg" 11 | 12 | save_path = ISTD_path+"/data_val" 13 | 14 | if not Path(save_path).exists(): 15 | os.mkdir(save_path) 16 | 17 | seed = 10 18 | 19 | # sys image 20 | shadow_syn_list,shadow_free_syn_list,mask_syn_list = gen_syn_images_flist(ISTD_path+'/train') 21 | 22 | # train and val (data augmentation) 23 | data_path = Path(ISTD_path)/"train/train_A" 24 | flist_main_name = "ISTD_shadow.flist" 25 | flist_save_path = Path(save_path)/flist_main_name 26 | images = gen_flist(data_path) 27 | images+=shadow_syn_list 28 | np.savetxt(flist_save_path, images, fmt='%s') 29 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_pre_train.flist'), 30 | Path(save_path)/str(Path(flist_save_path).stem+'_pre_val.flist'), ""] 31 | id_list = gen_flist_train_val_test( 32 | flist_save_path, png_val_test_PATH, [8, 2, 0], seed, []) 33 | 34 | data_path = Path(ISTD_path)/"train/train_B" 35 | flist_main_name = "ISTD_mask.flist" 36 | flist_save_path = Path(save_path)/flist_main_name 37 | images = gen_flist(data_path) 38 | images+=mask_syn_list 39 | np.savetxt(flist_save_path, images, fmt='%s') 40 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_pre_train.flist'), 41 | Path(save_path)/str(Path(flist_save_path).stem+'_pre_val.flist'), ""] 42 | gen_flist_train_val_test(flist_save_path, png_val_test_PATH, [ 43 | 8, 2, 0], seed, id_list) 44 | 45 | data_path = Path(ISTD_path)/"train/train_C" 46 | flist_main_name = "ISTD_shadow_free.flist" 47 | flist_save_path = Path(save_path)/flist_main_name 48 | images = gen_flist(data_path) 49 | images+=shadow_free_syn_list 50 | np.savetxt(flist_save_path, images, fmt='%s') 51 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_pre_train.flist'), 52 | Path(save_path)/str(Path(flist_save_path).stem+'_pre_val.flist'), ""] 53 | gen_flist_train_val_test(flist_save_path, png_val_test_PATH, [ 54 | 8, 2, 0], seed, id_list) 55 | 56 | # train and val 57 | data_path = Path(ISTD_path)/"train/train_A" 58 | flist_save_path = Path(save_path)/"ISTD_shadow_train.flist" 59 | images = gen_flist(data_path) 60 | np.savetxt(flist_save_path, images, fmt='%s') 61 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_train.flist'), 62 | Path(save_path)/str(Path(flist_save_path).stem+'_val.flist'), ""] 63 | id_list = gen_flist_train_val_test( 64 | flist_save_path, png_val_test_PATH, [8, 2, 0], seed, []) 65 | 66 | data_path = Path(ISTD_path)/"train/train_B" 67 | flist_save_path = Path(save_path)/"ISTD_mask_train.flist" 68 | images = gen_flist(data_path) 69 | np.savetxt(flist_save_path, images, fmt='%s') 70 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_train.flist'), 71 | Path(save_path)/str(Path(flist_save_path).stem+'_val.flist'), ""] 72 | gen_flist_train_val_test(flist_save_path, png_val_test_PATH, [ 73 | 8, 2, 0], seed, id_list) 74 | 75 | data_path = Path(ISTD_path)/"train/train_C" 76 | flist_save_path = Path(save_path)/"ISTD_shadow_free_train.flist" 77 | images = gen_flist(data_path) 78 | np.savetxt(flist_save_path, images, fmt='%s') 79 | png_val_test_PATH = [Path(save_path)/str(Path(flist_save_path).stem+'_train.flist'), 80 | Path(save_path)/str(Path(flist_save_path).stem+'_val.flist'), ""] 81 | gen_flist_train_val_test(flist_save_path, png_val_test_PATH, [ 82 | 8, 2, 0], seed, id_list) 83 | 84 | # test 85 | data_path = Path(ISTD_path)/"test/test_A" 86 | flist_save_path = Path(save_path)/"ISTD_shadow_test.flist" 87 | images = gen_flist(data_path) 88 | np.savetxt(flist_save_path, images, fmt='%s') 89 | 90 | data_path = Path(ISTD_path)/"test/test_B" 91 | flist_save_path = Path(save_path)/"ISTD_mask_test.flist" 92 | images = gen_flist(data_path) 93 | np.savetxt(flist_save_path, images, fmt='%s') 94 | 95 | data_path = Path(ISTD_path)/"test/test_C" 96 | flist_save_path = Path(save_path)/"ISTD_shadow_free_test.flist" 97 | images = gen_flist(data_path) 98 | np.savetxt(flist_save_path, images, fmt='%s') 99 | 100 | -------------------------------------------------------------------------------- /src/metrics.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from decimal import * 3 | 4 | import numpy as np 5 | import skimage.color as skcolor 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | class Metrics(nn.Module): 11 | def __init__(self): 12 | super(Metrics, self).__init__() 13 | nThreads = multiprocessing.cpu_count() 14 | self.multiprocessingi_utils = MultiprocessingiUtils(int(nThreads/4)) 15 | 16 | def rgb2lab_all(self, outputs, GT): 17 | outputs = self.multiprocessingi_utils.rgb2lab(outputs) 18 | GT = self.multiprocessingi_utils.rgb2lab(GT) 19 | 20 | return outputs, GT 21 | 22 | def rmse(self, outputs, mask, GT, dataset_mode=0): 23 | outputs, GT = self.rgb2lab_all(outputs, GT) 24 | 25 | outputs[outputs > 1.0] = 1.0 26 | outputs[outputs < 0] = 0 27 | 28 | mask[mask > 0] = 1 29 | mask[mask == 0] = 0 30 | mask = mask.expand(-1, 3, -1, -1) 31 | mask_inverse = 1 - mask 32 | 33 | error_map = (outputs - GT).abs()*255 34 | 35 | rmse_all = error_map.sum(dim=(1, 2, 3)) 36 | n_pxl_all = torch.from_numpy(np.array( 37 | [error_map.shape[2]*error_map.shape[3]])).cuda().expand(rmse_all.shape[0]).type(torch.float32) 38 | 39 | rmse_shadow = (error_map*mask).sum(dim=(1, 2, 3)) 40 | n_pxl_shadow = mask.sum(dim=(1, 2, 3)) / mask.shape[1] 41 | 42 | rmse_non_shadow = (error_map*mask_inverse).sum(dim=(1, 2, 3)) 43 | n_pxl_non_shadow = mask_inverse.sum( 44 | dim=(1, 2, 3)) / mask_inverse.shape[1] 45 | 46 | if dataset_mode != 0: 47 | return rmse_shadow, n_pxl_shadow, rmse_non_shadow, n_pxl_non_shadow, rmse_all, n_pxl_all 48 | else: 49 | rmse_shadow_eval = ( 50 | rmse_shadow / (n_pxl_shadow+(n_pxl_shadow == 0.0).type(torch.float32))).mean() 51 | rmse_non_shadow_eval = ( 52 | rmse_non_shadow / (n_pxl_non_shadow+(n_pxl_non_shadow == 0.0).type(torch.float32))).mean() 53 | rmse_all_eval = ( 54 | rmse_all/(n_pxl_all+(n_pxl_all == 0.0).type(torch.float32))).mean() 55 | # print('running rmse-shadow: %.4f, rmse-non-shadow: %.4f, rmse-all: %.4f' 56 | # % (rmse_shadow_eval, rmse_non_shadow_eval, rmse_all_eval)) 57 | return rmse_shadow_eval, rmse_non_shadow_eval, rmse_all_eval 58 | 59 | def collect_rmse(self, rmse_shadow, n_pxl_shadow, rmse_non_shadow, n_pxl_non_shadow, rmse_all, n_pxl_all): 60 | # GPU->CPU 61 | rmse_shadow = rmse_shadow.cpu().numpy() 62 | n_pxl_shadow = n_pxl_shadow.cpu().numpy() 63 | rmse_non_shadow = rmse_non_shadow.cpu().numpy() 64 | n_pxl_non_shadow = n_pxl_non_shadow.cpu().numpy() 65 | rmse_all = rmse_all.cpu().numpy() 66 | n_pxl_all = n_pxl_all.cpu().numpy() 67 | 68 | # decimal 69 | getcontext().prec = 50 70 | 71 | # sum 72 | rmse_shadow_ = Decimal(0) 73 | for add in rmse_shadow: 74 | rmse_shadow_ += Decimal(float(add)) 75 | n_pxl_shadow_ = Decimal(0) 76 | for add in n_pxl_shadow: 77 | n_pxl_shadow_ += Decimal(float(add)) 78 | rmse_non_shadow_ = Decimal(0) 79 | for add in rmse_non_shadow: 80 | rmse_non_shadow_ += Decimal(float(add)) 81 | n_pxl_non_shadow_ = Decimal(0) 82 | for add in n_pxl_non_shadow: 83 | n_pxl_non_shadow_ += Decimal(float(add)) 84 | rmse_all_ = Decimal(0) 85 | for add in rmse_all: 86 | rmse_all_ += Decimal(float(add)) 87 | n_pxl_all_ = Decimal(0) 88 | for add in n_pxl_all: 89 | n_pxl_all_ += Decimal(float(add)) 90 | 91 | # compute 92 | rmse_shadow_eval = rmse_shadow_ / n_pxl_shadow_ 93 | rmse_non_shadow_eval = rmse_non_shadow_ / n_pxl_non_shadow_ 94 | rmse_all_eval = rmse_all_ / n_pxl_all_ 95 | 96 | self.multiprocessingi_utils.close() 97 | 98 | return float(rmse_shadow_eval), float(rmse_non_shadow_eval), float(rmse_all_eval) 99 | 100 | 101 | def rgb2lab(img): 102 | img = img.transpose([1, 2, 0]) 103 | return skcolor.rgb2lab(img) 104 | 105 | 106 | class MultiprocessingiUtils(): 107 | def __init__(self, nThreads=4): 108 | self.nThreads = nThreads 109 | self.pool = multiprocessing.Pool(processes=self.nThreads) 110 | 111 | def rgb2lab(self, inputs): 112 | inputs = inputs.cpu().numpy() 113 | inputs_ = self.pool.map(rgb2lab, inputs) 114 | i = 0 115 | for input_ in inputs_: 116 | inputs[i, :, :, :] = input_.transpose([2, 0, 1])/255 117 | i += 1 118 | output = torch.from_numpy(inputs).cuda() 119 | return output 120 | 121 | def close(self): 122 | self.pool.close() 123 | self.pool.join() 124 | 125 | def creat_pool(self): 126 | self.pool = multiprocessing.Pool(processes=self.nThreads) 127 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import random 4 | 5 | import Augmentor 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import torchvision.transforms.functional as F 10 | from PIL import Image 11 | from torch.utils.data import DataLoader 12 | from torchvision import transforms as T 13 | from skimage.color import rgb2gray 14 | 15 | from .utils import imshow, create_mask 16 | 17 | 18 | class Dataset(torch.utils.data.Dataset): 19 | def __init__(self, config, data_flist, mask_flist, GT_flist, additional_mask=[], augment=True): 20 | super(Dataset, self).__init__() 21 | self.augment = augment 22 | self.data = self.load_flist(data_flist) 23 | self.mask = self.load_flist(mask_flist) 24 | self.GT = self.load_flist(GT_flist) 25 | if len(additional_mask) != 0: 26 | self.additional_mask = self.load_flist(additional_mask) 27 | else: 28 | self.additional_mask = [] 29 | 30 | self.input_size_h = config.INPUT_SIZE_H 31 | self.input_size_w = config.INPUT_SIZE_W 32 | self.mask_type = config.MASK 33 | self.dataset_name = config.SUBJECT_WORD 34 | 35 | def __len__(self): 36 | return len(self.data) 37 | 38 | def __getitem__(self, index): 39 | try: 40 | item = self.load_item(index) 41 | except: 42 | print('loading error: ' + self.data[index]) 43 | item = self.load_item(0) 44 | 45 | return item 46 | 47 | def load_name(self, index): 48 | name = self.data[index] 49 | return os.path.basename(name) 50 | 51 | def load_item(self, index): 52 | 53 | size_h = self.input_size_h 54 | size_w = self.input_size_w 55 | 56 | # load image 57 | img = cv2.imread(self.data[index], -1) 58 | mask_GT = cv2.imread(self.mask[index], -1) 59 | GT = cv2.imread(self.GT[index], -1) 60 | 61 | # augment data 62 | if self.augment: 63 | img, mask_GT, GT = self.data_augment( 64 | img, mask_GT, GT) 65 | else: 66 | imgh, imgw = img.shape[0:2] 67 | if not (size_h == imgh and size_w == imgw): 68 | img = self.resize(img, size_h, size_w) 69 | mask_GT = self.resize(mask_GT, size_h, size_w) 70 | GT = self.resize(GT, size_h, size_w) 71 | 72 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 73 | GT = cv2.cvtColor(GT, cv2.COLOR_BGR2RGB) 74 | 75 | # imshow(Image.fromarray(edge_truth)) 76 | return self.to_tensor(img), self.to_tensor(mask_GT), self.to_tensor(GT) 77 | 78 | def data_augment(self, img, mask_GT, GT): 79 | 80 | # https://github.com/mdbloice/Augmentor 81 | images = [[img, GT, mask_GT]] 82 | p = Augmentor.DataPipeline(images) 83 | p.flip_random(1) 84 | width=random.randint(256,640) 85 | height = round((width/self.input_size_w)*self.input_size_h) 86 | p.resize(1, width, height) 87 | g = p.generator(batch_size=1) 88 | augmented_images = next(g) 89 | images = augmented_images[0][0] 90 | GT = augmented_images[0][1] 91 | mask_GT = augmented_images[0][2] 92 | 93 | if len(self.additional_mask) != 0: 94 | additional_mask = self.load_mask( 95 | self.input_size_h, self.input_size_w, random.randint(0, len(self.additional_mask))) 96 | additional_mask = additional_mask.astype( 97 | np.single)/np.max(additional_mask)*np.random.uniform(0, 0.8) 98 | additional_mask = additional_mask[:, :, np.newaxis] 99 | images = images*(1-additional_mask) 100 | images = images.astype(np.uint8) 101 | return images, mask_GT, GT 102 | 103 | def load_mask(self, imgh, imgw, index): 104 | mask_type = self.mask_type 105 | 106 | # external + random block 107 | if mask_type == 4: 108 | mask_type = 1 if np.random.binomial(1, 0.5) == 1 else 3 109 | 110 | # external + random block + half 111 | elif mask_type == 5: 112 | mask_type = np.random.randint(1, 4) 113 | 114 | # random block 115 | if mask_type == 1: 116 | return create_mask(imgw, imgh, imgw // 2, imgh // 2) 117 | 118 | # half 119 | if mask_type == 2: 120 | # randomly choose right or left 121 | return create_mask(imgw, imgh, imgw // 2, imgh, 0 if random.random() < 0.5 else imgw // 2, 0) 122 | 123 | # external 124 | if mask_type == 3: 125 | mask_index = random.randint(0, len(self.additional_mask) - 1) 126 | mask = cv2.imread(self.additional_mask[mask_index]) 127 | mask = 255-mask 128 | mask = self.resize(mask, imgh, imgw) 129 | # threshold due to interpolation 130 | mask = (mask > 0).astype(np.uint8) * 255 131 | return mask 132 | 133 | # test mode: load mask non random 134 | if mask_type == 6: 135 | mask = cv2.imread(self.additional_mask[index]) 136 | mask = self.resize(mask, imgh, imgw, centerCrop=False) 137 | mask = rgb2gray(mask) 138 | mask = (mask > 0).astype(np.uint8) * 255 139 | return mask 140 | 141 | def to_tensor(self, img): 142 | img = Image.fromarray(img) # returns an image object. 143 | img_t = F.to_tensor(img).float() 144 | return img_t 145 | 146 | def resize(self, img, height, width): 147 | img = cv2.resize(img, (width, height), interpolation=cv2.INTER_CUBIC) 148 | return img 149 | 150 | def load_flist(self, flist): 151 | if isinstance(flist, list): 152 | return flist 153 | 154 | # flist: image file path, image directory path, text file flist path 155 | if isinstance(flist, str): 156 | if os.path.isdir(flist): 157 | flist = list(glob.glob(flist + '/*.jpg')) + \ 158 | list(glob.glob(flist + '/*.png')) 159 | flist.sort() 160 | return flist 161 | 162 | if os.path.isfile(flist): 163 | try: 164 | return np.genfromtxt(flist, dtype=np.str, encoding='utf-8') 165 | except: 166 | return [flist] 167 | 168 | return [] 169 | 170 | def create_iterator(self, batch_size): 171 | while True: 172 | sample_loader = DataLoader( 173 | dataset=self, 174 | batch_size=batch_size, 175 | drop_last=True 176 | ) 177 | 178 | for item in sample_loader: 179 | yield item 180 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | This part of the code is built based on the project: 3 | https://github.com/knazeri/edge-connect 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | import torchvision.models as models 8 | from .utils import blur, gauss_kernel, sobel, sobel_kernel 9 | from torch.autograd import Variable 10 | import numpy as np 11 | import os 12 | 13 | 14 | class AdversarialLoss(nn.Module): 15 | r""" 16 | Adversarial loss 17 | https://arxiv.org/abs/1711.10337 18 | """ 19 | 20 | def __init__(self, type='nsgan', target_real_label=1.0, target_fake_label=0.0): 21 | r""" 22 | type = nsgan | lsgan | hinge 23 | """ 24 | super(AdversarialLoss, self).__init__() 25 | 26 | self.type = type 27 | self.register_buffer('real_label', torch.tensor(target_real_label)) 28 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 29 | 30 | if type == 'nsgan': 31 | self.criterion = nn.BCELoss(reduction="mean") 32 | 33 | elif type == 'lsgan': 34 | self.criterion = nn.MSELoss(reduction="mean") 35 | 36 | elif type == 'hinge': 37 | self.criterion = nn.ReLU() 38 | 39 | def __call__(self, outputs, is_real, is_disc=None): 40 | if self.type == 'hinge': 41 | if is_disc: 42 | if is_real: 43 | outputs = -outputs 44 | return self.criterion(1 + outputs).mean() 45 | else: 46 | return (-outputs).mean() 47 | 48 | else: 49 | labels = (self.real_label if is_real else self.fake_label).expand_as( 50 | outputs) 51 | loss = self.criterion(outputs, labels) 52 | return loss 53 | 54 | 55 | class StyleLoss(nn.Module): 56 | r""" 57 | Perceptual loss, VGG-based 58 | https://arxiv.org/abs/1603.08155 59 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 60 | """ 61 | 62 | def __init__(self): 63 | super(StyleLoss, self).__init__() 64 | self.add_module('vgg19', VGG19()) 65 | self.criterion = torch.nn.L1Loss(reduction="mean") 66 | 67 | def compute_gram(self, x): 68 | b, ch, h, w = x.size() 69 | f = x.view(b, ch, w * h) 70 | f_T = f.transpose(1, 2) 71 | G = f.bmm(f_T) / (h * w * ch) 72 | 73 | return G 74 | 75 | def __call__(self, x, y): 76 | # Compute features 77 | x_vgg, y_vgg = self.vgg19(x), self.vgg19(y) 78 | 79 | # Compute loss 80 | style_loss = 0.0 81 | style_loss += self.criterion(self.compute_gram( 82 | x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2'])) 83 | style_loss += self.criterion(self.compute_gram( 84 | x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4'])) 85 | style_loss += self.criterion(self.compute_gram( 86 | x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4'])) 87 | style_loss += self.criterion(self.compute_gram( 88 | x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2'])) 89 | 90 | return style_loss 91 | 92 | 93 | class PerceptualLoss(nn.Module): 94 | r""" 95 | Perceptual loss, VGG-based 96 | https://arxiv.org/abs/1603.08155 97 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py 98 | """ 99 | 100 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]): 101 | super(PerceptualLoss, self).__init__() 102 | self.add_module('vgg19', VGG19()) 103 | self.criterion = torch.nn.L1Loss(reduction="mean") 104 | self.weights = weights 105 | 106 | def __call__(self, x, y): 107 | # Compute features 108 | x_vgg, y_vgg = self.vgg19(x), self.vgg19(y) 109 | 110 | p1 = self.weights[0] * \ 111 | self.criterion(x_vgg['relu1_2'], y_vgg['relu1_2']) 112 | p2 = self.weights[1] * \ 113 | self.criterion(x_vgg['relu2_2'], y_vgg['relu2_2']) 114 | p3 = self.weights[2] * \ 115 | self.criterion(x_vgg['relu3_2'], y_vgg['relu3_2']) 116 | p4 = self.weights[3] * \ 117 | self.criterion(x_vgg['relu4_2'], y_vgg['relu4_2']) 118 | p5 = self.weights[4] * \ 119 | self.criterion(x_vgg['relu5_2'], y_vgg['relu5_2']) 120 | 121 | return p1+p2+p3+p4+p5 122 | 123 | 124 | class VGG19(torch.nn.Module): 125 | def __init__(self): 126 | super(VGG19, self).__init__() 127 | # https://pytorch.org/hub/pytorch_vision_vgg/ 128 | mean = np.array( 129 | [0.485, 0.456, 0.406], dtype=np.float32) 130 | mean = mean.reshape((1, 3, 1, 1)) 131 | self.mean = Variable(torch.from_numpy(mean)).cuda() 132 | std = np.array( 133 | [0.229, 0.224, 0.225], dtype=np.float32) 134 | std = std.reshape((1, 3, 1, 1)) 135 | self.std = Variable(torch.from_numpy(std)).cuda() 136 | self.initial_model() 137 | 138 | def forward(self, x): 139 | relu1_1 = self.relu1_1((x-self.mean)/self.std) 140 | relu1_2 = self.relu1_2(relu1_1) 141 | 142 | relu2_1 = self.relu2_1(relu1_2) 143 | relu2_2 = self.relu2_2(relu2_1) 144 | 145 | relu3_1 = self.relu3_1(relu2_2) 146 | relu3_2 = self.relu3_2(relu3_1) 147 | relu3_3 = self.relu3_3(relu3_2) 148 | relu3_4 = self.relu3_4(relu3_3) 149 | 150 | relu4_1 = self.relu4_1(relu3_4) 151 | relu4_2 = self.relu4_2(relu4_1) 152 | relu4_3 = self.relu4_3(relu4_2) 153 | relu4_4 = self.relu4_4(relu4_3) 154 | 155 | relu5_1 = self.relu5_1(relu4_4) 156 | relu5_2 = self.relu5_2(relu5_1) 157 | relu5_3 = self.relu5_3(relu5_2) 158 | relu5_4 = self.relu5_4(relu5_3) 159 | 160 | out = { 161 | 'relu1_1': relu1_1, 162 | 'relu1_2': relu1_2, 163 | 164 | 'relu2_1': relu2_1, 165 | 'relu2_2': relu2_2, 166 | 167 | 'relu3_1': relu3_1, 168 | 'relu3_2': relu3_2, 169 | 'relu3_3': relu3_3, 170 | 'relu3_4': relu3_4, 171 | 172 | 'relu4_1': relu4_1, 173 | 'relu4_2': relu4_2, 174 | 'relu4_3': relu4_3, 175 | 'relu4_4': relu4_4, 176 | 177 | 'relu5_1': relu5_1, 178 | 'relu5_2': relu5_2, 179 | 'relu5_3': relu5_3, 180 | 'relu5_4': relu5_4, 181 | } 182 | return out 183 | 184 | def load_pretrained(self, vgg19_weights_path, gpu): 185 | if os.path.exists(vgg19_weights_path): 186 | if torch.cuda.is_available(): 187 | data = torch.load(vgg19_weights_path) 188 | print("load vgg_pretrained_model:"+vgg19_weights_path) 189 | else: 190 | data = torch.load(vgg19_weights_path, 191 | map_location=lambda storage, loc: storage) 192 | self.initial_model(data) 193 | self.to(gpu) 194 | else: 195 | print("you need download vgg_pretrained_model in the directory of "+str(self.config.DATA_ROOT) + 196 | "\n'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth'") 197 | raise Exception("Don't load vgg_pretrained_model") 198 | 199 | def initial_model(self,data=None): 200 | vgg19 = models.vgg19() 201 | if data is not None: 202 | vgg19.load_state_dict(data) 203 | features = vgg19.features 204 | self.relu1_1 = torch.nn.Sequential() 205 | self.relu1_2 = torch.nn.Sequential() 206 | 207 | self.relu2_1 = torch.nn.Sequential() 208 | self.relu2_2 = torch.nn.Sequential() 209 | 210 | self.relu3_1 = torch.nn.Sequential() 211 | self.relu3_2 = torch.nn.Sequential() 212 | self.relu3_3 = torch.nn.Sequential() 213 | self.relu3_4 = torch.nn.Sequential() 214 | 215 | self.relu4_1 = torch.nn.Sequential() 216 | self.relu4_2 = torch.nn.Sequential() 217 | self.relu4_3 = torch.nn.Sequential() 218 | self.relu4_4 = torch.nn.Sequential() 219 | 220 | self.relu5_1 = torch.nn.Sequential() 221 | self.relu5_2 = torch.nn.Sequential() 222 | self.relu5_3 = torch.nn.Sequential() 223 | self.relu5_4 = torch.nn.Sequential() 224 | 225 | for x in range(2): 226 | self.relu1_1.add_module(str(x), features[x]) 227 | 228 | for x in range(2, 4): 229 | self.relu1_2.add_module(str(x), features[x]) 230 | 231 | for x in range(4, 7): 232 | self.relu2_1.add_module(str(x), features[x]) 233 | 234 | for x in range(7, 9): 235 | self.relu2_2.add_module(str(x), features[x]) 236 | 237 | for x in range(9, 12): 238 | self.relu3_1.add_module(str(x), features[x]) 239 | 240 | for x in range(12, 14): 241 | self.relu3_2.add_module(str(x), features[x]) 242 | 243 | for x in range(14, 16): 244 | self.relu3_3.add_module(str(x), features[x]) 245 | 246 | for x in range(16, 18): 247 | self.relu3_4.add_module(str(x), features[x]) 248 | 249 | for x in range(18, 21): 250 | self.relu4_1.add_module(str(x), features[x]) 251 | 252 | for x in range(21, 23): 253 | self.relu4_2.add_module(str(x), features[x]) 254 | 255 | for x in range(23, 25): 256 | self.relu4_3.add_module(str(x), features[x]) 257 | 258 | for x in range(25, 27): 259 | self.relu4_4.add_module(str(x), features[x]) 260 | 261 | for x in range(27, 30): 262 | self.relu5_1.add_module(str(x), features[x]) 263 | 264 | for x in range(30, 32): 265 | self.relu5_2.add_module(str(x), features[x]) 266 | 267 | for x in range(32, 34): 268 | self.relu5_3.add_module(str(x), features[x]) 269 | 270 | for x in range(34, 36): 271 | self.relu5_4.add_module(str(x), features[x]) 272 | 273 | # don't need the gradients, just want the features 274 | # for param in self.parameters(): 275 | # param.requires_grad = False -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DMTN 2 | 3 | 4 | 5 | # 1. Resources 6 | [国内资源链接(密码:e2ww)](https://rec.ustc.edu.cn/share/16db9c80-15d5-11ef-9ca4-7b154f4fe8b0) 7 | 8 | ## 1.1 Dataset 9 | * SRD ([github](https://github.com/Liangqiong/DeShadowNet) | [paper](https://openaccess.thecvf.com/content_cvpr_2017/papers/Qu_DeshadowNet_A_Multi-Context_CVPR_2017_paper.pdf)) 10 | * ISTD ([github](https://github.com/DeepInsight-PCALab/ST-CGAN) | [paper](http://openaccess.thecvf.com/content_cvpr_2018/papers/Wang_Stacked_Conditional_Generative_CVPR_2018_paper.pdf)) 11 | * ISTD+DA, SRD+DA ([github](https://github.com/vinthony/ghost-free-shadow-removal) | [paper](https://arxiv.org/abs/1911.08718)) 12 | * [SSRD](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EckYI_84wMdJpVgf5EqkmSABmnOD_-53YZ6v2KIQsiLeXA?e=67Gj47) 13 | 14 | The SSRD dataset does not contain the ground truth of shadow-free images due to the presence of self shadow in images. 15 | 16 | ## 1.2 Results | Model Weight 17 | 18 | **TEST RESULTS ON SRD:** 19 | * Results on SRD: [DMTN_SRD](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EQt3ZoJAbZ5Cq_mhHxzrUVYBcsiaPLjnsN-SmhotYz-UOg?e=hS2RkJ) | Weight: [DMTN_SRD.pth](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EWc4B9PP-rtGp4LxPWGOkfoB6oi6Coh1tu-qG5qxBk-7Cg?e=sdlSiA) 20 | * Results on SRD: [DMTN+Mask_SRD](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EU0JfEPuOUNNlSQDc0TfYgQBOAtMpXjK5yRoa3q2H_bcnQ?e=MgsEu5) | Weight: [DMTN+Mask_SRD.pth](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EU4NJ0CPbwpBrzyXH5FLlMEBzqwKhcXlxe8k4vQiXRrJUw?e=FPKlfF) 21 | 22 | **TEST RESULTS ON ISTD:** 23 | * Results on ISTD: [DMTN_ISTD](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EbyzaV72N2FElC5nOsp3-ZYBuUoVLiy29rmXBMXVXXY6Lg?e=wKA55D) | Weight: [DMTN_ISTD.pth](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EROUwnLgz9BGi3OtJa5SIs8BwdgYBZTXeMJ1NLcGfHAwCg?e=v1f51U) 24 | * Results on ISTD: [DMTN+Mask_ISTD](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EcRgA4y1UAZIkRIabbm71iIBNhRH-JIugQDbInyWE3rpNQ?e=D8BiGD) | Weight: [DMTN+Mask_ISTD.pth](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EcnDQNKeoRdBtUYjQdirl34BR73n--qRnFIo6RnxPvk-KQ?e=9V0LIR) 25 | * Results on ISTD+DA: [DMTN_ISTD_DA](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/ERGASEyFybBDm9rYZv4a3I4B6FwMmrhZMImk_-b7Lo-YeQ?e=MbzrMk) | Weight: [DMTN_ISTD_DA.pth](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EWwqrUr7Qh9KugvJ2S5KsdMBYz6aiR-ufiX3kn3zB626lg?e=7QJRra) 26 | 27 | **TEST RESULTS ON ISTD+:** 28 | * Results on ISTD+: [DMTN_ISTD+](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EZnB81g7L3VPuGo2zhVclVEBPhsO6MBJYPtbOnqxmEDHuw?e=MZLmUM) | Weight: [DMTN_ISTD+.pth](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/ETVno1MtDsdLknqDNKq60VwB9Bq-oq8kZ8B8aiwQBZXbQQ?e=B0S37N) 29 | * Results on ISTD+: [DMTN+Mask_ISTD+](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EZEQr_hD7XdGgPiesl0L8aABSugt0z5U6V9q2Wv-fEr-VA?e=zq5A7s) | Weight: [DMTN+Mask_ISTD+.pth](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/ETo6UMeCGNhFjJ20o0RedaQBG7XIDcfbqucJ3A-hK6IQAQ?e=uN6sTs) 30 | 31 | **TEST RESULTS ON SSRD:** (DHAN and DMTN are pretrained on SRD dataset (size:420x320)) 32 | 33 | * Results of DMTN on SSRD: [DMTN_SSRD](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/ET7vtW6b-RNFiK7hJe8coXoBjMMUj2vZ4nEj1SitH8wuKA?e=ZDnfYV) | Weight: [DMTN_SRD_420_320.pth](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EQgZbEFJCLZGiAM8rnbE-ZUBHXw3zyTrhdet7JDSCrYiuA?e=6PcofV) 34 | * Results of DHAN on SSRD:[DHAN_SSRD](https://mailustceducn-my.sharepoint.com/:u:/g/personal/nachifur_mail_ustc_edu_cn/EROyGJwa2C5JkO1bLDVV_AsBbXRPKoZbBy5EsjAsz6xujg?e=nw53O6) 35 | 36 | 37 | ## 1.3 Visual results 38 | 39 | 40 | *Visual comparison results of **penumbra** removal on the SRD dataset - (Powered by [MulimgViewer](https://github.com/nachifur/MulimgViewer))* 41 | 42 | 43 | 44 | *Visual comparison results of **self shadow** removal on the SSRD dataset - (Powered by [MulimgViewer](https://github.com/nachifur/MulimgViewer))* 45 | 46 | 47 | 48 | 49 | ## 1.4 Evaluation Code 50 | Currently, MATLAB evaluation codes are used in most state-of-the-art works for shadow removal. 51 | 52 | [Our evaluation code](https://github.com/nachifur/DMTN/blob/main/evaluation/evaluate_MAE_PSNR_SSIM.m) (i.e., 1+2) 53 | 1. MAE (i.e., RMSE in paper): https://github.com/tsingqguo/exposure-fusion-shadow-removal 54 | 2. PSNR+SSIM: https://github.com/zhuyr97/AAAI2022_Unfolding_Network_Shadow_Removal/tree/master/codes 55 | 56 | Notably, there are slight differences between the different evaluation codes. 57 | * [wang_cvpr2018](https://github.com/DeepInsight-PCALab/ST-CGAN), [le_iccv2019](https://github.com/cvlab-stonybrook/SID): no imresize; 58 | * [fu_cvpr2021](https://github.com/tsingqguo/exposure-fusion-shadow-removal): first imresize, then double; 59 | * [zhu_aaai2022](https://github.com/zhuyr97/AAAI2022_Unfolding_Network_Shadow_Removal): first double, then imresize; 60 | * Our evaluation code: MAE->fu_cvpr2021, psnr+ssim->zhu_aaai2022 61 | 62 | # 2. Environments 63 | **ubuntu18.04+cuda10.2+pytorch1.7.1** 64 | 1. create environments 65 | ``` 66 | conda env create -f install.yaml 67 | ``` 68 | 2. activate environments 69 | ``` 70 | conda activate DMTN 71 | ``` 72 | 73 | # 3. Data Processing 74 | For example, generate the dataset list of ISTD: 75 | 1. Download: 76 | * ISTD and SRD 77 | * [USR shadowfree images](https://github.com/xw-hu/Mask-ShadowGAN) 78 | * [Syn. Shadow](https://github.com/vinthony/ghost-free-shadow-removal) 79 | * [SRD shadow mask](https://github.com/vinthony/ghost-free-shadow-removal) 80 | * train_B_ISTD: 81 | ``` 82 | cp -r ISTD_Dataset_arg/train_B ISTD_Dataset_arg/train_B_ISTD 83 | cp -r ISTD_Dataset_arg/train_B SRD_Dataset_arg/train_B_ISTD 84 | ``` 85 | * [VGG19](https://download.pytorch.org/models/vgg19-dcbb9e9d.pth) 86 | ``` 87 | cp vgg19-dcbb9e9d.pth ISTD_Dataset_arg/ 88 | cp vgg19-dcbb9e9d.pth SRD_Dataset_arg/ 89 | ``` 90 | 2. The data folders should be: 91 | ``` 92 | ISTD_Dataset_arg 93 | * train 94 | - train_A # ISTD shadow image 95 | - train_B # ISTD shadow mask 96 | - train_C # ISTD shadowfree image 97 | - shadow_free # USR shadowfree images 98 | - synC # Syn. shadow 99 | - train_B_ISTD # ISTD shadow mask 100 | * test 101 | - test_A # ISTD shadow image 102 | - test_B # ISTD shadow mask 103 | - test_C # ISTD shadowfree image 104 | * vgg19-dcbb9e9d.pth 105 | 106 | SRD_Dataset_arg 107 | * train # renaming the original `Train` folder in `SRD`. 108 | - train_A # SRD shadow image, renaming the original `shadow` folder in `SRD`. 109 | - train_B # SRD shadow mask 110 | - train_C # SRD shadowfree image, renaming the original `shadow_free` folder in `SRD`. 111 | - shadow_free # USR shadowfree images 112 | - synC # Syn. shadow 113 | - train_B_ISTD # ISTD shadow mask 114 | * test # renaming the original `test_data` folder in `SRD`. 115 | - train_A # SRD shadow image, renaming the original `shadow` folder in `SRD`. 116 | - train_B # SRD shadow mask 117 | - train_C # SRD shadowfree image, renaming the original `shadow_free` folder in `SRD`. 118 | * vgg19-dcbb9e9d.pth 119 | ``` 120 | 3. Edit `generate_flist_istd.py`: (Replace path) 121 | 122 | ``` 123 | ISTD_path = "/Your_data_storage_path/ISTD_Dataset_arg" 124 | ``` 125 | 4. Generate Datasets List. (Already contains ISTD+DA.) 126 | ``` 127 | conda activate DMTN 128 | cd script/ 129 | python generate_flist_istd.py 130 | ``` 131 | 5. Edit `config_ISTD.yml`: (Replace path) 132 | ``` 133 | DATA_ROOT: /Your_data_storage_path/ISTD_Dataset_arg 134 | ``` 135 | 136 | # 4. Training+Test+Evaluation 137 | ## 4.1 Training+Test+Evaluation 138 | For example, training+test+evaluation on ISTD dataset. 139 | ``` 140 | cp config/config_ISTD.yml config.yml 141 | cp config/run_ISTD.py run.py 142 | conda activate DMTN 143 | python run.py 144 | ``` 145 | ## 4.2 Only Test and Evaluation 146 | For example, test+evaluation on ISTD dataset. 147 | 1. Download weight file(`DMTN_ISTD.pth`) to `pre_train_model/ISTD` 148 | 2. Copy file 149 | ``` 150 | cp config/config_ISTD.yml config.yml 151 | cp config/run_ISTD.py run.py 152 | mkdir -p checkpoints/ISTD/ 153 | cp config.yml checkpoints/ISTD/config.yml 154 | cp pre_train_model/ISTD/DMTN_ISTD.pth checkpoints/ISTD/ShadowRemoval.pth 155 | ``` 156 | 157 | 3. Edit `run.py`. Comment the training code. 158 | 159 | ``` 160 | # # pre_train (no data augmentation) 161 | # MODE = 0 162 | # print('\nmode-'+str(MODE)+': start pre_training(data augmentation)...\n') 163 | # for i in range(1): 164 | # skip_train = init_config(checkpoints_path, MODE=MODE, 165 | # EVAL_INTERVAL_EPOCH=1, EPOCH=[90,i]) 166 | # if not skip_train: 167 | # main(MODE, config_path) 168 | # src_path = Path('./pre_train_model') / \ 169 | # config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_da.pth') 170 | # copypth(dest_path, src_path) 171 | 172 | # # train 173 | # MODE = 2 174 | # print('\nmode-'+str(MODE)+': start training...\n') 175 | # for i in range(1): 176 | # skip_train = init_config(checkpoints_path, MODE=MODE, 177 | # EVAL_INTERVAL_EPOCH=0.1, EPOCH=[60,i]) 178 | # if not skip_train: 179 | # main(MODE, config_path) 180 | # src_path = Path('./pre_train_model') / \ 181 | # config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_final.pth') 182 | # copypth(dest_path, src_path) 183 | ``` 184 | 4. Run 185 | 186 | ``` 187 | conda activate DMTN 188 | python run.py 189 | ``` 190 | ## 4.3 Show Results 191 | After evaluation, execute the following code to display the final RMSE. 192 | ``` 193 | python show_eval_result.py 194 | ``` 195 | Output: 196 | ``` 197 | running rmse-shadow: xxx, rmse-non-shadow: xxx, rmse-all: xxx # ISRD 198 | ``` 199 | This is the evaluation result of python+pytorch, which is only used during training. To get the evaluation results in the paper, you need to run the [matlab code](#1.4). 200 | 201 | ## 4.4 Test on SSRD 202 | 1. Edit `src/network/network_DMTN.py`. Modify the line (https://github.com/nachifur/DMTN/blob/main/src/network/network_DMTN.py#L339). 203 | ``` 204 | SSRD = True 205 | ``` 206 | 2. Test like the section `4.2 Only Test and Evaluation`. 207 | 208 | # 5. Acknowledgements 209 | Part of the code is based upon: 210 | * https://github.com/nachifur/LLPC 211 | * https://github.com/vinthony/ghost-free-shadow-removal 212 | * https://github.com/knazeri/edge-connect 213 | 214 | # 6. Citation 215 | ``` 216 | @ARTICLE{liu2023decoupled, 217 | author={Liu, Jiawei and Wang, Qiang and Fan, Huijie and Li, Wentao and Qu, Liangqiong and Tang, Yandong}, 218 | journal={IEEE Transactions on Multimedia}, 219 | title={A Decoupled Multi-Task Network for Shadow Removal}, 220 | year={2023}, 221 | volume={}, 222 | number={}, 223 | pages={1-14}, 224 | doi={10.1109/TMM.2023.3252271}} 225 | ``` 226 | # 7. Contact 227 | Please contact Jiawei Liu if there is any question (liujiawei18@mails.ucas.ac.cn). 228 | 229 | # 8. Revised Errors in the Paper 230 | Sorry! Here are the revised errors: 231 | 1. In Section III-C-2)-`Fig. 7 (or Fig. 5(b)) shows...`, "we can achieve feature decoupling, i.e., some channels of F represent shadow images (~~`I_m`~~ `I_s`)". 232 | -------------------------------------------------------------------------------- /src/network/networks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | 7 | from src.image_pool import ImagePool 8 | 9 | from src.loss import AdversarialLoss, PerceptualLoss, StyleLoss 10 | 11 | 12 | class BaseNetwork(nn.Module): 13 | def __init__(self): 14 | super(BaseNetwork, self).__init__() 15 | 16 | def init_weights(self, init_type='normal', gain=0.02): 17 | ''' 18 | initialize network's weights 19 | init_type: normal | xavier | kaiming | orthogonal 20 | ''' 21 | 22 | def init_func(m): 23 | classname = m.__class__.__name__ 24 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 25 | if init_type == 'normal': 26 | nn.init.normal_(m.weight.data, 0.0, gain) 27 | elif init_type == 'xavier': 28 | nn.init.xavier_normal_(m.weight.data, gain=gain) 29 | elif init_type == 'kaiming': 30 | nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 31 | elif init_type == 'orthogonal': 32 | nn.init.orthogonal_(m.weight.data, gain=gain) 33 | 34 | if hasattr(m, 'bias') and m.bias is not None: 35 | nn.init.constant_(m.bias.data, 0.0) 36 | 37 | elif classname.find('BatchNorm2d') != -1: 38 | nn.init.normal_(m.weight.data, 1.0, gain) 39 | nn.init.constant_(m.bias.data, 0.0) 40 | 41 | self.apply(init_func) 42 | 43 | 44 | class Discriminator(BaseNetwork): 45 | def __init__(self, config, in_channels, init_weights=True): 46 | super(Discriminator, self).__init__() 47 | # config 48 | self.config = config 49 | self.use_sigmoid = self.config.GAN_LOSS != 'hinge' 50 | norm = self.config.DIS_NORM 51 | # network 52 | self.conv1 = conv2d_layer( 53 | in_channels, 64, kernel_size=4, stride=2, norm=norm) 54 | self.conv2 = conv2d_layer(64, 128, kernel_size=4, stride=2, norm=norm) 55 | self.conv3 = conv2d_layer(128, 256, kernel_size=4, stride=2, norm=norm) 56 | self.conv4 = conv2d_layer(256, 512, kernel_size=4, stride=1, norm=norm) 57 | self.conv5 = conv2d_layer(512, 1, kernel_size=4, stride=1, norm=norm) 58 | # loss 59 | self.add_module('adversarial_loss', 60 | AdversarialLoss(type=self.config.GAN_LOSS)) 61 | self.add_module('perceptual_loss', 62 | PerceptualLoss()) 63 | self.fake_pool = ImagePool(self.config.POOL_SIZE) 64 | # optimizer 65 | self.optimizer = optim.Adam( 66 | params=self.parameters(), 67 | lr=float(config.LR_D), 68 | betas=(config.BETA1, config.BETA2) 69 | ) 70 | if init_weights: 71 | self.init_weights() 72 | 73 | def forward(self, x): 74 | conv1 = self.conv1(x) 75 | conv2 = self.conv2(conv1) 76 | conv3 = self.conv3(conv2) 77 | conv4 = self.conv4(conv3) 78 | conv5 = self.conv5(conv4) 79 | 80 | outputs = conv5 81 | if self.use_sigmoid: 82 | outputs = torch.sigmoid(conv5) 83 | 84 | return outputs, [conv1, conv2, conv3, conv4, conv5] 85 | 86 | def cal_loss(self, images, outputs, GT, cat_img=True): 87 | # discriminator loss 88 | loss = [] 89 | dis_loss = 0 90 | if cat_img: 91 | dis_input_real = torch.cat((images, GT), dim=1) 92 | dis_input_fake = torch.cat((images, F.interpolate(self.fake_pool.query(outputs.detach()), outputs.shape[2:], mode="bilinear",align_corners=True)), dim=1) 93 | else: 94 | dis_input_real = GT 95 | dis_input_fake = F.interpolate(self.fake_pool.query(outputs.detach()), outputs.shape[2:], mode="bilinear",align_corners=True) 96 | 97 | dis_real, dis_real_feat = self(dis_input_real) 98 | dis_fake, dis_fake_feat = self(dis_input_fake) 99 | dis_real_loss = self.adversarial_loss(dis_real, True, True) 100 | dis_fake_loss = self.adversarial_loss(dis_fake, False, True) 101 | dis_loss += (dis_real_loss + dis_fake_loss) / 2 102 | loss.append(dis_loss) 103 | 104 | # generator adversarial loss 105 | if cat_img: 106 | gen_input_fake = torch.cat((images, outputs), dim=1) 107 | else: 108 | gen_input_fake = outputs 109 | gen_fake, gen_fake_feat = self(gen_input_fake) 110 | gen_gan_loss = self.adversarial_loss(gen_fake, True, False) 111 | 112 | # generator perceptual loss 113 | gen_perceptual_loss = self.perceptual_loss(outputs, GT) 114 | 115 | loss.append(gen_gan_loss) 116 | loss.append(gen_perceptual_loss) 117 | return loss 118 | 119 | def backward(self, loss): 120 | self.optimizer.zero_grad() 121 | loss.backward() 122 | self.optimizer.step() 123 | 124 | 125 | def conv2d_layer(in_channels, channels, kernel_size=3, stride=1, padding=1, dilation=1, norm="batch", activation_fn="LeakyReLU", conv_mode="none", pad_mode="ReflectionPad2d"): 126 | """ 127 | norm: batch, spectral, instance, spectral_instance, none 128 | 129 | activation_fn: Sigmoid, ReLU, LeakyReLU, none 130 | 131 | conv_mode: transpose, upsample, none 132 | 133 | pad_mode: ReflectionPad2d, ReplicationPad2d, ZeroPad2d 134 | """ 135 | layer = [] 136 | # padding 137 | if conv_mode == "transpose": 138 | pass 139 | else: 140 | if pad_mode == "ReflectionPad2d": 141 | layer.append(nn.ReflectionPad2d(padding)) 142 | elif pad_mode == "ReplicationPad2d": 143 | layer.append(nn.ReflectionPad2d(padding)) 144 | else: 145 | layer.append(nn.ZeroPad2d(padding)) 146 | padding = 0 147 | 148 | # conv layer 149 | if norm == "spectral" or norm == "spectral_instance": 150 | bias = False 151 | # conv 152 | if conv_mode == "transpose": 153 | conv_ = nn.ConvTranspose2d 154 | elif conv_mode == "upsample": 155 | layer.append(nn.Upsample(mode='bilinear', scale_factor=stride)) 156 | conv_ = nn.Conv2d 157 | else: 158 | conv_ = nn.Conv2d 159 | else: 160 | bias = True 161 | # conv 162 | if conv_mode == "transpose": 163 | layer.append(nn.ConvTranspose2d(in_channels, channels, kernel_size, 164 | bias=bias, stride=stride, padding=padding, dilation=dilation)) 165 | elif conv_mode == "upsample": 166 | layer.append(nn.Upsample(mode='bilinear', scale_factor=stride)) 167 | layer.append(nn.Conv2d(in_channels, channels, kernel_size, 168 | bias=bias, stride=stride, padding=padding, dilation=dilation)) 169 | else: 170 | layer.append(nn.Conv2d(in_channels, channels, kernel_size, 171 | bias=bias, stride=stride, padding=padding, dilation=dilation)) 172 | 173 | # norm 174 | if norm == "spectral": 175 | layer.append(spectral_norm(conv_(in_channels, channels, kernel_size, 176 | stride=stride, bias=bias, padding=padding, dilation=dilation), True)) 177 | elif norm == "instance": 178 | layer.append(nn.InstanceNorm2d( 179 | channels, affine=True, track_running_stats=False)) 180 | elif norm == "batch": 181 | layer.append(nn.BatchNorm2d( 182 | channels, affine=True, track_running_stats=False)) 183 | elif norm == "spectral_instance": 184 | layer.append(spectral_norm(conv_(in_channels, channels, kernel_size, 185 | stride=stride, bias=bias, padding=padding, dilation=dilation), True)) 186 | layer.append(nn.InstanceNorm2d( 187 | channels, affine=True, track_running_stats=False)) 188 | elif norm == "batch_": 189 | layer.append(BatchNorm_(channels)) 190 | else: 191 | pass 192 | 193 | # activation_fn 194 | if activation_fn == "Sigmoid": 195 | layer.append(nn.Sigmoid()) 196 | elif activation_fn == "ReLU": 197 | layer.append(nn.ReLU(True)) 198 | elif activation_fn == "none": 199 | pass 200 | else: 201 | layer.append(nn.LeakyReLU(0.2,inplace=True)) 202 | 203 | return nn.Sequential(*layer) 204 | 205 | 206 | class BatchNorm_(nn.Module): 207 | def __init__(self, channels): 208 | super(BatchNorm_, self).__init__() 209 | self.w0 = torch.nn.Parameter( 210 | torch.FloatTensor([1.0]), requires_grad=True) 211 | self.w1 = torch.nn.Parameter( 212 | torch.FloatTensor([0.0]), requires_grad=True) 213 | self.BatchNorm2d = nn.BatchNorm2d( 214 | channels, affine=True, track_running_stats=False) 215 | 216 | def forward(self, x): 217 | outputs = self.w0*x+self.w1*self.BatchNorm2d(x) 218 | return outputs 219 | 220 | def avgcov2d_layer(pool_kernel_size, pool_stride, in_channels, channels, conv_kernel_size=3, conv_stride=1, padding=1, dilation=1, norm="batch", activation_fn="LeakyReLU"): 221 | layer = [] 222 | layer.append(nn.AvgPool2d(pool_kernel_size, pool_stride)) 223 | layer.append(conv2d_layer(in_channels, channels, kernel_size=conv_kernel_size, stride=conv_stride, 224 | padding=padding, dilation=dilation, norm=norm, activation_fn=activation_fn)) 225 | return nn.Sequential(*layer) 226 | 227 | 228 | def spectral_norm(module, mode=True): 229 | if mode: 230 | return nn.utils.spectral_norm(module) 231 | 232 | return module 233 | 234 | 235 | def get_encoder(encoder_param, norm): 236 | encoder = [] 237 | index = 0 238 | for param in encoder_param: 239 | if index == 0: 240 | encoder.append(conv2d_layer(param[0], param[1], kernel_size=param[2], stride=param[3], padding=param[4], dilation=1, norm=norm, 241 | activation_fn="ReLU", conv_mode="none", pad_mode="ReflectionPad2d")) 242 | else: 243 | encoder.append(conv2d_layer(param[0], param[1], kernel_size=param[2], stride=param[3], padding=param[4], dilation=1, norm=norm, 244 | activation_fn="ReLU", conv_mode="none", pad_mode="ZeroPad2d")) 245 | index += 1 246 | return encoder 247 | 248 | 249 | def get_middle(middle_param, norm): 250 | blocks = [] 251 | for _ in range(middle_param[0]): 252 | block = ResnetBlock( 253 | middle_param[1], norm) 254 | blocks.append(block) 255 | return blocks 256 | 257 | 258 | def get_decoder(decoder_param, norm, Sigmoid=True): 259 | if Sigmoid: 260 | activation_fn = "Sigmoid" 261 | else: 262 | activation_fn = "none" 263 | decoder = [] 264 | index = 0 265 | for param in decoder_param: 266 | if index == len(decoder_param)-1: 267 | decoder.append(conv2d_layer(param[0], param[1], kernel_size=param[2], stride=param[3], padding=param[4], dilation=1, norm="none", 268 | activation_fn=activation_fn, conv_mode="none", pad_mode="ReflectionPad2d")) 269 | else: 270 | decoder.append(conv2d_layer(param[0], param[1], kernel_size=param[2], stride=param[3], padding=param[4], dilation=1, norm=norm, 271 | activation_fn="ReLU", conv_mode="transpose", pad_mode="ZeroPad2d")) 272 | index += 1 273 | return decoder 274 | 275 | 276 | def get_encoder_decoder(in_channels, ResnetBlockNum, Sigmoid=True, norm="batch"): 277 | encoder_param = [ 278 | [in_channels, 64, 7, 1, 3], 279 | [64, 128, 4, 2, 1], 280 | [128, 256, 4, 2, 1]] 281 | encoder = nn.Sequential( 282 | *get_encoder(encoder_param, norm)) 283 | middle_param = [ResnetBlockNum, 256] 284 | middle = nn.Sequential( 285 | *get_middle(middle_param, norm)) 286 | decoder_param = [ 287 | [256, 128, 4, 2, 1], 288 | [128, 64, 4, 2, 1], 289 | [64, 3, 7, 1, 3]] 290 | decoder = nn.Sequential( 291 | *get_decoder(decoder_param, norm, Sigmoid=Sigmoid)) 292 | return nn.Sequential(*[encoder, middle, decoder]) 293 | 294 | 295 | class ResnetBlock(nn.Module): 296 | def __init__(self, dim, norm): 297 | super(ResnetBlock, self).__init__() 298 | self.conv_block = nn.Sequential( 299 | conv2d_layer(dim, dim, kernel_size=3, stride=1, padding=1, dilation=1, norm=norm, 300 | activation_fn="ReLU", conv_mode="none", pad_mode="ReflectionPad2d"), 301 | 302 | conv2d_layer(dim, dim, kernel_size=3, stride=1, padding=2, dilation=2, norm=norm, 303 | activation_fn="ReLU", conv_mode="none", pad_mode="ReflectionPad2d"), 304 | 305 | conv2d_layer(dim, dim, kernel_size=3, stride=1, padding=1, dilation=1, norm=norm, 306 | activation_fn="ReLU", conv_mode="none", pad_mode="ReflectionPad2d"), 307 | ) 308 | 309 | def forward(self, x): 310 | out = x + self.conv_block(x) 311 | return out 312 | -------------------------------------------------------------------------------- /src/network/network_DMTN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from numpy.core.fromnumeric import shape 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | 8 | from src.loss import VGG19 9 | from src.network.networks import Discriminator, avgcov2d_layer, conv2d_layer 10 | from torch.autograd import Variable 11 | from src.utils import solve_factor, imshow 12 | import torchvision.transforms.functional as TF 13 | import scipy.linalg 14 | 15 | 16 | class BaseNetwork(nn.Module): 17 | def __init__(self): 18 | super(BaseNetwork, self).__init__() 19 | 20 | def init_weights(self, init_type='identity', gain=0.02): 21 | ''' 22 | initialize network's weights 23 | init_type: normal | xavier | kaiming | orthogonal 24 | ''' 25 | 26 | def init_func(m): 27 | classname = m.__class__.__name__ 28 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 29 | if init_type == 'normal': 30 | nn.init.normal_(m.weight.data, 0.0, gain) 31 | elif init_type == 'xavier': 32 | nn.init.xavier_normal_(m.weight.data, gain=gain) 33 | elif init_type == 'kaiming': 34 | nn.init.kaiming_normal_( 35 | m.weight.data, a=0, mode='fan_in') 36 | elif init_type == 'orthogonal': 37 | nn.init.orthogonal_(m.weight.data, gain=gain) 38 | elif init_type == "identity": 39 | if isinstance(m, nn.Linear): 40 | nn.init.normal_(m.weight, 0, gain) 41 | else: 42 | identity_initializer(m.weight.data) 43 | 44 | if hasattr(m, 'bias') and m.bias is not None: 45 | nn.init.constant_(m.bias.data, 0.0) 46 | 47 | elif classname.find('BatchNorm2d') != -1: 48 | nn.init.normal_(m.weight.data, 1.0, gain) 49 | nn.init.constant_(m.bias.data, 0.0) 50 | 51 | self.apply(init_func) 52 | 53 | 54 | def identity_initializer(data): 55 | shape = data.shape 56 | array = np.zeros(shape, dtype=float) 57 | cx, cy = shape[2]//2, shape[3]//2 58 | for i in range(np.minimum(shape[0], shape[1])): 59 | array[i, i, cx, cy] = 1 60 | return torch.tensor(array, dtype=torch.float32) 61 | 62 | 63 | class DMTN(BaseNetwork): 64 | """DMTN""" 65 | 66 | def __init__(self, config, in_channels=3, init_weights=True): 67 | super(DMTN, self).__init__() 68 | 69 | # gan 70 | channels = 64 71 | stage_num = [12, 2] 72 | self.network = DMTNSOURCE( 73 | in_channels, channels, norm=config.GAN_NORM, stage_num=stage_num) 74 | # gan loss 75 | if config.LOSS == "MSELoss": 76 | self.add_module('loss', nn.MSELoss(reduction="mean")) 77 | elif config.LOSS == "L1Loss": 78 | self.add_module('loss', nn.L1Loss(reduction="mean")) 79 | 80 | # gan optimizer 81 | self.optimizer = optim.Adam( 82 | params=self.network.parameters(), 83 | lr=float(config.LR), 84 | betas=(config.BETA1, config.BETA2) 85 | ) 86 | 87 | # dis 88 | self.ADV = config.ADV 89 | if self.ADV: 90 | discriminator = [] 91 | discriminator.append(Discriminator(config, in_channels=6)) 92 | self.discriminator = nn.Sequential(*discriminator) 93 | 94 | if init_weights: 95 | self.init_weights(config.INIT_TYPE) 96 | 97 | def process(self, images, mask, GT): 98 | loss = [] 99 | logs = [] 100 | inputs = images 101 | img, net_matte, outputs = self( 102 | inputs) 103 | 104 | matte_gt = GT-inputs 105 | matte_gt = matte_gt - \ 106 | (matte_gt.min(dim=2, keepdim=True).values).min( 107 | dim=3, keepdim=True).values 108 | matte_gt = matte_gt / \ 109 | (matte_gt.max(dim=2, keepdim=True).values).max( 110 | dim=3, keepdim=True).values 111 | match_loss_1 = 0 112 | match_loss_2 = 0 113 | matte_loss = 0 114 | 115 | match_loss_1 += self.cal_loss(outputs, GT)*255*10 116 | match_loss_2 += self.cal_loss(img, images)*255*10 117 | matte_loss += self.cal_loss(net_matte, matte_gt)*255 118 | 119 | if self.ADV: 120 | dis_loss_1, gen_gan_loss_1, perceptual_loss_1 = self.discriminator[0].cal_loss( 121 | images, outputs, GT) 122 | 123 | perceptual_loss_1 = perceptual_loss_1*1000 124 | 125 | gen_loss = perceptual_loss_1 + match_loss_1+match_loss_2 +\ 126 | matte_loss+gen_gan_loss_1 127 | 128 | loss.append(gen_loss) 129 | loss.append(dis_loss_1) 130 | 131 | logs.append(("l_match1", match_loss_1.item())) 132 | logs.append(("l_match2", match_loss_2.item())) 133 | logs.append(("l_matte", matte_loss.item())) 134 | logs.append(("l_perceptual_1", perceptual_loss_1.item())) 135 | logs.append(("l_adv1", gen_gan_loss_1.item())) 136 | logs.append(("l_gen", gen_loss.item())) 137 | logs.append(("l_dis1", dis_loss_1.item())) 138 | else: 139 | gen_loss = match_loss_1 + match_loss_2 + matte_loss 140 | gen_loss = gen_loss 141 | loss.append(gen_loss) 142 | 143 | logs.append(("l_match1", match_loss_1.item())) 144 | logs.append(("l_match2", match_loss_2.item())) 145 | logs.append(("l_matte", matte_loss.item())) 146 | logs.append(("l_gen", gen_loss.item())) 147 | 148 | return [net_matte, outputs], loss, logs 149 | 150 | def forward(self, x): 151 | outputs = self.network(x) 152 | return outputs 153 | 154 | def cal_loss(self, outputs, GT): 155 | matching_loss = self.loss(outputs, GT) 156 | return matching_loss 157 | 158 | def backward(self, loss): 159 | self.optimizer.zero_grad() 160 | loss[0].backward() 161 | self.optimizer.step() 162 | if self.ADV: 163 | i = 0 164 | for discriminator in self.discriminator: 165 | discriminator.backward(loss[1+i]) 166 | i += 1 167 | 168 | 169 | class FeatureDecouplingModule(nn.Module): 170 | '''Shadow feature decoupling module''' 171 | 172 | def __init__(self, in_channels=64, channels=3): 173 | super(FeatureDecouplingModule, self).__init__() 174 | kernel_size = 1 175 | 176 | w = torch.randn(channels, in_channels, kernel_size, kernel_size) 177 | self.w0 = torch.nn.Parameter(torch.FloatTensor( 178 | self.normalize_to_0_1(w)), requires_grad=True) 179 | w = torch.randn(channels, in_channels, kernel_size, kernel_size) 180 | self.w1 = torch.nn.Parameter(torch.FloatTensor( 181 | self.normalize_to_0_1(w)), requires_grad=True) 182 | w = torch.zeros(channels, in_channels, kernel_size, kernel_size) 183 | self.w2 = torch.nn.Parameter(torch.FloatTensor( 184 | w), requires_grad=True) 185 | 186 | self.bias_proportion = torch.nn.Parameter(torch.zeros( 187 | (channels, 1, 1, 1)), requires_grad=True) 188 | 189 | self.alpha_0 = torch.nn.Parameter(torch.ones( 190 | (1, channels, 1, 1)), requires_grad=True) 191 | self.alpha_1 = torch.nn.Parameter(torch.ones( 192 | (1, channels, 1, 1)), requires_grad=True) 193 | self.alpha_2 = torch.nn.Parameter(torch.ones( 194 | (1, channels, 1, 1)), requires_grad=True) 195 | 196 | self.bias_0 = torch.nn.Parameter(torch.zeros( 197 | (1, channels, 1, 1)), requires_grad=True) 198 | self.bias_1 = torch.nn.Parameter(torch.zeros( 199 | (1, channels, 1, 1)), requires_grad=True) 200 | self.bias_2 = torch.nn.Parameter(torch.zeros( 201 | (1, channels, 1, 1)), requires_grad=True) 202 | 203 | def forward(self, x): 204 | w0 = self.w0 205 | w1 = self.w1 206 | o, c, k_w, k_h = w0.shape 207 | 208 | w0_attention = 1+self.w2 209 | w1_attention = 1-self.w2 210 | 211 | w = w0*w0_attention 212 | median_w = torch.median(w, dim=1, keepdim=True) 213 | w0_correct = F.relu(w-median_w.values+self.bias_proportion) 214 | 215 | w0_correct = self.normalize_to_0_1(w0_correct) 216 | 217 | w = w1*w1_attention 218 | median_w = torch.median(w, dim=1, keepdim=True) 219 | w1_correct = F.relu(w-median_w.values-self.bias_proportion) 220 | w1_correct = self.normalize_to_0_1(w1_correct) 221 | 222 | w2_correct = w0_correct+w1_correct 223 | 224 | img = torch.sigmoid(self.alpha_0*F.conv2d(x, w0_correct)+self.bias_0) 225 | matte = torch.sigmoid(self.alpha_1*F.conv2d(x, w1_correct)+self.bias_1) 226 | img_free = torch.sigmoid( 227 | self.alpha_2*F.conv2d(x, w2_correct)+self.bias_2) 228 | 229 | return img, matte, img_free 230 | 231 | def normalize_to_0_1(self, w): 232 | w = w-w.min() 233 | w = w/w.max() 234 | return w 235 | 236 | 237 | class DMTNSOURCE(nn.Module): 238 | def __init__(self, in_channels=3, channels=64, norm="batch", stage_num=[6, 4]): 239 | super(DMTNSOURCE, self).__init__() 240 | self.stage_num = stage_num 241 | 242 | # Pre-trained VGG19 243 | self.add_module('vgg19', VGG19()) 244 | 245 | # SE 246 | cat_channels = in_channels+64+128+256+512+512 247 | self.se = nn.Sequential(SELayer(cat_channels), 248 | conv2d_layer(cat_channels, channels, kernel_size=3, padding=1, dilation=1, norm=norm)) 249 | 250 | self.down_sample = conv2d_layer( 251 | channels, 2*channels, kernel_size=4, stride=2, padding=1, dilation=1, norm=norm) 252 | 253 | # coarse 254 | coarse_list = [] 255 | for i in range(self.stage_num[0]): 256 | coarse_list.append(SemiConvModule( 257 | 2*channels, norm, mid_dilation=2**(i % 6))) 258 | self.coarse_list = nn.Sequential(*coarse_list) 259 | 260 | self.up_conv = conv2d_layer( 261 | 2*channels, channels, kernel_size=3, stride=1, padding=1, dilation=1, norm=norm) 262 | 263 | # fine 264 | fine_list = [] 265 | for i in range(self.stage_num[1]): 266 | fine_list.append(SemiConvModule( 267 | channels, norm, mid_dilation=2**(i % 6))) 268 | self.fine_list = nn.Sequential(*fine_list) 269 | 270 | self.se_coarse = nn.Sequential(SELayer(2*channels), 271 | conv2d_layer(2*channels, channels, kernel_size=3, padding=1, dilation=1, norm=norm)) 272 | 273 | # SPP 274 | self.spp = SPP(channels, norm=norm) 275 | 276 | # Shadow feature decoupling module' 277 | self.FDM = FeatureDecouplingModule(in_channels=channels, channels=3) 278 | 279 | def forward(self, x): 280 | size = (x.shape[2], x.shape[3]) 281 | 282 | # vgg 283 | x_vgg = self.vgg19(x) 284 | 285 | # hyper-column features 286 | x_cat = torch.cat(( 287 | x, 288 | F.interpolate(x_vgg['relu1_2'], size, 289 | mode="bilinear", align_corners=True), 290 | F.interpolate(x_vgg['relu2_2'], size, 291 | mode="bilinear", align_corners=True), 292 | F.interpolate(x_vgg['relu3_2'], size, 293 | mode="bilinear", align_corners=True), 294 | F.interpolate(x_vgg['relu4_2'], size, 295 | mode="bilinear", align_corners=True), 296 | F.interpolate(x_vgg['relu5_2'], size, mode="bilinear", align_corners=True)), dim=1) 297 | 298 | # SE 299 | x = self.se(x_cat) 300 | 301 | # coarse 302 | x_ = x 303 | x = self.down_sample(x) 304 | for i in range(self.stage_num[0]): 305 | x = self.coarse_list[i](x) 306 | 307 | size = (x_.shape[2], x_.shape[3]) 308 | x = F.interpolate(x, size, mode="bilinear", align_corners=True) 309 | x = self.up_conv(x) 310 | 311 | # fine 312 | x = self.se_coarse(torch.cat((x_, x), dim=1)) 313 | for i in range(self.stage_num[1]): 314 | x = self.fine_list[i](x) 315 | 316 | # spp 317 | x = self.spp(x) 318 | 319 | # output 320 | img, matte_out, img_free = self.FDM(x) 321 | 322 | return [img, matte_out, img_free] 323 | 324 | 325 | class SemiConvModule(nn.Module): 326 | def __init__(self, channels=64, norm="batch", mid_dilation=2): 327 | super(SemiConvModule, self).__init__() 328 | list_factor = solve_factor(channels) 329 | self.group = list_factor[int(len(list_factor)/2)-1] 330 | self.split_channels = int(channels/2) 331 | 332 | # Conv 333 | self.conv_dilation = conv2d_layer( 334 | self.split_channels, self.split_channels, kernel_size=3, padding=mid_dilation, dilation=mid_dilation, norm=norm) 335 | self.conv_3x3 = conv2d_layer( 336 | self.split_channels, self.split_channels, kernel_size=3, padding=1, dilation=1, norm=norm) 337 | 338 | def forward(self, x): 339 | SSRD=False 340 | if SSRD: 341 | x_conv = x[:, self.split_channels:, :, :] 342 | x_identity = x[:, 0:self.split_channels, :, :] 343 | else: 344 | x_conv = x[:, 0:self.split_channels, :, :] 345 | x_identity = x[:, self.split_channels:, :, :] 346 | 347 | x_conv = x_conv+self.conv_dilation(x_conv)+self.conv_3x3(x_conv) 348 | 349 | x = torch.cat((x_identity, x_conv), dim=1) 350 | x = self.channel_shuffle(x) 351 | return x 352 | 353 | def channel_shuffle(self, x): 354 | batchsize, num_channels, height, width = x.data.size() 355 | assert num_channels % self.group == 0 356 | group_channels = num_channels // self.group 357 | 358 | x = x.reshape(batchsize, group_channels, self.group, height, width) 359 | x = x.permute(0, 2, 1, 3, 4) 360 | x = x.reshape(batchsize, num_channels, height, width) 361 | 362 | return x 363 | 364 | def identity(self, x): 365 | return x 366 | 367 | 368 | class SPP(nn.Module): 369 | def __init__(self, channels=64, norm="batch"): 370 | super(SPP, self).__init__() 371 | self.net2 = avgcov2d_layer( 372 | 4, 4, channels, channels, 1, padding=0, norm=norm) 373 | self.net8 = avgcov2d_layer( 374 | 8, 8, channels, channels, 1, padding=0, norm=norm) 375 | self.net16 = avgcov2d_layer( 376 | 16, 16, channels, channels, 1, padding=0, norm=norm) 377 | self.net32 = avgcov2d_layer( 378 | 32, 32, channels, channels, 1, padding=0, norm=norm) 379 | self.output = conv2d_layer(channels*5, channels, 3, norm=norm) 380 | 381 | def forward(self, x): 382 | size = (x.shape[2], x.shape[3]) 383 | x = torch.cat(( 384 | F.interpolate(self.net2(x), size, mode="bilinear", 385 | align_corners=True), 386 | F.interpolate(self.net8(x), size, mode="bilinear", 387 | align_corners=True), 388 | F.interpolate(self.net16(x), size, 389 | mode="bilinear", align_corners=True), 390 | F.interpolate(self.net32(x), size, 391 | mode="bilinear", align_corners=True), 392 | x), dim=1) 393 | x = self.output(x) 394 | return x 395 | 396 | 397 | class SELayer(nn.Module): 398 | def __init__(self, channel, reduction=8): 399 | super(SELayer, self).__init__() 400 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 401 | self.fc = nn.Sequential( 402 | nn.Linear(channel, channel // reduction, bias=True), 403 | nn.ReLU(inplace=True), 404 | nn.Linear(channel // reduction, channel, bias=True), 405 | nn.Sigmoid() 406 | ) 407 | 408 | def forward(self, x): 409 | b, c, _, _ = x.size() 410 | y = self.avg_pool(x).view(b, c) 411 | y = self.fc(y).view(b, c, 1, 1) 412 | return x * y 413 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | import sys 5 | import time 6 | from pathlib import Path 7 | from shutil import copyfile 8 | 9 | import cv2 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | import scipy.stats as st 13 | import torch 14 | import torch.nn.functional as F 15 | import yaml 16 | from PIL import Image 17 | 18 | def solve_factor(num): 19 | # solve factors for a number 20 | list_factor = [] 21 | i = 1 22 | if num > 2: 23 | while i <= num: 24 | i += 1 25 | if num % i == 0: 26 | list_factor.append(i) 27 | else: 28 | pass 29 | else: 30 | pass 31 | 32 | list_factor = list(set(list_factor)) 33 | list_factor = np.sort(list_factor) 34 | return list_factor 35 | 36 | def copypth(dest_path, src_path): 37 | if (src_path).is_file(): 38 | copyfile(src_path, dest_path) 39 | print(str(src_path)+" copy to "+str(dest_path)) 40 | 41 | 42 | def gauss_kernel(kernlen=21, nsig=3, channels=1): 43 | # https://github.com/dojure/FPIE/blob/master/utils.py 44 | interval = (2 * nsig + 1.) / (kernlen) 45 | x = np.linspace(-nsig - interval / 2., nsig + interval / 2., kernlen + 1) 46 | kern1d = np.diff(st.norm.cdf(x)) 47 | kernel_raw = np.sqrt(np.outer(kern1d, kern1d)) 48 | kernel = kernel_raw / kernel_raw.sum() 49 | out_filter = np.array(kernel, dtype=np.float32) 50 | out_filter = out_filter.reshape((1, 1, kernlen, kernlen)) 51 | out_filter = np.repeat(out_filter, channels, axis=1) 52 | return out_filter 53 | 54 | 55 | def blur(x, kernel_var): 56 | return F.conv2d(x, kernel_var, padding=1) 57 | 58 | 59 | def sobel_kernel(kernlen=3, channels=1): 60 | out_filter = np.array( 61 | [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=np.float32) 62 | out_filter = out_filter.reshape((1, 1, kernlen, kernlen)) 63 | out_filter_x = np.repeat(out_filter, channels, axis=1) 64 | 65 | out_filter = np.array( 66 | [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=np.float32) 67 | out_filter = out_filter.reshape((1, 1, kernlen, kernlen)) 68 | out_filter_y = np.repeat(out_filter, channels, axis=1) 69 | return [out_filter_x, out_filter_y] 70 | 71 | 72 | def sobel(x, kernel_var): 73 | sobel_kernel_x = kernel_var[0] 74 | sobel_kernel_y = kernel_var[1] 75 | sobel_x = F.conv2d(x, sobel_kernel_x, padding=1) 76 | sobel_y = F.conv2d(x, sobel_kernel_y, padding=1) 77 | 78 | return sobel_x.abs()+sobel_y.abs() 79 | 80 | 81 | def create_dir(dir): 82 | if not os.path.exists(dir): 83 | os.makedirs(dir) 84 | 85 | 86 | def create_config(config_path, cover=False): 87 | if cover: 88 | copyfile('./config.yml', config_path) 89 | else: 90 | if not os.path.exists(config_path): 91 | copyfile('./config.yml', config_path) 92 | 93 | 94 | def init_config(checkpoints_path, MODE=0, EVAL_INTERVAL_EPOCH=1, EPOCH=[30, 0]): 95 | if EVAL_INTERVAL_EPOCH < 1: 96 | APPEND = 0 97 | else: 98 | APPEND = 1 99 | if len(EPOCH) > 2: 100 | lr_restart = True 101 | else: 102 | lr_restart = False 103 | # Re-training after training is interrupted abnormally 104 | skip_train = restart_train( 105 | checkpoints_path, MODE, APPEND, EPOCH, lr_restart) 106 | if skip_train: 107 | return skip_train 108 | 109 | # edit config 110 | config_path = os.path.join(checkpoints_path, 'config.yml') 111 | fr = open(config_path, 'r') 112 | config = yaml.load(fr, Loader=yaml.FullLoader) 113 | fr.close() 114 | 115 | EPOCHLIST = EPOCH 116 | ALL_EPOCH = EPOCH[-2] 117 | EPOCH = EPOCH[EPOCH[-1]] 118 | 119 | if MODE == 0 or MODE == 1: 120 | flist = config['TRAIN_FLIST_PRE'] 121 | else: 122 | flist = config['TRAIN_FLIST'] 123 | TRAIN_DATA_NUM = len(np.genfromtxt( 124 | config["DATA_ROOT"]+flist, dtype=np.str, encoding='utf-8')) 125 | print('train data number is:{}'.format(TRAIN_DATA_NUM)) 126 | 127 | if torch.cuda.is_available(): 128 | BATCH_SIZE = config['BATCH_SIZE'] 129 | MAX_ITERS = EPOCH * (TRAIN_DATA_NUM // BATCH_SIZE) # drop_last=True 130 | if config["DEBUG"]: 131 | INTERVAL = 10 132 | config['MAX_ITERS'] = 80 133 | config['EVAL_INTERVAL'] = INTERVAL 134 | config['DEBUG'] = 1 135 | config['SAMPLE_INTERVAL'] = INTERVAL 136 | config['SAVE_INTERVAL'] = INTERVAL 137 | else: 138 | INTERVAL = ((EVAL_INTERVAL_EPOCH * (TRAIN_DATA_NUM // BATCH_SIZE))//10)*10 139 | config['MAX_ITERS'] = MAX_ITERS 140 | config['EVAL_INTERVAL'] = INTERVAL 141 | config['DEBUG'] = 0 142 | config['SAMPLE_INTERVAL'] = 500 143 | config['SAVE_INTERVAL'] = INTERVAL 144 | 145 | if EVAL_INTERVAL_EPOCH < 1: 146 | config['FORCE_EXIT'] = 0 147 | else: 148 | config['FORCE_EXIT'] = 1 149 | 150 | config["MODE"] = MODE 151 | config['ALL_EPOCH'] = ALL_EPOCH 152 | config['EPOCH'] = EPOCH 153 | config['EPOCHLIST'] = EPOCHLIST 154 | config['APPEND'] = APPEND 155 | config['EVAL_INTERVAL_EPOCH'] = EVAL_INTERVAL_EPOCH 156 | save_config(config, config_path) 157 | else: 158 | print("cuda is unavailable") 159 | 160 | return skip_train 161 | 162 | 163 | def restart_train(checkpoints_path, MODE, APPEND, EPOCH, lr_restart): 164 | config_path = os.path.join(checkpoints_path, 'config.yml') 165 | fr = open('./config.yml', 'r') 166 | config = yaml.load(fr, Loader=yaml.FullLoader) 167 | fr.close() 168 | 169 | if MODE == 0: 170 | src_path = Path('./pre_train_model') / \ 171 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_da.pth') 172 | elif MODE == 1: 173 | src_path = Path('./pre_train_model') / \ 174 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_pre_no_da.pth') 175 | elif MODE == 2: 176 | src_path = Path('./pre_train_model') / \ 177 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'_final.pth') 178 | 179 | if lr_restart: 180 | src_path = Path(checkpoints_path)/("model_save_mode_" + 181 | str(MODE))/(str(EPOCH[EPOCH[-1]]-1)+".0.pth") 182 | 183 | log_eval_val_ap_path = Path(checkpoints_path) / \ 184 | ('log_eval_val_ap_'+str(MODE)+'.txt') 185 | 186 | 187 | 188 | if Path(src_path).is_file(): 189 | skip_train = True 190 | cover = False # retrain 191 | else: 192 | skip_train = False 193 | if APPEND == 1 and log_eval_val_ap_path.is_file(): 194 | cover = False 195 | else: 196 | cover = True # new train stage 197 | 198 | # append 199 | if (not Path(src_path).is_file()) and APPEND == 1 and log_eval_val_ap_path.is_file(): 200 | eval_val_ap = np.genfromtxt( 201 | log_eval_val_ap_path, dtype=np.str, encoding='utf-8').astype(np.float) 202 | src_path = Path(checkpoints_path)/("model_save_mode_" + 203 | str(MODE))/(str(eval_val_ap[0, 1]) + '.pth') 204 | 205 | if eval_val_ap[0, 1]==EPOCH[EPOCH[-1]-1]-1: 206 | cover = True 207 | 208 | # copy .pth 209 | dest_path = Path('checkpoints/') / \ 210 | config["SUBJECT_WORD"]/(config["MODEL_NAME"]+'.pth') 211 | copypth(dest_path, src_path) 212 | 213 | # create config 214 | create_config(config_path, cover=cover) 215 | print("cover config file-"+str(cover)) 216 | 217 | if skip_train: 218 | print("skip train stage of mode-"+str(MODE)) 219 | return skip_train 220 | 221 | 222 | def save_config(config, config_path): 223 | with open(config_path, 'w') as f_obj: 224 | yaml.dump(config, f_obj) 225 | 226 | 227 | def stitch_images(*outputs, img_per_row=2): 228 | inputs_all = [*outputs] 229 | inputs = inputs_all[0] 230 | gap = 5 231 | images = [inputs_all[0], *inputs_all[1], *inputs_all[2:]] 232 | 233 | columns = len(images) 234 | 235 | height, width = inputs[0][:, :, 0].shape 236 | img = Image.new('RGB', (width * img_per_row * columns + gap * 237 | (img_per_row - 1), height * int(len(inputs) / img_per_row))) 238 | 239 | for ix in range(len(inputs)): 240 | xoffset = int(ix % img_per_row) * width * \ 241 | columns + int(ix % img_per_row) * gap 242 | yoffset = int(ix / img_per_row) * height 243 | 244 | for cat in range(len(images)): 245 | im = np.array((images[cat][ix]).cpu()).astype(np.uint8).squeeze() 246 | im = Image.fromarray(im) 247 | img.paste(im, (xoffset + cat * width, yoffset)) 248 | 249 | return img 250 | 251 | 252 | def imshow(img, title=''): 253 | fig = plt.gcf() 254 | fig.canvas.set_window_title(title) 255 | plt.axis('off') 256 | if len(img.size) == 3: 257 | plt.imshow(img, interpolation='none') 258 | plt.show() 259 | else: 260 | plt.imshow(img, cmap='Greys_r') 261 | plt.show() 262 | 263 | 264 | def imsave(img, path): 265 | im = Image.fromarray(img.cpu().numpy().astype(np.uint8).squeeze()) 266 | im.save(path) 267 | 268 | 269 | def create_mask(width, height, mask_width, mask_height, x=None, y=None): 270 | # # https://github.com/knazeri/edge-connect 271 | mask = np.zeros((height, width)) 272 | mask_x = x if x is not None else random.randint(0, width - mask_width) 273 | mask_y = y if y is not None else random.randint(0, height - mask_height) 274 | mask[mask_y:mask_y + mask_height, mask_x:mask_x + mask_width] = 1 275 | return mask 276 | 277 | 278 | class Progbar(object): 279 | # https://github.com/knazeri/edge-connect 280 | """Displays a progress bar. 281 | 282 | Arguments: 283 | target: Total number of steps expected, None if unknown. 284 | width: Progress bar width on screen. 285 | verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) 286 | stateful_metrics: Iterable of string names of metrics that 287 | should *not* be averaged over time. Metrics in this list 288 | will be displayed as-is. All others will be averaged 289 | by the progbar before display. 290 | interval: Minimum visual progress update interval (in seconds). 291 | """ 292 | 293 | def __init__(self, target, width=25, verbose=1, interval=0.05, 294 | stateful_metrics=None): 295 | self.target = target 296 | self.width = width 297 | self.verbose = verbose 298 | self.interval = interval 299 | if stateful_metrics: 300 | self.stateful_metrics = set(stateful_metrics) 301 | else: 302 | self.stateful_metrics = set() 303 | 304 | self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and 305 | sys.stdout.isatty()) or 306 | 'ipykernel' in sys.modules or 307 | 'posix' in sys.modules) 308 | self._total_width = 0 309 | self._seen_so_far = 0 310 | # We use a dict + list to avoid garbage collection 311 | # issues found in OrderedDict 312 | self._values = {} 313 | self._values_order = [] 314 | self._start = time.time() 315 | self._last_update = 0 316 | 317 | def update(self, current, values=None): 318 | """Updates the progress bar. 319 | 320 | Arguments: 321 | current: Index of current step. 322 | values: List of tuples: 323 | `(name, value_for_last_step)`. 324 | If `name` is in `stateful_metrics`, 325 | `value_for_last_step` will be displayed as-is. 326 | Else, an average of the metric over time will be displayed. 327 | """ 328 | values = values or [] 329 | for k, v in values: 330 | if k not in self._values_order: 331 | self._values_order.append(k) 332 | if k not in self.stateful_metrics: 333 | if k not in self._values: 334 | self._values[k] = [v * (current - self._seen_so_far), 335 | current - self._seen_so_far] 336 | else: 337 | self._values[k][0] += v * (current - self._seen_so_far) 338 | self._values[k][1] += (current - self._seen_so_far) 339 | else: 340 | self._values[k] = v 341 | self._seen_so_far = current 342 | 343 | now = time.time() 344 | info = ' - %.0fs' % (now - self._start) 345 | if self.verbose == 1: 346 | if (now - self._last_update < self.interval and 347 | self.target is not None and current < self.target): 348 | return 349 | 350 | prev_total_width = self._total_width 351 | if self._dynamic_display: 352 | sys.stdout.write('\b' * prev_total_width) 353 | sys.stdout.write('\r') 354 | else: 355 | sys.stdout.write('\n') 356 | 357 | if self.target is not None: 358 | numdigits = int(np.floor(np.log10(self.target))) + 1 359 | barstr = '%%%dd/%d [' % (numdigits, self.target) 360 | bar = barstr % current 361 | prog = float(current) / self.target 362 | prog_width = int(self.width * prog) 363 | if prog_width > 0: 364 | bar += ('=' * (prog_width - 1)) 365 | if current < self.target: 366 | bar += '>' 367 | else: 368 | bar += '=' 369 | bar += ('.' * (self.width - prog_width)) 370 | bar += ']' 371 | else: 372 | bar = '%7d/Unknown' % current 373 | 374 | self._total_width = len(bar) 375 | sys.stdout.write(bar) 376 | 377 | if current: 378 | time_per_unit = (now - self._start) / current 379 | else: 380 | time_per_unit = 0 381 | if self.target is not None and current < self.target: 382 | eta = time_per_unit * (self.target - current) 383 | if eta > 3600: 384 | eta_format = '%d:%02d:%02d' % (eta // 3600, 385 | (eta % 3600) // 60, 386 | eta % 60) 387 | elif eta > 60: 388 | eta_format = '%d:%02d' % (eta // 60, eta % 60) 389 | else: 390 | eta_format = '%ds' % eta 391 | 392 | info = ' - ETA: %s' % eta_format 393 | else: 394 | if time_per_unit >= 1: 395 | info += ' %.0fs/step' % time_per_unit 396 | elif time_per_unit >= 1e-3: 397 | info += ' %.0fms/step' % (time_per_unit * 1e3) 398 | else: 399 | info += ' %.0fus/step' % (time_per_unit * 1e6) 400 | 401 | for k in self._values_order: 402 | info += ' - %s:' % k 403 | if isinstance(self._values[k], list): 404 | avg = np.mean( 405 | self._values[k][0] / max(1, self._values[k][1])) 406 | if abs(avg) > 1e-3: 407 | info += ' %.4f' % avg 408 | else: 409 | info += ' %.4e' % avg 410 | else: 411 | info += ' %s' % self._values[k] 412 | 413 | self._total_width += len(info) 414 | if prev_total_width > self._total_width: 415 | info += (' ' * (prev_total_width - self._total_width)) 416 | 417 | if self.target is not None and current >= self.target: 418 | info += '\n' 419 | 420 | sys.stdout.write(info) 421 | sys.stdout.flush() 422 | 423 | elif self.verbose == 2: 424 | if self.target is None or current >= self.target: 425 | for k in self._values_order: 426 | info += ' - %s:' % k 427 | avg = np.mean( 428 | self._values[k][0] / max(1, self._values[k][1])) 429 | if avg > 1e-3: 430 | info += ' %.4f' % avg 431 | else: 432 | info += ' %.4e' % avg 433 | info += '\n' 434 | 435 | sys.stdout.write(info) 436 | sys.stdout.flush() 437 | 438 | self._last_update = now 439 | 440 | def add(self, n, values=None): 441 | self.update(self._seen_so_far + n, values) 442 | -------------------------------------------------------------------------------- /src/model_top.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from shutil import copyfile 4 | 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms.functional as F 8 | import yaml 9 | from torch.utils.data import DataLoader 10 | 11 | from .dataset import Dataset 12 | from .metrics import Metrics 13 | from .models import Model 14 | from .utils import (Progbar, create_dir, imsave, imshow, save_config, 15 | stitch_images) 16 | 17 | 18 | class ModelTop(): 19 | def __init__(self, config): 20 | # config 21 | self.config = config 22 | if config.DEBUG == 1: 23 | self.debug = True 24 | else: 25 | self.debug = False 26 | self.model_name = config.MODEL_NAME 27 | self.RESULTS_SAMPLE = self.config.RESULTS_SAMPLE 28 | # model 29 | self.model = Model(config).to(config.DEVICE) 30 | # eval 31 | self.metrics = Metrics().to(config.DEVICE) 32 | # dataset 33 | if config.MODE == 3: # test 34 | self.test_dataset = Dataset( 35 | config, config.DATA_ROOT+config.TEST_FLIST, config.DATA_ROOT+config.TEST_MASK_FLIST, config.DATA_ROOT+config.TEST_GT_FLIST, augment=False) 36 | elif config.MODE == 4: # eval 37 | self.val_dataset = Dataset( 38 | config, config.DATA_ROOT+config.VAL_FLIST, config.DATA_ROOT+config.VAL_MASK_FLIST, config.DATA_ROOT+config.VAL_GT_FLIST, augment=False) 39 | elif config.MODE == 5: # eval 40 | self.val_dataset = Dataset( 41 | config, config.DATA_ROOT+config.TEST_FLIST, config.DATA_ROOT+config.TEST_MASK_FLIST, config.DATA_ROOT+config.TEST_GT_FLIST, augment=False) 42 | else: 43 | if config.MODE == 0: 44 | self.train_dataset = Dataset( 45 | config, config.DATA_ROOT+config.TRAIN_FLIST_PRE, config.DATA_ROOT+config.TRAIN_MASK_FLIST_PRE, config.DATA_ROOT+config.TRAIN_GT_FLIST_PRE, augment=True) 46 | self.val_dataset = Dataset( 47 | config, config.DATA_ROOT+config.VAL_FLIST_PRE, config.DATA_ROOT+config.VAL_MASK_FLIST_PRE, config.DATA_ROOT+config.VAL_GT_FLIST_PRE, augment=True) 48 | elif config.MODE == 1: 49 | self.train_dataset = Dataset( 50 | config, config.DATA_ROOT+config.TRAIN_FLIST_PRE, config.DATA_ROOT+config.TRAIN_MASK_FLIST_PRE, config.DATA_ROOT+config.TRAIN_GT_FLIST_PRE, augment=False) 51 | self.val_dataset = Dataset( 52 | config, config.DATA_ROOT+config.VAL_FLIST_PRE, config.DATA_ROOT+config.VAL_MASK_FLIST_PRE, config.DATA_ROOT+config.VAL_GT_FLIST_PRE, augment=False) 53 | elif config.MODE == 2: 54 | self.train_dataset = Dataset( 55 | config, config.DATA_ROOT+config.TRAIN_FLIST, config.DATA_ROOT+config.TRAIN_MASK_FLIST, config.DATA_ROOT+config.TRAIN_GT_FLIST, augment=False) 56 | self.val_dataset = Dataset( 57 | config, config.DATA_ROOT+config.VAL_FLIST, config.DATA_ROOT+config.VAL_MASK_FLIST, config.DATA_ROOT+config.VAL_GT_FLIST, augment=False) 58 | self.sample_iterator = self.val_dataset.create_iterator( 59 | config.SAMPLE_SIZE) 60 | 61 | # path 62 | self.samples_path = os.path.join(config.PATH, 'samples') 63 | self.results_path = os.path.join(config.PATH, 'results') 64 | self.backups_path = os.path.join(config.PATH, 'backups') 65 | self.results_samples_path = os.path.join(self.results_path, 'samples') 66 | if self.config.BACKUP: 67 | create_dir(self.backups_path) 68 | if config.RESULTS is not None: 69 | self.results_path = os.path.join(config.RESULTS) 70 | create_dir("./pre_train_model/"+self.config.SUBJECT_WORD) 71 | # load file 72 | self.log_file = os.path.join( 73 | config.PATH, 'log_' + self.model_name + '.dat') 74 | 75 | # avoid overfitting 76 | if config.MODE < 3: 77 | data_save_path = os.path.join( 78 | self.config.PATH, 'log_eval_val_ap_id.txt') 79 | if os.path.exists(data_save_path): 80 | self.eval_val_ap_id = np.genfromtxt( 81 | data_save_path, dtype=np.str, encoding='utf-8').astype(np.float) 82 | if config.APPEND == 0: 83 | self.eval_val_ap_id[config.MODE] = 0 84 | else: 85 | self.eval_val_ap_id = [0.0, 0.0, 0.0] 86 | 87 | data_save_path = os.path.join( 88 | self.config.PATH, 'final_model_epoch.txt') 89 | if os.path.exists(data_save_path): 90 | self.epoch = np.genfromtxt( 91 | data_save_path, dtype=np.str, encoding='utf-8').astype(np.float).astype(np.int) 92 | if config.APPEND == 0: 93 | self.epoch[config.MODE, 0] = 0 94 | else: 95 | self.epoch = np.zeros((3, 2)).astype(int) 96 | 97 | data_save_path = os.path.join( 98 | self.config.PATH, 'log_eval_val_ap_'+str(self.config.MODE)+'.txt') 99 | if os.path.exists(data_save_path): 100 | self.eval_val_ap = np.genfromtxt( 101 | data_save_path, dtype=np.str, encoding='utf-8').astype(np.float) 102 | if config.APPEND == 0: 103 | if config.FORCE_EXIT == 0: 104 | if config.MODE == 0 or config.MODE == 1: 105 | self.eval_val_ap = np.ones( 106 | (config.PRE_TRAIN_EVAL_LEN, 2))*1e6 107 | elif config.MODE == 2: 108 | self.eval_val_ap = np.ones( 109 | (config.TRAIN_EVAL_LEN, 2))*1e6 110 | else: 111 | self.eval_val_ap = np.ones((config.ALL_EPOCH, 2))*1e6 112 | else: 113 | if config.FORCE_EXIT == 0: 114 | if config.MODE == 0 or config.MODE == 1: 115 | self.eval_val_ap = np.ones( 116 | (config.PRE_TRAIN_EVAL_LEN, 2))*1e6 117 | elif config.MODE == 2: 118 | self.eval_val_ap = np.ones( 119 | (config.TRAIN_EVAL_LEN, 2))*1e6 120 | else: 121 | self.eval_val_ap = np.ones((config.ALL_EPOCH, 2))*1e6 122 | 123 | # lr scheduler 124 | self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 125 | self.model.network_instance.optimizer, 'min', factor=0.5, patience=0, min_lr=1e-5) 126 | if config.ADV: 127 | self.scheduler_dis = [torch.optim.lr_scheduler.ReduceLROnPlateau( 128 | discriminator.optimizer, 'min', factor=0.5, patience=0, min_lr=5e-6) for discriminator in self.model.network_instance.discriminator] 129 | 130 | def load(self): 131 | self.model.load() 132 | 133 | def save(self): 134 | self.model.save() 135 | 136 | def train(self): 137 | # initial 138 | self.restart_train_check_lr_scheduler() 139 | train_loader = DataLoader( 140 | dataset=self.train_dataset, 141 | batch_size=self.config.BATCH_SIZE, 142 | num_workers=4, 143 | drop_last=True, 144 | shuffle=True 145 | ) 146 | keep_training = True 147 | mode = self.config.MODE 148 | max_iteration = int(float((self.config.MAX_ITERS))) 149 | max_epoch = self.config.EPOCHLIST[self.config.EPOCHLIST[-1]]+1 150 | total = len(self.train_dataset) 151 | self.TRAIN_DATA_NUM = total 152 | if total == 0: 153 | print( 154 | 'No training data was provided! Check \'TRAIN_FLIST\' value in the configuration file.') 155 | return 156 | # train 157 | while(keep_training): 158 | # epoch 159 | epoch = self.epoch[self.config.MODE, 0] 160 | epoch += 1 161 | self.epoch[self.config.MODE, 0] = epoch 162 | print('\n\nTraining epoch: %d' % epoch) 163 | # progbar 164 | progbar = Progbar( 165 | total, width=20, stateful_metrics=['epoch', 'iter']) 166 | 167 | for items in train_loader: 168 | # initial 169 | self.model.train() 170 | 171 | # get data 172 | images, mask, GT = self.cuda( 173 | *items) 174 | # imshow(F.to_pil_image((outputs)[0,:,:,:].cpu())) 175 | # train 176 | outputs, loss, logs = self.model.process( 177 | images, mask, GT) 178 | # backward 179 | self.model.backward(loss) 180 | iteration = self.model.iteration 181 | 182 | # log-epoch, iteration 183 | logs = [ 184 | ("epoch", epoch), 185 | ("iter", iteration), 186 | ] + logs 187 | # progbar 188 | progbar.add(len(images), values=logs if self.config.VERBOSE else [ 189 | x for x in logs if not x[0].startswith('l_')]) 190 | 191 | # log model at checkpoints 192 | if self.config.LOG_INTERVAL and iteration % self.config.LOG_INTERVAL == 0: 193 | self.log(logs) 194 | # sample model at checkpoints 195 | if self.config.SAMPLE_INTERVAL and iteration % self.config.SAMPLE_INTERVAL == 0: 196 | self.sample() 197 | # save model at checkpoints 198 | if self.config.SAVE_INTERVAL and iteration % self.config.SAVE_INTERVAL == 0: 199 | self.save() 200 | # force_exit 201 | if epoch >= max_epoch: # if iteration >= max_iteration: 202 | force_exit = True 203 | else: 204 | force_exit = False 205 | # end condition 206 | if force_exit: 207 | keep_training = False 208 | self.force_exit() 209 | print('\n ***force_exit: max iteration***') 210 | break 211 | # evaluate model at checkpoints 212 | if (self.config.EVAL_INTERVAL and iteration % self.config.EVAL_INTERVAL == 0): 213 | print('\nstart eval...\n') 214 | pre_exit, ap = self.eval() 215 | self.scheduler.step(ap) 216 | if self.config.ADV: 217 | for scheduler_dis in self.scheduler_dis: 218 | scheduler_dis.step(ap) 219 | 220 | with open(self.config.CONFIG_PATH, 'r') as f_obj: 221 | config = yaml.load(f_obj, Loader=yaml.FullLoader) 222 | config['LR'] = self.scheduler.optimizer.param_groups[0]['lr'] 223 | config['LR_D'] = self.scheduler_dis[0].optimizer.param_groups[0]['lr'] 224 | save_config(config, self.config.CONFIG_PATH) 225 | else: 226 | pre_exit = False 227 | # debug 228 | if self.debug: 229 | if iteration >= 40: 230 | if self.config.MODE == 0: 231 | force_exit = True 232 | copyfile(Path('checkpoints')/self.config.SUBJECT_WORD/(self.config.MODEL_NAME+'.pth'), 233 | Path('./pre_train_model')/self.config.SUBJECT_WORD/(self.config.MODEL_NAME+'_pre_da.pth')) 234 | if self.config.MODE == 1: 235 | force_exit = True 236 | copyfile(Path('checkpoints')/self.config.SUBJECT_WORD/(self.config.MODEL_NAME+'.pth'), 237 | Path('./pre_train_model')/self.config.SUBJECT_WORD/(self.config.MODEL_NAME+'_pre_no_da.pth')) 238 | # end condition 239 | if pre_exit: 240 | keep_training = False 241 | break 242 | 243 | print('\nEnd training....') 244 | 245 | def eval(self): 246 | # torch.cuda.empty_cache() 247 | if self.config.MODE == 4 or self.config.MODE == 5: 248 | BATCH_SIZE = self.config.BATCH_SIZE # *8 249 | num_workers = 4 # 8 250 | else: 251 | BATCH_SIZE = self.config.BATCH_SIZE 252 | num_workers = 4 253 | val_loader = DataLoader( 254 | dataset=self.val_dataset, 255 | batch_size=BATCH_SIZE, 256 | num_workers=num_workers, 257 | drop_last=False, 258 | shuffle=False 259 | ) 260 | total = len(self.val_dataset) 261 | self.metrics.multiprocessingi_utils.creat_pool() 262 | 263 | self.model.eval() 264 | progbar = Progbar(total, width=20, stateful_metrics=['it']) 265 | iteration = 0 266 | log_eval_PR = [[0], [0]] 267 | n_thresh = 99 268 | # zero all counts 269 | rmse_shadow = torch.Tensor(0).cuda() 270 | n_pxl_shadow = torch.Tensor(0).cuda() 271 | rmse_non_shadow = torch.Tensor(0).cuda() 272 | n_pxl_non_shadow = torch.Tensor(0).cuda() 273 | rmse_all = torch.Tensor(0).cuda() 274 | n_pxl_all = torch.Tensor(0).cuda() 275 | # eval each image 276 | with torch.no_grad(): 277 | for items in val_loader: 278 | iteration += 1 279 | images, mask, GT = self.cuda( 280 | *items) 281 | # eval 282 | outputs = self.model(images) 283 | 284 | rmse_shadow_, n_pxl_shadow_, rmse_non_shadow_, n_pxl_non_shadow_, rmse_all_, n_pxl_all_ = self.metrics.rmse( 285 | outputs[-1], mask, GT, dataset_mode=1) 286 | 287 | rmse_shadow = torch.cat((rmse_shadow, rmse_shadow_), dim=0) 288 | n_pxl_shadow = torch.cat((n_pxl_shadow, n_pxl_shadow_), dim=0) 289 | rmse_non_shadow = torch.cat( 290 | (rmse_non_shadow, rmse_non_shadow_), dim=0) 291 | n_pxl_non_shadow = torch.cat( 292 | (n_pxl_non_shadow, n_pxl_non_shadow_), dim=0) 293 | rmse_all = torch.cat((rmse_all, rmse_all_), dim=0) 294 | n_pxl_all = torch.cat((n_pxl_all, n_pxl_all_), dim=0) 295 | 296 | if self.debug: 297 | if iteration == 10: 298 | break 299 | 300 | rmse_shadow_eval, rmse_non_shadow_eval, rmse_all_eval = self.metrics.collect_rmse( 301 | rmse_shadow, n_pxl_shadow, rmse_non_shadow, n_pxl_non_shadow, rmse_all, n_pxl_all) 302 | 303 | # print 304 | print('running rmse-shadow: %.4f, rmse-non-shadow: %.4f, rmse-all: %.4f' 305 | % (rmse_shadow_eval, rmse_non_shadow_eval, rmse_all_eval)) 306 | data_save_path = os.path.join( 307 | self.config.PATH, 'show_result.txt') 308 | np.savetxt(data_save_path, [ 309 | rmse_shadow_eval, rmse_non_shadow_eval, rmse_all_eval], fmt='%s') 310 | 311 | # avoid overfitting (pre_train and train) 312 | if self.config.MODE == 0 or self.config.MODE == 1 or self.config.MODE == 2: 313 | if self.config.FORCE_EXIT == 0: 314 | if rmse_all_eval < np.mean(self.eval_val_ap[:, 0]): 315 | exit_ = False 316 | else: 317 | exit_ = True 318 | else: 319 | exit_ = False 320 | 321 | self.eval_val_ap = np.delete(self.eval_val_ap, -1, axis=0) 322 | self.eval_val_ap = np.append( 323 | [[rmse_all_eval, self.eval_val_ap_id[self.config.MODE]]], self.eval_val_ap, axis=0) 324 | 325 | model_save_path = "model_save_mode_"+str(self.config.MODE) 326 | model_save_path = Path(self.config.PATH)/model_save_path 327 | create_dir(model_save_path) 328 | self.model.weights_path = os.path.join( 329 | model_save_path, str(self.eval_val_ap_id[self.config.MODE]) + '.pth') 330 | self.save() 331 | self.model.weights_path = os.path.join( 332 | self.config.PATH, self.model_name + '.pth') 333 | 334 | data_save_path = os.path.join( 335 | self.config.PATH, 'final_model_epoch.txt') 336 | np.savetxt(data_save_path, self.epoch, fmt='%s') 337 | if self.eval_val_ap_id[self.config.MODE] == (len(self.eval_val_ap)-1): 338 | self.eval_val_ap_id[self.config.MODE] = 0.0 339 | else: 340 | self.eval_val_ap_id[self.config.MODE] += 1.0 341 | 342 | data_save_path = os.path.join( 343 | self.config.PATH, 'log_eval_val_ap_'+str(self.config.MODE)+'.txt') 344 | np.savetxt(data_save_path, self.eval_val_ap, fmt='%s') 345 | data_save_path = os.path.join( 346 | self.config.PATH, 'log_eval_val_ap_id.txt') 347 | np.savetxt(data_save_path, self.eval_val_ap_id, fmt='%s') 348 | 349 | if exit_: 350 | idmin = np.array(self.eval_val_ap[:, 0]).argmin() 351 | 352 | data_save_path = os.path.join( 353 | self.config.PATH, 'final_model_epoch.txt') 354 | if self.config.EVAL_INTERVAL_EPOCH < 1: 355 | self.epoch[self.config.MODE, 356 | 1] = self.eval_val_ap[idmin, 1] 357 | else: 358 | self.epoch[self.config.MODE, 359 | 1] = self.epoch[self.config.MODE, 0]-idmin-1 360 | print('final model id:'+str(self.epoch[self.config.MODE, 1])) 361 | np.savetxt(data_save_path, self.epoch, fmt='%s') 362 | 363 | pre_train_save_path = "./pre_train_model/"+self.config.SUBJECT_WORD 364 | if os.path.exists(os.path.join(model_save_path, str(self.eval_val_ap[idmin, 1]) + '.pth')): 365 | if self.config.MODE == 0: 366 | PATH_WEIDHT = os.path.join( 367 | pre_train_save_path, self.model_name + '_pre_da.pth') 368 | elif self.config.MODE == 1: 369 | PATH_WEIDHT = os.path.join( 370 | pre_train_save_path, self.model_name + '_pre_no_da.pth') 371 | elif self.config.MODE == 2: 372 | PATH_WEIDHT = os.path.join( 373 | pre_train_save_path, self.model_name + '_final.pth') 374 | copyfile(os.path.join(model_save_path, str( 375 | self.eval_val_ap[idmin, 1]) + '.pth'), PATH_WEIDHT) 376 | print(os.path.join(model_save_path, str( 377 | self.eval_val_ap[idmin, 1]) + '.pth')+" copy to "+PATH_WEIDHT) 378 | pre_model_save = torch.load(PATH_WEIDHT) 379 | torch.save( 380 | {'iteration': 0, 'model': pre_model_save['model']}, PATH_WEIDHT) 381 | 382 | copyfile(PATH_WEIDHT, os.path.join( 383 | self.config.PATH, self.model_name + '.pth')) 384 | return exit_, rmse_all_eval 385 | 386 | def force_exit(self): 387 | model_save_path = "model_save_mode_"+str(self.config.MODE) 388 | model_save_path = Path(self.config.PATH)/model_save_path 389 | 390 | idmin = np.array(self.eval_val_ap[:, 0]).argmin() 391 | 392 | data_save_path = os.path.join( 393 | self.config.PATH, 'final_model_epoch.txt') 394 | 395 | self.epoch[self.config.MODE,0] = self.epoch[self.config.MODE,0]-1 396 | if self.config.EVAL_INTERVAL_EPOCH < 1: 397 | self.epoch[self.config.MODE, 398 | 1] = self.eval_val_ap[idmin, 1] 399 | else: 400 | self.epoch[self.config.MODE, 401 | 1] = self.epoch[self.config.MODE, 0]-idmin-1 402 | print('\n\nfinal model id:'+str(self.epoch[self.config.MODE, 1])) 403 | np.savetxt(data_save_path, self.epoch, fmt='%s') 404 | 405 | pre_train_save_path = "./pre_train_model/"+self.config.SUBJECT_WORD 406 | if os.path.exists(os.path.join(model_save_path, str(self.eval_val_ap[idmin, 1]) + '.pth')): 407 | if self.config.MODE == 0: 408 | PATH_WEIDHT = os.path.join( 409 | pre_train_save_path, self.model_name + '_pre_da.pth') 410 | elif self.config.MODE == 1: 411 | PATH_WEIDHT = os.path.join( 412 | pre_train_save_path, self.model_name + '_pre_no_da.pth') 413 | elif self.config.MODE == 2: 414 | PATH_WEIDHT = os.path.join( 415 | pre_train_save_path, self.model_name + '_final.pth') 416 | copyfile(os.path.join(model_save_path, str( 417 | self.eval_val_ap[idmin, 1]) + '.pth'), PATH_WEIDHT) 418 | print(os.path.join(model_save_path, str( 419 | self.eval_val_ap[idmin, 1]) + '.pth')+" copy to "+PATH_WEIDHT) 420 | pre_model_save = torch.load(PATH_WEIDHT) 421 | torch.save( 422 | {'iteration': 0, 'model': pre_model_save['model']}, PATH_WEIDHT) 423 | 424 | copyfile(PATH_WEIDHT, os.path.join( 425 | self.config.PATH, self.model_name + '.pth')) 426 | 427 | def test(self): 428 | # initial 429 | self.model.eval() 430 | if self.RESULTS_SAMPLE: 431 | save_path = os.path.join( 432 | self.results_samples_path, self.model_name) 433 | create_dir(save_path) 434 | else: 435 | save_path = os.path.join(self.results_path, self.model_name) 436 | create_dir(save_path) 437 | if self.debug: 438 | debug_path = os.path.join(save_path, "debug") 439 | create_dir(debug_path) 440 | save_path = debug_path 441 | test_loader = DataLoader( 442 | dataset=self.test_dataset, 443 | batch_size=1, 444 | ) 445 | # test 446 | index = 0 447 | with torch.no_grad(): 448 | for items in test_loader: 449 | name = self.test_dataset.load_name(index) 450 | index += 1 451 | images, mask, GT = self.cuda( 452 | *items) 453 | if self.RESULTS_SAMPLE: 454 | image_per_row = 2 455 | if self.config.SAMPLE_SIZE <= 6: 456 | image_per_row = 1 457 | outputs = self.model( 458 | images) 459 | i = 0 460 | for output in outputs: 461 | outputs[i] = self.postprocess(output) 462 | i += 1 463 | matte_gt = GT-images 464 | matte_gt = matte_gt - \ 465 | (matte_gt.min(dim=2, keepdim=True).values).min( 466 | dim=3, keepdim=True).values 467 | matte_gt = matte_gt / \ 468 | (matte_gt.max(dim=2, keepdim=True).values).max( 469 | dim=3, keepdim=True).values 470 | images = stitch_images( 471 | self.postprocess(images), 472 | outputs, 473 | self.postprocess(matte_gt), 474 | self.postprocess(GT), 475 | img_per_row=image_per_row, 476 | ) 477 | images.save(path) 478 | else: 479 | outputs = self.model(images) 480 | outputs = self.postprocess(outputs[-1])[0] 481 | path = os.path.join(save_path, name) 482 | imsave(outputs, path) 483 | # debug 484 | if self.debug: 485 | if index == 10: 486 | break 487 | print('\nEnd test....') 488 | 489 | def sample(self, it=None): 490 | # initial, do not sample when validation set is empty 491 | if len(self.val_dataset) == 0: 492 | return 493 | # torch.cuda.empty_cache() 494 | self.model.eval() 495 | iteration = self.model.iteration 496 | 497 | items = next(self.sample_iterator) 498 | images, mask, GT = self.cuda( 499 | *items) 500 | image_per_row = 2 501 | if self.config.SAMPLE_SIZE <= 6: 502 | image_per_row = 1 503 | outputs = self.model( 504 | images) 505 | i = 0 506 | for output in outputs: 507 | outputs[i] = self.postprocess(output) 508 | i += 1 509 | matte_gt = GT-images 510 | matte_gt = matte_gt - \ 511 | (matte_gt.min(dim=2, keepdim=True).values).min( 512 | dim=3, keepdim=True).values 513 | matte_gt = matte_gt / \ 514 | (matte_gt.max(dim=2, keepdim=True).values).max( 515 | dim=3, keepdim=True).values 516 | images = stitch_images( 517 | self.postprocess(images), 518 | outputs, 519 | self.postprocess(matte_gt), 520 | self.postprocess(GT), 521 | img_per_row=image_per_row, 522 | ) 523 | 524 | path = os.path.join(self.samples_path, self.model_name) 525 | name = os.path.join(path, "mode_"+str(self.config.MODE) + 526 | "_"+str(iteration).zfill(5) + ".png") 527 | create_dir(path) 528 | print('\nsaving sample ' + name) 529 | images.save(name) 530 | 531 | def log(self, logs): 532 | with open(self.log_file, 'a') as f: 533 | f.write('%s\n' % ' '.join([str(item[1]) for item in logs])) 534 | 535 | def cuda(self, *args): 536 | return (item.to(self.config.DEVICE) for item in args) 537 | 538 | def postprocess(self, img): 539 | # [0, 1] => [0, 255] 540 | img = img * 255.0 541 | img = img.permute(0, 2, 3, 1) 542 | return img.int() 543 | 544 | def restart_train_check_lr_scheduler(self): 545 | checkpoints_path = Path('./checkpoints') / self.config.SUBJECT_WORD 546 | log_eval_val_ap_path = Path(checkpoints_path) / \ 547 | ('log_eval_val_ap_'+str(self.config.MODE)+'.txt') 548 | 549 | if log_eval_val_ap_path.is_file(): 550 | eval_val_ap = np.genfromtxt( 551 | log_eval_val_ap_path, dtype=np.str, encoding='utf-8').astype(np.float) 552 | EPOCH = self.config.EPOCHLIST 553 | 554 | if EPOCH[-1]!=0 and eval_val_ap[0, 1] != EPOCH[EPOCH[-1]-1]-1: 555 | ap = str(eval_val_ap[0, 0]) 556 | self.scheduler.step(ap) 557 | if self.config.ADV: 558 | for scheduler_dis in self.scheduler_dis: 559 | scheduler_dis.step(ap) 560 | --------------------------------------------------------------------------------