├── PF-AFN_test ├── checkpoints │ └── .md ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── aligned_dataset_test.cpython-36.pyc │ │ ├── aligned_dataset_test.cpython-37.pyc │ │ ├── base_data_loader.cpython-36.pyc │ │ ├── base_data_loader.cpython-37.pyc │ │ ├── base_dataset.cpython-36.pyc │ │ ├── base_dataset.cpython-37.pyc │ │ ├── custom_dataset_data_loader_test.cpython-36.pyc │ │ ├── custom_dataset_data_loader_test.cpython-37.pyc │ │ ├── data_loader_test.cpython-36.pyc │ │ ├── data_loader_test.cpython-37.pyc │ │ ├── image_folder.cpython-36.pyc │ │ └── image_folder.cpython-37.pyc │ ├── aligned_dataset_test.py │ ├── base_data_loader.py │ ├── base_dataset.py │ ├── custom_dataset_data_loader_test.py │ ├── data_loader_test.py │ └── image_folder.py ├── dataset │ ├── test_clothes │ │ ├── 003434_1.jpg │ │ ├── 006026_1.jpg │ │ ├── 010567_1.jpg │ │ ├── 014396_1.jpg │ │ ├── 017575_1.jpg │ │ └── 019119_1.jpg │ ├── test_edge │ │ ├── 003434_1.jpg │ │ ├── 006026_1.jpg │ │ ├── 010567_1.jpg │ │ ├── 014396_1.jpg │ │ ├── 017575_1.jpg │ │ └── 019119_1.jpg │ └── test_img │ │ ├── 000066_0.jpg │ │ ├── 004912_0.jpg │ │ ├── 005510_0.jpg │ │ ├── 014834_0.jpg │ │ ├── 015794_0.jpg │ │ └── 016962_0.jpg ├── demo.txt ├── models │ ├── __pycache__ │ │ ├── afwm.cpython-36.pyc │ │ ├── afwm.cpython-37.pyc │ │ ├── networks.cpython-36.pyc │ │ └── networks.cpython-37.pyc │ ├── afwm.py │ ├── correlation │ │ ├── README.md │ │ ├── __pycache__ │ │ │ ├── correlation.cpython-36.pyc │ │ │ └── correlation.cpython-37.pyc │ │ └── correlation.py │ └── networks.py ├── options │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── base_options.cpython-36.pyc │ │ ├── base_options.cpython-37.pyc │ │ ├── test_options.cpython-36.pyc │ │ └── test_options.cpython-37.pyc │ ├── base_options.py │ └── test_options.py ├── results │ └── demo │ │ └── PFAFN │ │ ├── 0.jpg │ │ ├── 1.jpg │ │ ├── 2.jpg │ │ ├── 3.jpg │ │ ├── 4.jpg │ │ └── 5.jpg ├── test.py ├── test.sh └── util │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── util.cpython-36.pyc │ └── util.cpython-37.pyc │ ├── image_pool.py │ └── util.py ├── PF-AFN_train ├── checkpoints │ └── readme.txt ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── aligned_dataset.cpython-36.pyc │ │ ├── aligned_dataset.cpython-37.pyc │ │ ├── aligned_dataset_fake.cpython-36.pyc │ │ ├── aligned_dataset_test.cpython-36.pyc │ │ ├── base_data_loader.cpython-36.pyc │ │ ├── base_dataset.cpython-36.pyc │ │ ├── base_dataset.cpython-37.pyc │ │ ├── custom_dataset_data_loader.cpython-36.pyc │ │ ├── custom_dataset_data_loader_test.cpython-36.pyc │ │ ├── data_loader.cpython-36.pyc │ │ ├── data_loader_test.cpython-36.pyc │ │ ├── image_folder.cpython-36.pyc │ │ └── image_folder.cpython-37.pyc │ ├── aligned_dataset.py │ ├── base_data_loader.py │ ├── base_dataset.py │ ├── custom_dataset_data_loader.py │ ├── data_loader.py │ └── image_folder.py ├── dataset │ └── readme.txt ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── afwm.cpython-37.pyc │ │ ├── base_model.cpython-36.pyc │ │ ├── flow_gmm.cpython-36.pyc │ │ ├── flow_gmm_add.cpython-36.pyc │ │ ├── flow_gmm_cor.cpython-36.pyc │ │ ├── flow_gmm_cor_add.cpython-36.pyc │ │ ├── flow_gmm_cor_more.cpython-36.pyc │ │ ├── flow_gmm_cor_more_add.cpython-36.pyc │ │ ├── flow_gmm_cor_more_feat.cpython-36.pyc │ │ ├── flow_gmm_cor_more_grid_offset_sep.cpython-36.pyc │ │ ├── flow_gmm_cor_more_grid_sep.cpython-36.pyc │ │ ├── flow_gmm_cor_more_new.cpython-36.pyc │ │ ├── flow_gmm_cor_more_offset.cpython-36.pyc │ │ ├── flow_gmm_cor_more_revise.cpython-36.pyc │ │ ├── flow_gmm_cor_more_revise_new.cpython-36.pyc │ │ ├── flow_gmm_cor_more_revise_new_sep.cpython-36.pyc │ │ ├── flow_gmm_cor_more_revise_new_sep_all.cpython-36.pyc │ │ ├── flow_gmm_cor_more_revise_new_sep_all_more.cpython-36.pyc │ │ ├── flow_gmm_cor_more_revise_new_sep_all_more_heatmap.cpython-36.pyc │ │ ├── flow_gmm_cor_more_revise_new_sep_all_more_no_refine.cpython-36.pyc │ │ ├── flow_gmm_cor_more_revise_new_sep_all_more_trans.cpython-36.pyc │ │ ├── flow_gmm_cor_more_sep.cpython-36.pyc │ │ ├── flow_gmm_cor_sep.cpython-36.pyc │ │ ├── flow_gmm_smooth.cpython-36.pyc │ │ ├── flow_gmm_vis.cpython-36.pyc │ │ ├── models.cpython-36.pyc │ │ ├── networks.cpython-36.pyc │ │ ├── networks.cpython-37.pyc │ │ ├── networks_flow.cpython-36.pyc │ │ ├── pix2pixHD_model.cpython-36.pyc │ │ └── predict_mask.cpython-36.pyc │ ├── afwm.py │ ├── correlation │ │ ├── README.md │ │ ├── __pycache__ │ │ │ ├── correlation.cpython-36.pyc │ │ │ └── correlation.cpython-37.pyc │ │ └── correlation.py │ └── networks.py ├── options │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── base_options.cpython-36.pyc │ │ ├── base_options.cpython-37.pyc │ │ ├── test_options.cpython-36.pyc │ │ ├── train_options.cpython-36.pyc │ │ └── train_options.cpython-37.pyc │ ├── base_options.py │ └── train_options.py ├── runs │ └── readme.txt ├── sample │ └── readme.txt ├── scripts │ ├── train_PBAFN_e2e.sh │ ├── train_PBAFN_stage1.sh │ ├── train_PFAFN_e2e.sh │ └── train_PFAFN_stage1.sh ├── train_PBAFN_e2e.py ├── train_PBAFN_stage1.py ├── train_PFAFN_e2e.py ├── train_PFAFN_stage1.py └── util │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── image_pool.cpython-36.pyc │ ├── util.cpython-36.pyc │ └── util.cpython-37.pyc │ ├── image_pool.py │ └── util.py ├── PFAFN_supp.pdf ├── README.md └── show ├── compare.jpg └── compare_both.jpg /PF-AFN_test/checkpoints/.md: -------------------------------------------------------------------------------- 1 | Please put the downloaded checkpoints folder "PFAFN" under the folder "checkpoints". 2 | -------------------------------------------------------------------------------- /PF-AFN_test/data/__init__.py: -------------------------------------------------------------------------------- 1 | # data_init -------------------------------------------------------------------------------- /PF-AFN_test/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_test/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_test/data/__pycache__/aligned_dataset_test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/aligned_dataset_test.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_test/data/__pycache__/aligned_dataset_test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/aligned_dataset_test.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_test/data/__pycache__/base_data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/base_data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_test/data/__pycache__/base_data_loader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/base_data_loader.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_test/data/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_test/data/__pycache__/base_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/base_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_test/data/__pycache__/custom_dataset_data_loader_test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/custom_dataset_data_loader_test.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_test/data/__pycache__/custom_dataset_data_loader_test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/custom_dataset_data_loader_test.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_test/data/__pycache__/data_loader_test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/data_loader_test.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_test/data/__pycache__/data_loader_test.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/data_loader_test.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_test/data/__pycache__/image_folder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/image_folder.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_test/data/__pycache__/image_folder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/image_folder.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_test/data/aligned_dataset_test.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_params, get_transform 3 | from PIL import Image 4 | import linecache 5 | 6 | class AlignedDataset(BaseDataset): 7 | def initialize(self, opt): 8 | self.opt = opt 9 | self.root = opt.dataroot 10 | 11 | self.fine_height=256 12 | self.fine_width=192 13 | 14 | self.dataset_size = len(open('demo.txt').readlines()) 15 | 16 | dir_I = '_img' 17 | self.dir_I = os.path.join(opt.dataroot, opt.phase + dir_I) 18 | 19 | dir_C = '_clothes' 20 | self.dir_C = os.path.join(opt.dataroot, opt.phase + dir_C) 21 | 22 | dir_E = '_edge' 23 | self.dir_E = os.path.join(opt.dataroot, opt.phase + dir_E) 24 | 25 | def __getitem__(self, index): 26 | 27 | file_path ='demo.txt' 28 | im_name, c_name = linecache.getline(file_path, index+1).strip().split() 29 | 30 | I_path = os.path.join(self.dir_I,im_name) 31 | I = Image.open(I_path).convert('RGB') 32 | 33 | params = get_params(self.opt, I.size) 34 | transform = get_transform(self.opt, params) 35 | transform_E = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) 36 | 37 | I_tensor = transform(I) 38 | 39 | C_path = os.path.join(self.dir_C,c_name) 40 | C = Image.open(C_path).convert('RGB') 41 | C_tensor = transform(C) 42 | 43 | E_path = os.path.join(self.dir_E,c_name) 44 | E = Image.open(E_path).convert('L') 45 | E_tensor = transform_E(E) 46 | 47 | input_dict = { 'image': I_tensor,'clothes': C_tensor, 'edge': E_tensor} 48 | return input_dict 49 | 50 | def __len__(self): 51 | return self.dataset_size 52 | 53 | def name(self): 54 | return 'AlignedDataset' 55 | -------------------------------------------------------------------------------- /PF-AFN_test/data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataLoader(): 3 | def __init__(self): 4 | pass 5 | 6 | def initialize(self, opt): 7 | self.opt = opt 8 | pass 9 | 10 | def load_data(): 11 | return None 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /PF-AFN_test/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import random 6 | 7 | class BaseDataset(data.Dataset): 8 | def __init__(self): 9 | super(BaseDataset, self).__init__() 10 | 11 | def name(self): 12 | return 'BaseDataset' 13 | 14 | def initialize(self, opt): 15 | pass 16 | 17 | def get_params(opt, size): 18 | w, h = size 19 | new_h = h 20 | new_w = w 21 | if opt.resize_or_crop == 'resize_and_crop': 22 | new_h = new_w = opt.loadSize 23 | elif opt.resize_or_crop == 'scale_width_and_crop': 24 | new_w = opt.loadSize 25 | new_h = opt.loadSize * h // w 26 | 27 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) 28 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) 29 | 30 | flip = 0 31 | return {'crop_pos': (x, y), 'flip': flip} 32 | 33 | def get_transform_resize(opt, params, method=Image.BICUBIC, normalize=True): 34 | transform_list = [] 35 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) 36 | osize = [256,192] 37 | transform_list.append(transforms.Scale(osize, method)) 38 | if 'crop' in opt.resize_or_crop: 39 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) 40 | 41 | if opt.resize_or_crop == 'none': 42 | base = float(2 ** opt.n_downsample_global) 43 | if opt.netG == 'local': 44 | base *= (2 ** opt.n_local_enhancers) 45 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 46 | 47 | if opt.isTrain and not opt.no_flip: 48 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 49 | 50 | transform_list += [transforms.ToTensor()] 51 | 52 | if normalize: 53 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 54 | (0.5, 0.5, 0.5))] 55 | return transforms.Compose(transform_list) 56 | 57 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True): 58 | transform_list = [] 59 | if 'resize' in opt.resize_or_crop: 60 | osize = [opt.loadSize, opt.loadSize] 61 | transform_list.append(transforms.Scale(osize, method)) 62 | elif 'scale_width' in opt.resize_or_crop: 63 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) 64 | osize = [256,192] 65 | transform_list.append(transforms.Scale(osize, method)) 66 | if 'crop' in opt.resize_or_crop: 67 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) 68 | 69 | if opt.resize_or_crop == 'none': 70 | base = float(16) 71 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 72 | 73 | if opt.isTrain and not opt.no_flip: 74 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 75 | 76 | transform_list += [transforms.ToTensor()] 77 | 78 | if normalize: 79 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 80 | (0.5, 0.5, 0.5))] 81 | return transforms.Compose(transform_list) 82 | 83 | def normalize(): 84 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 85 | 86 | def __make_power_2(img, base, method=Image.BICUBIC): 87 | ow, oh = img.size 88 | h = int(round(oh / base) * base) 89 | w = int(round(ow / base) * base) 90 | if (h == oh) and (w == ow): 91 | return img 92 | return img.resize((w, h), method) 93 | 94 | def __scale_width(img, target_width, method=Image.BICUBIC): 95 | ow, oh = img.size 96 | if (ow == target_width): 97 | return img 98 | w = target_width 99 | h = int(target_width * oh / ow) 100 | return img.resize((w, h), method) 101 | 102 | def __crop(img, pos, size): 103 | ow, oh = img.size 104 | x1, y1 = pos 105 | tw = th = size 106 | if (ow > tw or oh > th): 107 | return img.crop((x1, y1, x1 + tw, y1 + th)) 108 | return img 109 | 110 | def __flip(img, flip): 111 | if flip: 112 | return img.transpose(Image.FLIP_LEFT_RIGHT) 113 | return img 114 | -------------------------------------------------------------------------------- /PF-AFN_test/data/custom_dataset_data_loader_test.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | from data.aligned_dataset_test import AlignedDataset 8 | dataset = AlignedDataset() 9 | 10 | print("dataset [%s] was created" % (dataset.name())) 11 | dataset.initialize(opt) 12 | return dataset 13 | 14 | class CustomDatasetDataLoader(BaseDataLoader): 15 | def name(self): 16 | return 'CustomDatasetDataLoader' 17 | 18 | def initialize(self, opt): 19 | BaseDataLoader.initialize(self, opt) 20 | self.dataset = CreateDataset(opt) 21 | self.dataloader = torch.utils.data.DataLoader( 22 | self.dataset, 23 | batch_size=opt.batchSize, 24 | shuffle = False, 25 | num_workers=int(opt.nThreads)) 26 | 27 | def load_data(self): 28 | return self.dataloader 29 | 30 | def __len__(self): 31 | return min(len(self.dataset), self.opt.max_dataset_size) 32 | -------------------------------------------------------------------------------- /PF-AFN_test/data/data_loader_test.py: -------------------------------------------------------------------------------- 1 | 2 | def CreateDataLoader(opt): 3 | from data.custom_dataset_data_loader_test import CustomDatasetDataLoader 4 | data_loader = CustomDatasetDataLoader() 5 | print(data_loader.name()) 6 | data_loader.initialize(opt) 7 | return data_loader 8 | -------------------------------------------------------------------------------- /PF-AFN_test/data/image_folder.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | 5 | IMG_EXTENSIONS = [ 6 | '.jpg', '.JPG', '.jpeg', '.JPEG', 7 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 8 | ] 9 | 10 | 11 | def is_image_file(filename): 12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 13 | 14 | def make_dataset(dir): 15 | images = [] 16 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 17 | 18 | f = dir.split('/')[-1].split('_')[-1] 19 | print (dir, f) 20 | dirs= os.listdir(dir) 21 | for img in dirs: 22 | 23 | path = os.path.join(dir, img) 24 | #print(path) 25 | images.append(path) 26 | return images 27 | 28 | def make_dataset_test(dir): 29 | images = [] 30 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 31 | 32 | f = dir.split('/')[-1].split('_')[-1] 33 | for i in range(len([name for name in os.listdir(dir) if os.path.isfile(os.path.join(dir, name))])): 34 | if f == 'label' or f == 'labelref': 35 | img = str(i) + '.png' 36 | else: 37 | img = str(i) + '.jpg' 38 | path = os.path.join(dir, img) 39 | images.append(path) 40 | return images 41 | 42 | def default_loader(path): 43 | return Image.open(path).convert('RGB') 44 | 45 | 46 | class ImageFolder(data.Dataset): 47 | 48 | def __init__(self, root, transform=None, return_paths=False, 49 | loader=default_loader): 50 | imgs = make_dataset(root) 51 | if len(imgs) == 0: 52 | raise(RuntimeError("Found 0 images in: " + root + "\n" 53 | "Supported image extensions are: " + 54 | ",".join(IMG_EXTENSIONS))) 55 | 56 | self.root = root 57 | self.imgs = imgs 58 | self.transform = transform 59 | self.return_paths = return_paths 60 | self.loader = loader 61 | 62 | def __getitem__(self, index): 63 | path = self.imgs[index] 64 | img = self.loader(path) 65 | if self.transform is not None: 66 | img = self.transform(img) 67 | if self.return_paths: 68 | return img, path 69 | else: 70 | return img 71 | 72 | def __len__(self): 73 | return len(self.imgs) 74 | -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_clothes/003434_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_clothes/003434_1.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_clothes/006026_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_clothes/006026_1.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_clothes/010567_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_clothes/010567_1.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_clothes/014396_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_clothes/014396_1.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_clothes/017575_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_clothes/017575_1.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_clothes/019119_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_clothes/019119_1.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_edge/003434_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_edge/003434_1.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_edge/006026_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_edge/006026_1.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_edge/010567_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_edge/010567_1.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_edge/014396_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_edge/014396_1.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_edge/017575_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_edge/017575_1.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_edge/019119_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_edge/019119_1.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_img/000066_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_img/000066_0.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_img/004912_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_img/004912_0.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_img/005510_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_img/005510_0.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_img/014834_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_img/014834_0.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_img/015794_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_img/015794_0.jpg -------------------------------------------------------------------------------- /PF-AFN_test/dataset/test_img/016962_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_img/016962_0.jpg -------------------------------------------------------------------------------- /PF-AFN_test/demo.txt: -------------------------------------------------------------------------------- 1 | 000066_0.jpg 017575_1.jpg 2 | 016962_0.jpg 003434_1.jpg 3 | 004912_0.jpg 014396_1.jpg 4 | 005510_0.jpg 006026_1.jpg 5 | 014834_0.jpg 019119_1.jpg 6 | 015794_0.jpg 010567_1.jpg -------------------------------------------------------------------------------- /PF-AFN_test/models/__pycache__/afwm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/models/__pycache__/afwm.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_test/models/__pycache__/afwm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/models/__pycache__/afwm.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_test/models/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/models/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_test/models/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/models/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_test/models/afwm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .correlation import correlation 5 | 6 | def apply_offset(offset): 7 | 8 | sizes = list(offset.size()[2:]) 9 | grid_list = torch.meshgrid([torch.arange(size, device=offset.device) for size in sizes]) 10 | grid_list = reversed(grid_list) 11 | 12 | grid_list = [grid.float().unsqueeze(0) + offset[:, dim, ...] 13 | for dim, grid in enumerate(grid_list)] 14 | 15 | grid_list = [grid / ((size - 1.0) / 2.0) - 1.0 16 | for grid, size in zip(grid_list, reversed(sizes))] 17 | 18 | return torch.stack(grid_list, dim=-1) 19 | 20 | 21 | class ResBlock(nn.Module): 22 | def __init__(self, in_channels): 23 | super(ResBlock, self).__init__() 24 | self.block = nn.Sequential( 25 | nn.BatchNorm2d(in_channels), 26 | nn.ReLU(inplace=True), 27 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False), 28 | nn.BatchNorm2d(in_channels), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False) 31 | ) 32 | 33 | def forward(self, x): 34 | return self.block(x) + x 35 | 36 | 37 | class DownSample(nn.Module): 38 | def __init__(self, in_channels, out_channels): 39 | super(DownSample, self).__init__() 40 | self.block= nn.Sequential( 41 | nn.BatchNorm2d(in_channels), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False) 44 | ) 45 | 46 | def forward(self, x): 47 | return self.block(x) 48 | 49 | 50 | 51 | class FeatureEncoder(nn.Module): 52 | def __init__(self, in_channels, chns=[64,128,256,256,256]): 53 | super(FeatureEncoder, self).__init__() 54 | self.encoders = [] 55 | for i, out_chns in enumerate(chns): 56 | if i == 0: 57 | encoder = nn.Sequential(DownSample(in_channels, out_chns), 58 | ResBlock(out_chns), 59 | ResBlock(out_chns)) 60 | else: 61 | encoder = nn.Sequential(DownSample(chns[i-1], out_chns), 62 | ResBlock(out_chns), 63 | ResBlock(out_chns)) 64 | 65 | self.encoders.append(encoder) 66 | 67 | self.encoders = nn.ModuleList(self.encoders) 68 | 69 | 70 | def forward(self, x): 71 | encoder_features = [] 72 | for encoder in self.encoders: 73 | x = encoder(x) 74 | encoder_features.append(x) 75 | return encoder_features 76 | 77 | class RefinePyramid(nn.Module): 78 | def __init__(self, chns=[64,128,256,256,256], fpn_dim=256): 79 | super(RefinePyramid, self).__init__() 80 | self.chns = chns 81 | 82 | self.adaptive = [] 83 | for in_chns in list(reversed(chns)): 84 | adaptive_layer = nn.Conv2d(in_chns, fpn_dim, kernel_size=1) 85 | self.adaptive.append(adaptive_layer) 86 | self.adaptive = nn.ModuleList(self.adaptive) 87 | 88 | self.smooth = [] 89 | for i in range(len(chns)): 90 | smooth_layer = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, padding=1) 91 | self.smooth.append(smooth_layer) 92 | self.smooth = nn.ModuleList(self.smooth) 93 | 94 | def forward(self, x): 95 | conv_ftr_list = x 96 | 97 | feature_list = [] 98 | last_feature = None 99 | for i, conv_ftr in enumerate(list(reversed(conv_ftr_list))): 100 | feature = self.adaptive[i](conv_ftr) 101 | 102 | if last_feature is not None: 103 | feature = feature + F.interpolate(last_feature, scale_factor=2, mode='nearest') 104 | 105 | feature = self.smooth[i](feature) 106 | last_feature = feature 107 | feature_list.append(feature) 108 | 109 | return tuple(reversed(feature_list)) 110 | 111 | 112 | class AFlowNet(nn.Module): 113 | def __init__(self, num_pyramid, fpn_dim=256): 114 | super(AFlowNet, self).__init__() 115 | self.netMain = [] 116 | self.netRefine = [] 117 | for i in range(num_pyramid): 118 | netMain_layer = torch.nn.Sequential( 119 | torch.nn.Conv2d(in_channels=49, out_channels=128, kernel_size=3, stride=1, padding=1), 120 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 121 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), 122 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 123 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), 124 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 125 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1) 126 | ) 127 | 128 | netRefine_layer = torch.nn.Sequential( 129 | torch.nn.Conv2d(2 * fpn_dim, out_channels=128, kernel_size=3, stride=1, padding=1), 130 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 131 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), 132 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 133 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), 134 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 135 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1) 136 | ) 137 | self.netMain.append(netMain_layer) 138 | self.netRefine.append(netRefine_layer) 139 | 140 | self.netMain = nn.ModuleList(self.netMain) 141 | self.netRefine = nn.ModuleList(self.netRefine) 142 | 143 | 144 | def forward(self, x, x_warps, x_conds, warp_feature=True): 145 | last_flow = None 146 | 147 | for i in range(len(x_warps)): 148 | x_warp = x_warps[len(x_warps) - 1 - i] 149 | x_cond = x_conds[len(x_warps) - 1 - i] 150 | 151 | if last_flow is not None and warp_feature: 152 | x_warp_after = F.grid_sample(x_warp, last_flow.detach().permute(0, 2, 3, 1), 153 | mode='bilinear', padding_mode='border') 154 | else: 155 | x_warp_after = x_warp 156 | 157 | tenCorrelation = F.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=x_warp_after, tenSecond=x_cond, intStride=1), negative_slope=0.1, inplace=False) 158 | flow = self.netMain[i](tenCorrelation) 159 | flow = apply_offset(flow) 160 | 161 | if last_flow is not None: 162 | flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border') 163 | else: 164 | flow = flow.permute(0, 3, 1, 2) 165 | 166 | last_flow = flow 167 | x_warp = F.grid_sample(x_warp, flow.permute(0, 2, 3, 1),mode='bilinear', padding_mode='border') 168 | concat = torch.cat([x_warp,x_cond],1) 169 | flow = self.netRefine[i](concat) 170 | flow = apply_offset(flow) 171 | flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border') 172 | 173 | last_flow = F.interpolate(flow, scale_factor=2, mode='bilinear') 174 | 175 | x_warp = F.grid_sample(x, last_flow.permute(0, 2, 3, 1), 176 | mode='bilinear', padding_mode='border') 177 | return x_warp, last_flow, 178 | 179 | 180 | class AFWM(nn.Module): 181 | 182 | def __init__(self, opt, input_nc): 183 | super(AFWM, self).__init__() 184 | num_filters = [64,128,256,256,256] 185 | self.image_features = FeatureEncoder(3, num_filters) 186 | self.cond_features = FeatureEncoder(input_nc, num_filters) 187 | self.image_FPN = RefinePyramid(num_filters) 188 | self.cond_FPN = RefinePyramid(num_filters) 189 | self.aflow_net = AFlowNet(len(num_filters)) 190 | 191 | def forward(self, cond_input, image_input): 192 | cond_pyramids = self.cond_FPN(self.cond_features(cond_input)) # maybe use nn.Sequential 193 | image_pyramids = self.image_FPN(self.image_features(image_input)) 194 | 195 | x_warp, last_flow = self.aflow_net(image_input, image_pyramids, cond_pyramids) 196 | 197 | return x_warp, last_flow 198 | 199 | -------------------------------------------------------------------------------- /PF-AFN_test/models/correlation/README.md: -------------------------------------------------------------------------------- 1 | This is an adaptation of the FlowNet2 implementation in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately. -------------------------------------------------------------------------------- /PF-AFN_test/models/correlation/__pycache__/correlation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/models/correlation/__pycache__/correlation.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_test/models/correlation/__pycache__/correlation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/models/correlation/__pycache__/correlation.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_test/models/correlation/correlation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | 5 | import cupy 6 | import math 7 | import re 8 | 9 | kernel_Correlation_rearrange = ''' 10 | extern "C" __global__ void kernel_Correlation_rearrange( 11 | const int n, 12 | const float* input, 13 | float* output 14 | ) { 15 | int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; 16 | 17 | if (intIndex >= n) { 18 | return; 19 | } 20 | 21 | int intSample = blockIdx.z; 22 | int intChannel = blockIdx.y; 23 | 24 | float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; 25 | 26 | __syncthreads(); 27 | 28 | int intPaddedY = (intIndex / SIZE_3(input)) + 3*{{intStride}}; 29 | int intPaddedX = (intIndex % SIZE_3(input)) + 3*{{intStride}}; 30 | int intRearrange = ((SIZE_3(input) + 6*{{intStride}}) * intPaddedY) + intPaddedX; 31 | 32 | output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; 33 | } 34 | ''' 35 | 36 | kernel_Correlation_updateOutput = ''' 37 | extern "C" __global__ void kernel_Correlation_updateOutput( 38 | const int n, 39 | const float* rbot0, 40 | const float* rbot1, 41 | float* top 42 | ) { 43 | extern __shared__ char patch_data_char[]; 44 | 45 | float *patch_data = (float *)patch_data_char; 46 | 47 | // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 48 | int x1 = (blockIdx.x + 3) * {{intStride}}; 49 | int y1 = (blockIdx.y + 3) * {{intStride}}; 50 | int item = blockIdx.z; 51 | int ch_off = threadIdx.x; 52 | 53 | // Load 3D patch into shared shared memory 54 | for (int j = 0; j < 1; j++) { // HEIGHT 55 | for (int i = 0; i < 1; i++) { // WIDTH 56 | int ji_off = (j + i) * SIZE_3(rbot0); 57 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS 58 | int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; 59 | int idxPatchData = ji_off + ch; 60 | patch_data[idxPatchData] = rbot0[idx1]; 61 | } 62 | } 63 | } 64 | 65 | __syncthreads(); 66 | 67 | __shared__ float sum[32]; 68 | 69 | // Compute correlation 70 | for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { 71 | sum[ch_off] = 0; 72 | 73 | int s2o = (top_channel % 7 - 3) * {{intStride}}; 74 | int s2p = (top_channel / 7 - 3) * {{intStride}}; 75 | 76 | for (int j = 0; j < 1; j++) { // HEIGHT 77 | for (int i = 0; i < 1; i++) { // WIDTH 78 | int ji_off = (j + i) * SIZE_3(rbot0); 79 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS 80 | int x2 = x1 + s2o; 81 | int y2 = y1 + s2p; 82 | 83 | int idxPatchData = ji_off + ch; 84 | int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; 85 | 86 | sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; 87 | } 88 | } 89 | } 90 | 91 | __syncthreads(); 92 | 93 | if (ch_off == 0) { 94 | float total_sum = 0; 95 | for (int idx = 0; idx < 32; idx++) { 96 | total_sum += sum[idx]; 97 | } 98 | const int sumelems = SIZE_3(rbot0); 99 | const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; 100 | top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; 101 | } 102 | } 103 | } 104 | ''' 105 | 106 | kernel_Correlation_updateGradFirst = ''' 107 | #define ROUND_OFF 50000 108 | 109 | extern "C" __global__ void kernel_Correlation_updateGradFirst( 110 | const int n, 111 | const int intSample, 112 | const float* rbot0, 113 | const float* rbot1, 114 | const float* gradOutput, 115 | float* gradFirst, 116 | float* gradSecond 117 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 118 | int n = intIndex % SIZE_1(gradFirst); // channels 119 | int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 3*{{intStride}}; // w-pos 120 | int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 3*{{intStride}}; // h-pos 121 | 122 | // round_off is a trick to enable integer division with ceil, even for negative numbers 123 | // We use a large offset, for the inner part not to become negative. 124 | const int round_off = ROUND_OFF; 125 | const int round_off_s1 = {{intStride}} * round_off; 126 | 127 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: 128 | int xmin = (l - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}} 129 | int ymin = (m - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}} 130 | 131 | // Same here: 132 | int xmax = (l - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}}) / {{intStride}} 133 | int ymax = (m - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}}) / {{intStride}} 134 | 135 | float sum = 0; 136 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { 137 | xmin = max(0,xmin); 138 | xmax = min(SIZE_3(gradOutput)-1,xmax); 139 | 140 | ymin = max(0,ymin); 141 | ymax = min(SIZE_2(gradOutput)-1,ymax); 142 | 143 | for (int p = -3; p <= 3; p++) { 144 | for (int o = -3; o <= 3; o++) { 145 | // Get rbot1 data: 146 | int s2o = {{intStride}} * o; 147 | int s2p = {{intStride}} * p; 148 | int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; 149 | float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] 150 | 151 | // Index offset for gradOutput in following loops: 152 | int op = (p+3) * 7 + (o+3); // index[o,p] 153 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op); 154 | 155 | for (int y = ymin; y <= ymax; y++) { 156 | for (int x = xmin; x <= xmax; x++) { 157 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] 158 | sum += gradOutput[idxgradOutput] * bot1tmp; 159 | } 160 | } 161 | } 162 | } 163 | } 164 | const int sumelems = SIZE_1(gradFirst); 165 | const int bot0index = ((n * SIZE_2(gradFirst)) + (m-3*{{intStride}})) * SIZE_3(gradFirst) + (l-3*{{intStride}}); 166 | gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; 167 | } } 168 | ''' 169 | 170 | kernel_Correlation_updateGradSecond = ''' 171 | #define ROUND_OFF 50000 172 | 173 | extern "C" __global__ void kernel_Correlation_updateGradSecond( 174 | const int n, 175 | const int intSample, 176 | const float* rbot0, 177 | const float* rbot1, 178 | const float* gradOutput, 179 | float* gradFirst, 180 | float* gradSecond 181 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 182 | int n = intIndex % SIZE_1(gradSecond); // channels 183 | int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 3*{{intStride}}; // w-pos 184 | int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 3*{{intStride}}; // h-pos 185 | 186 | // round_off is a trick to enable integer division with ceil, even for negative numbers 187 | // We use a large offset, for the inner part not to become negative. 188 | const int round_off = ROUND_OFF; 189 | const int round_off_s1 = {{intStride}} * round_off; 190 | 191 | float sum = 0; 192 | for (int p = -3; p <= 3; p++) { 193 | for (int o = -3; o <= 3; o++) { 194 | int s2o = {{intStride}} * o; 195 | int s2p = {{intStride}} * p; 196 | 197 | //Get X,Y ranges and clamp 198 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: 199 | int xmin = (l - 3*{{intStride}} - s2o + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}} 200 | int ymin = (m - 3*{{intStride}} - s2p + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}} 201 | 202 | // Same here: 203 | int xmax = (l - 3*{{intStride}} - s2o + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}} - s2o) / {{intStride}} 204 | int ymax = (m - 3*{{intStride}} - s2p + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}} - s2p) / {{intStride}} 205 | 206 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { 207 | xmin = max(0,xmin); 208 | xmax = min(SIZE_3(gradOutput)-1,xmax); 209 | 210 | ymin = max(0,ymin); 211 | ymax = min(SIZE_2(gradOutput)-1,ymax); 212 | 213 | // Get rbot0 data: 214 | int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; 215 | float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] 216 | 217 | // Index offset for gradOutput in following loops: 218 | int op = (p+3) * 7 + (o+3); // index[o,p] 219 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op); 220 | 221 | for (int y = ymin; y <= ymax; y++) { 222 | for (int x = xmin; x <= xmax; x++) { 223 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] 224 | sum += gradOutput[idxgradOutput] * bot0tmp; 225 | } 226 | } 227 | } 228 | } 229 | } 230 | const int sumelems = SIZE_1(gradSecond); 231 | const int bot1index = ((n * SIZE_2(gradSecond)) + (m-3*{{intStride}})) * SIZE_3(gradSecond) + (l-3*{{intStride}}); 232 | gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; 233 | } } 234 | ''' 235 | 236 | def cupy_kernel(strFunction, objVariables): 237 | strKernel = globals()[strFunction].replace('{{intStride}}', str(objVariables['intStride'])) 238 | 239 | while True: 240 | objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) 241 | 242 | if objMatch is None: 243 | break 244 | # end 245 | 246 | intArg = int(objMatch.group(2)) 247 | 248 | strTensor = objMatch.group(4) 249 | intSizes = objVariables[strTensor].size() 250 | 251 | strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) 252 | # end 253 | 254 | while True: 255 | objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) 256 | 257 | if objMatch is None: 258 | break 259 | # end 260 | 261 | intArgs = int(objMatch.group(2)) 262 | strArgs = objMatch.group(4).split(',') 263 | 264 | strTensor = strArgs[0] 265 | intStrides = objVariables[strTensor].stride() 266 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] 267 | 268 | strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') 269 | # end 270 | 271 | return strKernel 272 | # end 273 | 274 | @cupy.util.memoize(for_each_device=True) 275 | def cupy_launch(strFunction, strKernel): 276 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) 277 | # end 278 | 279 | class _FunctionCorrelation(torch.autograd.Function): 280 | @staticmethod 281 | def forward(self, first, second, intStride): 282 | rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + (6 * intStride), first.shape[3] + (6 * intStride), first.shape[1] ]) 283 | rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + (6 * intStride), first.shape[3] + (6 * intStride), first.shape[1] ]) 284 | 285 | self.save_for_backward(first, second, rbot0, rbot1) 286 | 287 | self.intStride = intStride 288 | 289 | assert(first.is_contiguous() == True) 290 | assert(second.is_contiguous() == True) 291 | 292 | output = first.new_zeros([ first.shape[0], 49, int(math.ceil(first.shape[2] / intStride)), int(math.ceil(first.shape[3] / intStride)) ]) 293 | 294 | if first.is_cuda == True: 295 | n = first.shape[2] * first.shape[3] 296 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 297 | 'intStride': self.intStride, 298 | 'input': first, 299 | 'output': rbot0 300 | }))( 301 | grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]), 302 | block=tuple([ 16, 1, 1 ]), 303 | args=[ n, first.data_ptr(), rbot0.data_ptr() ] 304 | ) 305 | 306 | n = second.shape[2] * second.shape[3] 307 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 308 | 'intStride': self.intStride, 309 | 'input': second, 310 | 'output': rbot1 311 | }))( 312 | grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]), 313 | block=tuple([ 16, 1, 1 ]), 314 | args=[ n, second.data_ptr(), rbot1.data_ptr() ] 315 | ) 316 | 317 | n = output.shape[1] * output.shape[2] * output.shape[3] 318 | cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { 319 | 'intStride': self.intStride, 320 | 'rbot0': rbot0, 321 | 'rbot1': rbot1, 322 | 'top': output 323 | }))( 324 | grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]), 325 | block=tuple([ 32, 1, 1 ]), 326 | shared_mem=first.shape[1] * 4, 327 | args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ] 328 | ) 329 | 330 | elif first.is_cuda == False: 331 | raise NotImplementedError() 332 | 333 | # end 334 | 335 | return output 336 | # end 337 | 338 | @staticmethod 339 | def backward(self, gradOutput): 340 | first, second, rbot0, rbot1 = self.saved_tensors 341 | 342 | assert(gradOutput.is_contiguous() == True) 343 | 344 | gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None 345 | gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None 346 | 347 | if first.is_cuda == True: 348 | if gradFirst is not None: 349 | for intSample in range(first.shape[0]): 350 | n = first.shape[1] * first.shape[2] * first.shape[3] 351 | cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', { 352 | 'intStride': self.intStride, 353 | 'rbot0': rbot0, 354 | 'rbot1': rbot1, 355 | 'gradOutput': gradOutput, 356 | 'gradFirst': gradFirst, 357 | 'gradSecond': None 358 | }))( 359 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 360 | block=tuple([ 512, 1, 1 ]), 361 | args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ] 362 | ) 363 | # end 364 | # end 365 | 366 | if gradSecond is not None: 367 | for intSample in range(first.shape[0]): 368 | n = first.shape[1] * first.shape[2] * first.shape[3] 369 | cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', { 370 | 'intStride': self.intStride, 371 | 'rbot0': rbot0, 372 | 'rbot1': rbot1, 373 | 'gradOutput': gradOutput, 374 | 'gradFirst': None, 375 | 'gradSecond': gradSecond 376 | }))( 377 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 378 | block=tuple([ 512, 1, 1 ]), 379 | args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ] 380 | ) 381 | # end 382 | # end 383 | 384 | elif first.is_cuda == False: 385 | raise NotImplementedError() 386 | 387 | # end 388 | 389 | return gradFirst, gradSecond, None 390 | # end 391 | # end 392 | 393 | def FunctionCorrelation(tenFirst, tenSecond, intStride): 394 | return _FunctionCorrelation.apply(tenFirst, tenSecond, intStride) 395 | # end 396 | 397 | class ModuleCorrelation(torch.nn.Module): 398 | def __init__(self): 399 | super(ModuleCorrelation, self).__init__() 400 | # end 401 | 402 | def forward(self, tenFirst, tenSecond, intStride): 403 | return _FunctionCorrelation.apply(tenFirst, tenSecond, intStride) 404 | # end 405 | # end -------------------------------------------------------------------------------- /PF-AFN_test/models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import os 5 | 6 | class UnetSkipConnectionBlock(nn.Module): 7 | def __init__(self, outer_nc, inner_nc, input_nc=None, 8 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 9 | super(UnetSkipConnectionBlock, self).__init__() 10 | self.outermost = outermost 11 | use_bias = norm_layer == nn.InstanceNorm2d 12 | 13 | if input_nc is None: 14 | input_nc = outer_nc 15 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 16 | stride=2, padding=1, bias=use_bias) 17 | downrelu = nn.LeakyReLU(0.2, True) 18 | uprelu = nn.ReLU(True) 19 | if norm_layer != None: 20 | downnorm = norm_layer(inner_nc) 21 | upnorm = norm_layer(outer_nc) 22 | 23 | if outermost: 24 | upsample = nn.Upsample(scale_factor=2, mode='bilinear') 25 | upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) 26 | down = [downconv] 27 | up = [uprelu, upsample, upconv] 28 | model = down + [submodule] + up 29 | elif innermost: 30 | upsample = nn.Upsample(scale_factor=2, mode='bilinear') 31 | upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) 32 | down = [downrelu, downconv] 33 | if norm_layer == None: 34 | up = [uprelu, upsample, upconv] 35 | else: 36 | up = [uprelu, upsample, upconv, upnorm] 37 | model = down + up 38 | else: 39 | upsample = nn.Upsample(scale_factor=2, mode='bilinear') 40 | upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) 41 | if norm_layer == None: 42 | down = [downrelu, downconv] 43 | up = [uprelu, upsample, upconv] 44 | else: 45 | down = [downrelu, downconv, downnorm] 46 | up = [uprelu, upsample, upconv, upnorm] 47 | 48 | if use_dropout: 49 | model = down + [submodule] + up + [nn.Dropout(0.5)] 50 | else: 51 | model = down + [submodule] + up 52 | 53 | self.model = nn.Sequential(*model) 54 | 55 | def forward(self, x): 56 | if self.outermost: 57 | return self.model(x) 58 | else: 59 | return torch.cat([x, self.model(x)], 1) 60 | 61 | class ResidualBlock(nn.Module): 62 | def __init__(self, in_features=64, norm_layer=nn.BatchNorm2d): 63 | super(ResidualBlock, self).__init__() 64 | self.relu = nn.ReLU(True) 65 | if norm_layer == None: 66 | self.block = nn.Sequential( 67 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), 68 | nn.ReLU(inplace=True), 69 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), 70 | ) 71 | else: 72 | self.block = nn.Sequential( 73 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), 74 | norm_layer(in_features), 75 | nn.ReLU(inplace=True), 76 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), 77 | norm_layer(in_features) 78 | ) 79 | 80 | def forward(self, x): 81 | residual = x 82 | out = self.block(x) 83 | out += residual 84 | out = self.relu(out) 85 | return out 86 | 87 | class ResUnetGenerator(nn.Module): 88 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, 89 | norm_layer=nn.BatchNorm2d, use_dropout=False): 90 | super(ResUnetGenerator, self).__init__() 91 | unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) 92 | 93 | for i in range(num_downs - 5): 94 | unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 95 | unet_block = ResUnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 96 | unet_block = ResUnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 97 | unet_block = ResUnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 98 | unet_block = ResUnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) 99 | 100 | self.model = unet_block 101 | 102 | def forward(self, input): 103 | return self.model(input) 104 | 105 | 106 | class ResUnetSkipConnectionBlock(nn.Module): 107 | def __init__(self, outer_nc, inner_nc, input_nc=None, 108 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 109 | super(ResUnetSkipConnectionBlock, self).__init__() 110 | self.outermost = outermost 111 | use_bias = norm_layer == nn.InstanceNorm2d 112 | 113 | if input_nc is None: 114 | input_nc = outer_nc 115 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3, 116 | stride=2, padding=1, bias=use_bias) 117 | 118 | res_downconv = [ResidualBlock(inner_nc, norm_layer), ResidualBlock(inner_nc, norm_layer)] 119 | res_upconv = [ResidualBlock(outer_nc, norm_layer), ResidualBlock(outer_nc, norm_layer)] 120 | 121 | downrelu = nn.ReLU(True) 122 | uprelu = nn.ReLU(True) 123 | if norm_layer != None: 124 | downnorm = norm_layer(inner_nc) 125 | upnorm = norm_layer(outer_nc) 126 | 127 | if outermost: 128 | upsample = nn.Upsample(scale_factor=2, mode='nearest') 129 | upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) 130 | down = [downconv, downrelu] + res_downconv 131 | up = [upsample, upconv] 132 | model = down + [submodule] + up 133 | elif innermost: 134 | upsample = nn.Upsample(scale_factor=2, mode='nearest') 135 | upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) 136 | down = [downconv, downrelu] + res_downconv 137 | if norm_layer == None: 138 | up = [upsample, upconv, uprelu] + res_upconv 139 | else: 140 | up = [upsample, upconv, upnorm, uprelu] + res_upconv 141 | model = down + up 142 | else: 143 | upsample = nn.Upsample(scale_factor=2, mode='nearest') 144 | upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) 145 | if norm_layer == None: 146 | down = [downconv, downrelu] + res_downconv 147 | up = [upsample, upconv, uprelu] + res_upconv 148 | else: 149 | down = [downconv, downnorm, downrelu] + res_downconv 150 | up = [upsample, upconv, upnorm, uprelu] + res_upconv 151 | 152 | if use_dropout: 153 | model = down + [submodule] + up + [nn.Dropout(0.5)] 154 | else: 155 | model = down + [submodule] + up 156 | 157 | self.model = nn.Sequential(*model) 158 | 159 | def forward(self, x): 160 | if self.outermost: 161 | return self.model(x) 162 | else: 163 | return torch.cat([x, self.model(x)], 1) 164 | 165 | 166 | def save_checkpoint(model, save_path): 167 | if not os.path.exists(os.path.dirname(save_path)): 168 | os.makedirs(os.path.dirname(save_path)) 169 | torch.save(model.state_dict(), save_path) 170 | 171 | 172 | def load_checkpoint(model, checkpoint_path): 173 | 174 | if not os.path.exists(checkpoint_path): 175 | print('No checkpoint!') 176 | return 177 | 178 | checkpoint = torch.load(checkpoint_path) 179 | checkpoint_new = model.state_dict() 180 | for param in checkpoint_new: 181 | checkpoint_new[param] = checkpoint[param] 182 | 183 | model.load_state_dict(checkpoint_new) 184 | 185 | 186 | 187 | -------------------------------------------------------------------------------- /PF-AFN_test/options/__init__.py: -------------------------------------------------------------------------------- 1 | # options_init -------------------------------------------------------------------------------- /PF-AFN_test/options/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/options/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_test/options/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/options/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_test/options/__pycache__/base_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/options/__pycache__/base_options.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_test/options/__pycache__/base_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/options/__pycache__/base_options.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_test/options/__pycache__/test_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/options/__pycache__/test_options.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_test/options/__pycache__/test_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/options/__pycache__/test_options.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_test/options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | class BaseOptions(): 5 | def __init__(self): 6 | self.parser = argparse.ArgumentParser() 7 | self.initialized = False 8 | 9 | def initialize(self): 10 | self.parser.add_argument('--name', type=str, default='demo', help='name of the experiment. It decides where to store samples and models') 11 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 12 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 13 | self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator') 14 | self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit") 15 | self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose') 16 | 17 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 18 | self.parser.add_argument('--loadSize', type=int, default=512, help='scale images to this size') 19 | self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size') 20 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 21 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 22 | 23 | self.parser.add_argument('--dataroot', type=str, 24 | default='dataset/') 25 | self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 26 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 27 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') 28 | self.parser.add_argument('--nThreads', default=1, type=int, help='# threads for loading data') 29 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 30 | 31 | self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size') 32 | self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') 33 | 34 | self.initialized = True 35 | 36 | def parse(self, save=True): 37 | if not self.initialized: 38 | self.initialize() 39 | self.opt = self.parser.parse_args() 40 | self.opt.isTrain = self.isTrain # train or test 41 | 42 | str_ids = self.opt.gpu_ids.split(',') 43 | self.opt.gpu_ids = [] 44 | for str_id in str_ids: 45 | id = int(str_id) 46 | if id >= 0: 47 | self.opt.gpu_ids.append(id) 48 | 49 | if len(self.opt.gpu_ids) > 0: 50 | torch.cuda.set_device(self.opt.gpu_ids[0]) 51 | 52 | args = vars(self.opt) 53 | 54 | print('------------ Options -------------') 55 | for k, v in sorted(args.items()): 56 | print('%s: %s' % (str(k), str(v))) 57 | print('-------------- End ----------------') 58 | 59 | return self.opt 60 | -------------------------------------------------------------------------------- /PF-AFN_test/options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | class TestOptions(BaseOptions): 4 | def initialize(self): 5 | BaseOptions.initialize(self) 6 | 7 | self.parser.add_argument('--warp_checkpoint', type=str, default='checkpoints/PFAFN/warp_model_final.pth', help='load the pretrained model from the specified location') 8 | self.parser.add_argument('--gen_checkpoint', type=str, default='checkpoints/PFAFN/gen_model_final.pth', help='load the pretrained model from the specified location') 9 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 10 | 11 | self.isTrain = False 12 | -------------------------------------------------------------------------------- /PF-AFN_test/results/demo/PFAFN/0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/results/demo/PFAFN/0.jpg -------------------------------------------------------------------------------- /PF-AFN_test/results/demo/PFAFN/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/results/demo/PFAFN/1.jpg -------------------------------------------------------------------------------- /PF-AFN_test/results/demo/PFAFN/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/results/demo/PFAFN/2.jpg -------------------------------------------------------------------------------- /PF-AFN_test/results/demo/PFAFN/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/results/demo/PFAFN/3.jpg -------------------------------------------------------------------------------- /PF-AFN_test/results/demo/PFAFN/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/results/demo/PFAFN/4.jpg -------------------------------------------------------------------------------- /PF-AFN_test/results/demo/PFAFN/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/results/demo/PFAFN/5.jpg -------------------------------------------------------------------------------- /PF-AFN_test/test.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.test_options import TestOptions 3 | from data.data_loader_test import CreateDataLoader 4 | from models.networks import ResUnetGenerator, load_checkpoint 5 | from models.afwm import AFWM 6 | import torch.nn as nn 7 | import os 8 | import numpy as np 9 | import torch 10 | import cv2 11 | import torch.nn.functional as F 12 | 13 | opt = TestOptions().parse() 14 | 15 | start_epoch, epoch_iter = 1, 0 16 | 17 | data_loader = CreateDataLoader(opt) 18 | dataset = data_loader.load_data() 19 | dataset_size = len(data_loader) 20 | print(dataset_size) 21 | 22 | warp_model = AFWM(opt, 3) 23 | print(warp_model) 24 | warp_model.eval() 25 | warp_model.cuda() 26 | load_checkpoint(warp_model, opt.warp_checkpoint) 27 | 28 | gen_model = ResUnetGenerator(7, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d) 29 | print(gen_model) 30 | gen_model.eval() 31 | gen_model.cuda() 32 | load_checkpoint(gen_model, opt.gen_checkpoint) 33 | 34 | total_steps = (start_epoch-1) * dataset_size + epoch_iter 35 | step = 0 36 | step_per_batch = dataset_size / opt.batchSize 37 | 38 | for epoch in range(1,2): 39 | 40 | for i, data in enumerate(dataset, start=epoch_iter): 41 | iter_start_time = time.time() 42 | total_steps += opt.batchSize 43 | epoch_iter += opt.batchSize 44 | 45 | real_image = data['image'] 46 | clothes = data['clothes'] 47 | ##edge is extracted from the clothes image with the built-in function in python 48 | edge = data['edge'] 49 | edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int)) 50 | clothes = clothes * edge 51 | 52 | flow_out = warp_model(real_image.cuda(), clothes.cuda()) 53 | warped_cloth, last_flow, = flow_out 54 | warped_edge = F.grid_sample(edge.cuda(), last_flow.permute(0, 2, 3, 1), 55 | mode='bilinear', padding_mode='zeros') 56 | 57 | gen_inputs = torch.cat([real_image.cuda(), warped_cloth, warped_edge], 1) 58 | gen_outputs = gen_model(gen_inputs) 59 | p_rendered, m_composite = torch.split(gen_outputs, [3, 1], 1) 60 | p_rendered = torch.tanh(p_rendered) 61 | m_composite = torch.sigmoid(m_composite) 62 | m_composite = m_composite * warped_edge 63 | p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite) 64 | 65 | path = 'results/' + opt.name 66 | os.makedirs(path, exist_ok=True) 67 | sub_path = path + '/PFAFN' 68 | os.makedirs(sub_path,exist_ok=True) 69 | 70 | if step % 1 == 0: 71 | a = real_image.float().cuda() 72 | b= clothes.cuda() 73 | c = p_tryon 74 | combine = torch.cat([a[0],b[0],c[0]], 2).squeeze() 75 | cv_img=(combine.permute(1,2,0).detach().cpu().numpy()+1)/2 76 | rgb=(cv_img*255).astype(np.uint8) 77 | bgr=cv2.cvtColor(rgb,cv2.COLOR_RGB2BGR) 78 | cv2.imwrite(sub_path+'/'+str(step)+'.jpg',bgr) 79 | 80 | step += 1 81 | if epoch_iter >= dataset_size: 82 | break 83 | 84 | 85 | -------------------------------------------------------------------------------- /PF-AFN_test/test.sh: -------------------------------------------------------------------------------- 1 | python test.py --name demo --resize_or_crop None --batchSize 1 --gpu_ids 0 2 | -------------------------------------------------------------------------------- /PF-AFN_test/util/__init__.py: -------------------------------------------------------------------------------- 1 | # util_init -------------------------------------------------------------------------------- /PF-AFN_test/util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_test/util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_test/util/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/util/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_test/util/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/util/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_test/util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.autograd import Variable 4 | class ImagePool(): 5 | def __init__(self, pool_size): 6 | self.pool_size = pool_size 7 | if self.pool_size > 0: 8 | self.num_imgs = 0 9 | self.images = [] 10 | 11 | def query(self, images): 12 | if self.pool_size == 0: 13 | return images 14 | return_images = [] 15 | for image in images.data: 16 | image = torch.unsqueeze(image, 0) 17 | if self.num_imgs < self.pool_size: 18 | self.num_imgs = self.num_imgs + 1 19 | self.images.append(image) 20 | return_images.append(image) 21 | else: 22 | p = random.uniform(0, 1) 23 | if p > 0.5: 24 | random_id = random.randint(0, self.pool_size-1) 25 | tmp = self.images[random_id].clone() 26 | self.images[random_id] = image 27 | return_images.append(tmp) 28 | else: 29 | return_images.append(image) 30 | return_images = Variable(torch.cat(return_images, 0)) 31 | return return_images 32 | -------------------------------------------------------------------------------- /PF-AFN_test/util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | 8 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True): 9 | if isinstance(image_tensor, list): 10 | image_numpy = [] 11 | for i in range(len(image_tensor)): 12 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) 13 | return image_numpy 14 | image_numpy = image_tensor.cpu().float().numpy() 15 | 16 | image_numpy = (image_numpy + 1) / 2.0 17 | image_numpy = np.clip(image_numpy, 0, 1) 18 | if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3: 19 | image_numpy = image_numpy[:,:,0] 20 | 21 | return image_numpy 22 | 23 | def tensor2label(label_tensor, n_label, imtype=np.uint8): 24 | if n_label == 0: 25 | return tensor2im(label_tensor, imtype) 26 | label_tensor = label_tensor.cpu().float() 27 | if label_tensor.size()[0] > 1: 28 | label_tensor = label_tensor.max(0, keepdim=True)[1] 29 | label_tensor = Colorize(n_label)(label_tensor) 30 | label_numpy = label_tensor.numpy() 31 | label_numpy = label_numpy / 255.0 32 | 33 | return label_numpy 34 | 35 | def save_image(image_numpy, image_path): 36 | image_pil = Image.fromarray(image_numpy) 37 | image_pil.save(image_path) 38 | 39 | def mkdirs(paths): 40 | if isinstance(paths, list) and not isinstance(paths, str): 41 | for path in paths: 42 | mkdir(path) 43 | else: 44 | mkdir(paths) 45 | 46 | def mkdir(path): 47 | if not os.path.exists(path): 48 | os.makedirs(path) 49 | 50 | 51 | def uint82bin(n, count=8): 52 | """returns the binary of integer n, count refers to amount of bits""" 53 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)]) 54 | 55 | def labelcolormap(N): 56 | if N == 35: # cityscape 57 | cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81), 58 | (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153), 59 | (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0), 60 | (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70), 61 | ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)], 62 | dtype=np.uint8) 63 | else: 64 | cmap = np.zeros((N, 3), dtype=np.uint8) 65 | for i in range(N): 66 | r, g, b = 0, 0, 0 67 | id = i 68 | for j in range(7): 69 | str_id = uint82bin(id) 70 | r = r ^ (np.uint8(str_id[-1]) << (7-j)) 71 | g = g ^ (np.uint8(str_id[-2]) << (7-j)) 72 | b = b ^ (np.uint8(str_id[-3]) << (7-j)) 73 | id = id >> 3 74 | cmap[i, 0] = r 75 | cmap[i, 1] = g 76 | cmap[i, 2] = b 77 | return cmap 78 | 79 | class Colorize(object): 80 | def __init__(self, n=35): 81 | self.cmap = labelcolormap(n) 82 | self.cmap = torch.from_numpy(self.cmap[:n]) 83 | 84 | def __call__(self, gray_image): 85 | size = gray_image.size() 86 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 87 | 88 | for label in range(0, len(self.cmap)): 89 | mask = (label == gray_image[0]).cpu() 90 | color_image[0][mask] = self.cmap[label][0] 91 | color_image[1][mask] = self.cmap[label][1] 92 | color_image[2][mask] = self.cmap[label][2] 93 | 94 | return color_image 95 | -------------------------------------------------------------------------------- /PF-AFN_train/checkpoints/readme.txt: -------------------------------------------------------------------------------- 1 | The checkpoints will be saved here. 2 | -------------------------------------------------------------------------------- /PF-AFN_train/data/__init__.py: -------------------------------------------------------------------------------- 1 | # data_init -------------------------------------------------------------------------------- /PF-AFN_train/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_train/data/__pycache__/aligned_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/aligned_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/data/__pycache__/aligned_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/aligned_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_train/data/__pycache__/aligned_dataset_fake.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/aligned_dataset_fake.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/data/__pycache__/aligned_dataset_test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/aligned_dataset_test.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/data/__pycache__/base_data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/base_data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/data/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/data/__pycache__/base_dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/base_dataset.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_train/data/__pycache__/custom_dataset_data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/custom_dataset_data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/data/__pycache__/custom_dataset_data_loader_test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/custom_dataset_data_loader_test.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/data/__pycache__/data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/data/__pycache__/data_loader_test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/data_loader_test.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/data/__pycache__/image_folder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/image_folder.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/data/__pycache__/image_folder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/image_folder.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_train/data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_params, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import torch 6 | import json 7 | import numpy as np 8 | import os.path as osp 9 | from PIL import ImageDraw 10 | 11 | 12 | class AlignedDataset(BaseDataset): 13 | def initialize(self, opt): 14 | self.opt = opt 15 | self.root = opt.dataroot 16 | self.diction={} 17 | 18 | if opt.isTrain or opt.use_encoded_image: 19 | dir_A = '_A' if self.opt.label_nc == 0 else '_label' 20 | self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A) 21 | self.A_paths = sorted(make_dataset(self.dir_A)) 22 | 23 | self.fine_height=256 24 | self.fine_width=192 25 | self.radius=5 26 | 27 | dir_B = '_B' if self.opt.label_nc == 0 else '_img' 28 | self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B) 29 | self.B_paths = sorted(make_dataset(self.dir_B)) 30 | 31 | self.dataset_size = len(self.A_paths) 32 | 33 | if opt.isTrain or opt.use_encoded_image: 34 | dir_E = '_edge' 35 | self.dir_E = os.path.join(opt.dataroot, opt.phase + dir_E) 36 | self.E_paths = sorted(make_dataset(self.dir_E)) 37 | 38 | if opt.isTrain or opt.use_encoded_image: 39 | dir_C = '_color' 40 | self.dir_C = os.path.join(opt.dataroot, opt.phase + dir_C) 41 | self.C_paths = sorted(make_dataset(self.dir_C)) 42 | 43 | 44 | def __getitem__(self, index): 45 | 46 | A_path = self.A_paths[index] 47 | A = Image.open(A_path).convert('L') 48 | 49 | params = get_params(self.opt, A.size) 50 | if self.opt.label_nc == 0: 51 | transform_A = get_transform(self.opt, params) 52 | A_tensor = transform_A(A.convert('RGB')) 53 | else: 54 | transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) 55 | A_tensor = transform_A(A) * 255.0 56 | 57 | B_path = self.B_paths[index] 58 | B = Image.open(B_path).convert('RGB') 59 | transform_B = get_transform(self.opt, params) 60 | B_tensor = transform_B(B) 61 | 62 | C_path = self.C_paths[index] 63 | C = Image.open(C_path).convert('RGB') 64 | C_tensor = transform_B(C) 65 | 66 | E_path = self.E_paths[index] 67 | E = Image.open(E_path).convert('L') 68 | E_tensor = transform_A(E) 69 | 70 | index_un = np.random.randint(14221) 71 | C_un_path = self.C_paths[index_un] 72 | C_un = Image.open(C_un_path).convert('RGB') 73 | C_un_tensor = transform_B(C_un) 74 | 75 | E_un_path = self.E_paths[index_un] 76 | E_un = Image.open(E_un_path).convert('L') 77 | E_un_tensor = transform_A(E_un) 78 | 79 | pose_name =B_path.replace('.png', '_keypoints.json').replace('.jpg','_keypoints.json').replace('train_img','train_pose') 80 | with open(osp.join(pose_name), 'r') as f: 81 | pose_label = json.load(f) 82 | try: 83 | pose_data = pose_label['people'][0]['pose_keypoints'] 84 | except IndexError: 85 | pose_data = [0 for i in range(54)] 86 | pose_data = np.array(pose_data) 87 | pose_data = pose_data.reshape((-1,3)) 88 | 89 | point_num = pose_data.shape[0] 90 | pose_map = torch.zeros(point_num, self.fine_height, self.fine_width) 91 | r = self.radius 92 | im_pose = Image.new('L', (self.fine_width, self.fine_height)) 93 | pose_draw = ImageDraw.Draw(im_pose) 94 | for i in range(point_num): 95 | one_map = Image.new('L', (self.fine_width, self.fine_height)) 96 | draw = ImageDraw.Draw(one_map) 97 | pointx = pose_data[i,0] 98 | pointy = pose_data[i,1] 99 | if pointx > 1 and pointy > 1: 100 | draw.rectangle((pointx-r, pointy-r, pointx+r, pointy+r), 'white', 'white') 101 | pose_draw.rectangle((pointx-r, pointy-r, pointx+r, pointy+r), 'white', 'white') 102 | one_map = transform_B(one_map.convert('RGB')) 103 | pose_map[i] = one_map[0] 104 | P_tensor=pose_map 105 | 106 | densepose_name = B_path.replace('.png', '.npy').replace('.jpg','.npy').replace('train_img','train_densepose') 107 | dense_mask = np.load(densepose_name).astype(np.float32) 108 | dense_mask = transform_A(dense_mask) 109 | 110 | if self.opt.isTrain: 111 | input_dict = { 'label': A_tensor, 'image': B_tensor, 'path': A_path, 'img_path': B_path ,'color_path': C_path,'color_un_path': C_un_path, 112 | 'edge': E_tensor, 'color': C_tensor, 'edge_un': E_un_tensor, 'color_un': C_un_tensor, 'pose':P_tensor, 'densepose':dense_mask 113 | } 114 | 115 | return input_dict 116 | 117 | def __len__(self): 118 | return len(self.A_paths) // (self.opt.batchSize * self.opt.num_gpus) * (self.opt.batchSize * self.opt.num_gpus) 119 | 120 | def name(self): 121 | return 'AlignedDataset' 122 | -------------------------------------------------------------------------------- /PF-AFN_train/data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | class BaseDataLoader(): 2 | def __init__(self): 3 | pass 4 | 5 | def initialize(self, opt): 6 | self.opt = opt 7 | pass 8 | 9 | def load_data(): 10 | return None 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /PF-AFN_train/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import random 6 | 7 | class BaseDataset(data.Dataset): 8 | def __init__(self): 9 | super(BaseDataset, self).__init__() 10 | 11 | def name(self): 12 | return 'BaseDataset' 13 | 14 | def initialize(self, opt): 15 | pass 16 | 17 | def get_params(opt, size): 18 | w, h = size 19 | new_h = h 20 | new_w = w 21 | if opt.resize_or_crop == 'resize_and_crop': 22 | new_h = new_w = opt.loadSize 23 | elif opt.resize_or_crop == 'scale_width_and_crop': 24 | new_w = opt.loadSize 25 | new_h = opt.loadSize * h // w 26 | 27 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) 28 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) 29 | 30 | #flip = random.random() > 0.5 31 | flip = 0 32 | return {'crop_pos': (x, y), 'flip': flip} 33 | 34 | def get_transform_resize(opt, params, method=Image.BICUBIC, normalize=True): 35 | transform_list = [] 36 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) 37 | osize = [256,192] 38 | transform_list.append(transforms.Scale(osize, method)) 39 | if 'crop' in opt.resize_or_crop: 40 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) 41 | 42 | if opt.resize_or_crop == 'none': 43 | base = float(2 ** opt.n_downsample_global) 44 | if opt.netG == 'local': 45 | base *= (2 ** opt.n_local_enhancers) 46 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 47 | 48 | if opt.isTrain and not opt.no_flip: 49 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 50 | 51 | transform_list += [transforms.ToTensor()] 52 | 53 | if normalize: 54 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 55 | (0.5, 0.5, 0.5))] 56 | return transforms.Compose(transform_list) 57 | 58 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True): 59 | transform_list = [] 60 | if 'resize' in opt.resize_or_crop: 61 | osize = [opt.loadSize, opt.loadSize] 62 | transform_list.append(transforms.Scale(osize, method)) 63 | elif 'scale_width' in opt.resize_or_crop: 64 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) 65 | osize = [256,192] 66 | transform_list.append(transforms.Scale(osize, method)) 67 | if 'crop' in opt.resize_or_crop: 68 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) 69 | 70 | if opt.resize_or_crop == 'none': 71 | base = float(2 ** opt.n_downsample_global) 72 | if opt.netG == 'local': 73 | base *= (2 ** opt.n_local_enhancers) 74 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 75 | 76 | if opt.isTrain and not opt.no_flip: 77 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 78 | 79 | transform_list += [transforms.ToTensor()] 80 | 81 | if normalize: 82 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 83 | (0.5, 0.5, 0.5))] 84 | return transforms.Compose(transform_list) 85 | 86 | def normalize(): 87 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 88 | 89 | def __make_power_2(img, base, method=Image.BICUBIC): 90 | ow, oh = img.size 91 | h = int(round(oh / base) * base) 92 | w = int(round(ow / base) * base) 93 | if (h == oh) and (w == ow): 94 | return img 95 | return img.resize((w, h), method) 96 | 97 | def __scale_width(img, target_width, method=Image.BICUBIC): 98 | ow, oh = img.size 99 | if (ow == target_width): 100 | return img 101 | w = target_width 102 | h = int(target_width * oh / ow) 103 | return img.resize((w, h), method) 104 | 105 | def __crop(img, pos, size): 106 | ow, oh = img.size 107 | x1, y1 = pos 108 | tw = th = size 109 | if (ow > tw or oh > th): 110 | return img.crop((x1, y1, x1 + tw, y1 + th)) 111 | return img 112 | 113 | def __flip(img, flip): 114 | if flip: 115 | return img.transpose(Image.FLIP_LEFT_RIGHT) 116 | return img 117 | -------------------------------------------------------------------------------- /PF-AFN_train/data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | from data.aligned_dataset import AlignedDataset 8 | dataset = AlignedDataset() 9 | 10 | print("dataset [%s] was created" % (dataset.name())) 11 | dataset.initialize(opt) 12 | return dataset 13 | 14 | class CustomDatasetDataLoader(BaseDataLoader): 15 | def name(self): 16 | return 'CustomDatasetDataLoader' 17 | 18 | def initialize(self, opt): 19 | BaseDataLoader.initialize(self, opt) 20 | self.dataset = CreateDataset(opt) 21 | self.dataloader = torch.utils.data.DataLoader( 22 | self.dataset, 23 | batch_size=opt.batchSize, 24 | shuffle=not opt.serial_batches, 25 | num_workers=int(opt.nThreads)) 26 | 27 | def load_data(self): 28 | return self.dataloader 29 | 30 | def __len__(self): 31 | return min(len(self.dataset), self.opt.max_dataset_size) 32 | -------------------------------------------------------------------------------- /PF-AFN_train/data/data_loader.py: -------------------------------------------------------------------------------- 1 | def CreateDataLoader(opt): 2 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 3 | data_loader = CustomDatasetDataLoader() 4 | print(data_loader.name()) 5 | data_loader.initialize(opt) 6 | return data_loader 7 | -------------------------------------------------------------------------------- /PF-AFN_train/data/image_folder.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import os 4 | 5 | IMG_EXTENSIONS = [ 6 | '.jpg', '.JPG', '.jpeg', '.JPEG', 7 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 8 | ] 9 | 10 | 11 | def is_image_file(filename): 12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 13 | 14 | def make_dataset(dir): 15 | images = [] 16 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 17 | 18 | f = dir.split('/')[-1].split('_')[-1] 19 | print (dir, f) 20 | dirs= os.listdir(dir) 21 | for img in dirs: 22 | 23 | path = os.path.join(dir, img) 24 | #print(path) 25 | images.append(path) 26 | return images 27 | 28 | def make_dataset_test(dir): 29 | images = [] 30 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 31 | 32 | f = dir.split('/')[-1].split('_')[-1] 33 | for i in range(len([name for name in os.listdir(dir) if os.path.isfile(os.path.join(dir, name))])): 34 | if f == 'label' or f == 'labelref': 35 | img = str(i) + '.png' 36 | else: 37 | img = str(i) + '.jpg' 38 | path = os.path.join(dir, img) 39 | #print(path) 40 | images.append(path) 41 | return images 42 | 43 | def default_loader(path): 44 | return Image.open(path).convert('RGB') 45 | 46 | 47 | class ImageFolder(data.Dataset): 48 | 49 | def __init__(self, root, transform=None, return_paths=False, 50 | loader=default_loader): 51 | imgs = make_dataset(root) 52 | if len(imgs) == 0: 53 | raise(RuntimeError("Found 0 images in: " + root + "\n" 54 | "Supported image extensions are: " + 55 | ",".join(IMG_EXTENSIONS))) 56 | 57 | self.root = root 58 | self.imgs = imgs 59 | self.transform = transform 60 | self.return_paths = return_paths 61 | self.loader = loader 62 | 63 | def __getitem__(self, index): 64 | path = self.imgs[index] 65 | img = self.loader(path) 66 | if self.transform is not None: 67 | img = self.transform(img) 68 | if self.return_paths: 69 | return img, path 70 | else: 71 | return img 72 | 73 | def __len__(self): 74 | return len(self.imgs) 75 | -------------------------------------------------------------------------------- /PF-AFN_train/dataset/readme.txt: -------------------------------------------------------------------------------- 1 | The downloaded dataset should be put here. 2 | -------------------------------------------------------------------------------- /PF-AFN_train/models/__init__.py: -------------------------------------------------------------------------------- 1 | # model_init -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/afwm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/afwm.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/base_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/base_model.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_add.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_add.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_add.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_add.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_more.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_more_add.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_add.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_more_feat.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_feat.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_more_grid_offset_sep.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_grid_offset_sep.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_more_grid_sep.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_grid_sep.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_more_new.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_new.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_more_offset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_offset.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all_more.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all_more.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all_more_heatmap.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all_more_heatmap.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all_more_no_refine.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all_more_no_refine.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all_more_trans.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all_more_trans.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_more_sep.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_sep.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_cor_sep.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_sep.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_smooth.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_smooth.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/flow_gmm_vis.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_vis.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/models.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/networks_flow.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/networks_flow.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/pix2pixHD_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/pix2pixHD_model.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/__pycache__/predict_mask.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/predict_mask.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/afwm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from options.train_options import TrainOptions 6 | from .correlation import correlation # the custom cost volume layer 7 | opt = TrainOptions().parse() 8 | 9 | def apply_offset(offset): 10 | sizes = list(offset.size()[2:]) 11 | grid_list = torch.meshgrid([torch.arange(size, device=offset.device) for size in sizes]) 12 | grid_list = reversed(grid_list) 13 | # apply offset 14 | grid_list = [grid.float().unsqueeze(0) + offset[:, dim, ...] 15 | for dim, grid in enumerate(grid_list)] 16 | # normalize 17 | grid_list = [grid / ((size - 1.0) / 2.0) - 1.0 18 | for grid, size in zip(grid_list, reversed(sizes))] 19 | 20 | return torch.stack(grid_list, dim=-1) 21 | 22 | 23 | def TVLoss(x): 24 | tv_h = x[:, :, 1:, :] - x[:, :, :-1, :] 25 | tv_w = x[:, :, :, 1:] - x[:, :, :, :-1] 26 | 27 | return torch.mean(torch.abs(tv_h)) + torch.mean(torch.abs(tv_w)) 28 | 29 | 30 | # backbone 31 | class ResBlock(nn.Module): 32 | def __init__(self, in_channels): 33 | super(ResBlock, self).__init__() 34 | self.block = nn.Sequential( 35 | nn.BatchNorm2d(in_channels), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False), 38 | nn.BatchNorm2d(in_channels), 39 | nn.ReLU(inplace=True), 40 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False) 41 | ) 42 | 43 | def forward(self, x): 44 | return self.block(x) + x 45 | 46 | 47 | class DownSample(nn.Module): 48 | def __init__(self, in_channels, out_channels): 49 | super(DownSample, self).__init__() 50 | self.block= nn.Sequential( 51 | nn.BatchNorm2d(in_channels), 52 | nn.ReLU(inplace=True), 53 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False) 54 | ) 55 | 56 | def forward(self, x): 57 | return self.block(x) 58 | 59 | 60 | 61 | class FeatureEncoder(nn.Module): 62 | def __init__(self, in_channels, chns=[64,128,256,256,256]): 63 | # in_channels = 3 for images, and is larger (e.g., 17+1+1) for agnositc representation 64 | super(FeatureEncoder, self).__init__() 65 | self.encoders = [] 66 | for i, out_chns in enumerate(chns): 67 | if i == 0: 68 | encoder = nn.Sequential(DownSample(in_channels, out_chns), 69 | ResBlock(out_chns), 70 | ResBlock(out_chns)) 71 | else: 72 | encoder = nn.Sequential(DownSample(chns[i-1], out_chns), 73 | ResBlock(out_chns), 74 | ResBlock(out_chns)) 75 | 76 | self.encoders.append(encoder) 77 | 78 | self.encoders = nn.ModuleList(self.encoders) 79 | 80 | 81 | def forward(self, x): 82 | encoder_features = [] 83 | for encoder in self.encoders: 84 | x = encoder(x) 85 | encoder_features.append(x) 86 | return encoder_features 87 | 88 | class RefinePyramid(nn.Module): 89 | def __init__(self, chns=[64,128,256,256,256], fpn_dim=256): 90 | super(RefinePyramid, self).__init__() 91 | self.chns = chns 92 | 93 | # adaptive 94 | self.adaptive = [] 95 | for in_chns in list(reversed(chns)): 96 | adaptive_layer = nn.Conv2d(in_chns, fpn_dim, kernel_size=1) 97 | self.adaptive.append(adaptive_layer) 98 | self.adaptive = nn.ModuleList(self.adaptive) 99 | # output conv 100 | self.smooth = [] 101 | for i in range(len(chns)): 102 | smooth_layer = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, padding=1) 103 | self.smooth.append(smooth_layer) 104 | self.smooth = nn.ModuleList(self.smooth) 105 | 106 | def forward(self, x): 107 | conv_ftr_list = x 108 | 109 | feature_list = [] 110 | last_feature = None 111 | for i, conv_ftr in enumerate(list(reversed(conv_ftr_list))): 112 | # adaptive 113 | feature = self.adaptive[i](conv_ftr) 114 | # fuse 115 | if last_feature is not None: 116 | feature = feature + F.interpolate(last_feature, scale_factor=2, mode='nearest') 117 | # smooth 118 | feature = self.smooth[i](feature) 119 | last_feature = feature 120 | feature_list.append(feature) 121 | 122 | return tuple(reversed(feature_list)) 123 | 124 | 125 | class AFlowNet(nn.Module): 126 | def __init__(self, num_pyramid, fpn_dim=256): 127 | super(AFlowNet, self).__init__() 128 | self.netMain = [] 129 | self.netRefine = [] 130 | for i in range(num_pyramid): 131 | netMain_layer = torch.nn.Sequential( 132 | torch.nn.Conv2d(in_channels=49, out_channels=128, kernel_size=3, stride=1, padding=1), 133 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 134 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), 135 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 136 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), 137 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 138 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1) 139 | ) 140 | 141 | netRefine_layer = torch.nn.Sequential( 142 | torch.nn.Conv2d(2 * fpn_dim, out_channels=128, kernel_size=3, stride=1, padding=1), 143 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 144 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), 145 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 146 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), 147 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 148 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1) 149 | ) 150 | self.netMain.append(netMain_layer) 151 | self.netRefine.append(netRefine_layer) 152 | 153 | self.netMain = nn.ModuleList(self.netMain) 154 | self.netRefine = nn.ModuleList(self.netRefine) 155 | 156 | 157 | def forward(self, x, x_edge, x_warps, x_conds, warp_feature=True): 158 | last_flow = None 159 | last_flow_all = [] 160 | delta_list = [] 161 | x_all = [] 162 | x_edge_all = [] 163 | cond_fea_all = [] 164 | delta_x_all = [] 165 | delta_y_all = [] 166 | filter_x = [[0, 0, 0], 167 | [1, -2, 1], 168 | [0, 0, 0]] 169 | filter_y = [[0, 1, 0], 170 | [0, -2, 0], 171 | [0, 1, 0]] 172 | filter_diag1 = [[1, 0, 0], 173 | [0, -2, 0], 174 | [0, 0, 1]] 175 | filter_diag2 = [[0, 0, 1], 176 | [0, -2, 0], 177 | [1, 0, 0]] 178 | weight_array = np.ones([3, 3, 1, 4]) 179 | weight_array[:, :, 0, 0] = filter_x 180 | weight_array[:, :, 0, 1] = filter_y 181 | weight_array[:, :, 0, 2] = filter_diag1 182 | weight_array[:, :, 0, 3] = filter_diag2 183 | 184 | weight_array = torch.cuda.FloatTensor(weight_array).permute(3,2,0,1) 185 | self.weight = nn.Parameter(data=weight_array, requires_grad=False) 186 | 187 | for i in range(len(x_warps)): 188 | x_warp = x_warps[len(x_warps) - 1 - i] 189 | x_cond = x_conds[len(x_warps) - 1 - i] 190 | cond_fea_all.append(x_cond) 191 | 192 | if last_flow is not None and warp_feature: 193 | x_warp_after = F.grid_sample(x_warp, last_flow.detach().permute(0, 2, 3, 1), 194 | mode='bilinear', padding_mode='border') 195 | else: 196 | x_warp_after = x_warp 197 | 198 | tenCorrelation = F.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=x_warp_after, tenSecond=x_cond, intStride=1), negative_slope=0.1, inplace=False) 199 | flow = self.netMain[i](tenCorrelation) 200 | delta_list.append(flow) 201 | flow = apply_offset(flow) 202 | if last_flow is not None: 203 | flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border') 204 | else: 205 | flow = flow.permute(0, 3, 1, 2) 206 | 207 | last_flow = flow 208 | x_warp = F.grid_sample(x_warp, flow.permute(0, 2, 3, 1),mode='bilinear', padding_mode='border') 209 | concat = torch.cat([x_warp,x_cond],1) 210 | flow = self.netRefine[i](concat) 211 | delta_list.append(flow) 212 | flow = apply_offset(flow) 213 | flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border') 214 | 215 | last_flow = F.interpolate(flow, scale_factor=2, mode='bilinear') 216 | last_flow_all.append(last_flow) 217 | cur_x = F.interpolate(x, scale_factor=0.5**(len(x_warps)-1-i), mode='bilinear') 218 | cur_x_warp = F.grid_sample(cur_x, last_flow.permute(0, 2, 3, 1),mode='bilinear', padding_mode='border') 219 | x_all.append(cur_x_warp) 220 | cur_x_edge = F.interpolate(x_edge, scale_factor=0.5**(len(x_warps)-1-i), mode='bilinear') 221 | cur_x_warp_edge = F.grid_sample(cur_x_edge, last_flow.permute(0, 2, 3, 1),mode='bilinear', padding_mode='zeros') 222 | x_edge_all.append(cur_x_warp_edge) 223 | flow_x,flow_y = torch.split(last_flow,1,dim=1) 224 | delta_x = F.conv2d(flow_x, self.weight) 225 | delta_y = F.conv2d(flow_y,self.weight) 226 | delta_x_all.append(delta_x) 227 | delta_y_all.append(delta_y) 228 | 229 | x_warp = F.grid_sample(x, last_flow.permute(0, 2, 3, 1), 230 | mode='bilinear', padding_mode='border') 231 | return x_warp, last_flow, cond_fea_all, last_flow_all, delta_list, x_all, x_edge_all, delta_x_all, delta_y_all 232 | 233 | 234 | class AFWM(nn.Module): 235 | 236 | def __init__(self, opt, input_nc): 237 | super(AFWM, self).__init__() 238 | num_filters = [64,128,256,256,256] 239 | self.image_features = FeatureEncoder(3, num_filters) 240 | self.cond_features = FeatureEncoder(input_nc, num_filters) 241 | self.image_FPN = RefinePyramid(num_filters) 242 | self.cond_FPN = RefinePyramid(num_filters) 243 | self.aflow_net = AFlowNet(len(num_filters)) 244 | self.old_lr = opt.lr 245 | self.old_lr_warp = opt.lr*0.2 246 | 247 | def forward(self, cond_input, image_input, image_edge): 248 | cond_pyramids = self.cond_FPN(self.cond_features(cond_input)) # maybe use nn.Sequential 249 | image_pyramids = self.image_FPN(self.image_features(image_input)) 250 | 251 | x_warp, last_flow, last_flow_all, flow_all, delta_list, x_all, x_edge_all, delta_x_all, delta_y_all = self.aflow_net(image_input, image_edge, image_pyramids, cond_pyramids) 252 | 253 | return x_warp, last_flow, last_flow_all, flow_all, delta_list, x_all, x_edge_all, delta_x_all, delta_y_all 254 | 255 | 256 | def update_learning_rate(self,optimizer): 257 | lrd = opt.lr / opt.niter_decay 258 | lr = self.old_lr - lrd 259 | for param_group in optimizer.param_groups: 260 | param_group['lr'] = lr 261 | if opt.verbose: 262 | print('update learning rate: %f -> %f' % (self.old_lr, lr)) 263 | self.old_lr = lr 264 | 265 | def update_learning_rate_warp(self,optimizer): 266 | lrd = 0.2 * opt.lr / opt.niter_decay 267 | lr = self.old_lr_warp - lrd 268 | for param_group in optimizer.param_groups: 269 | param_group['lr'] = lr 270 | if opt.verbose: 271 | print('update learning rate: %f -> %f' % (self.old_lr_warp, lr)) 272 | self.old_lr_warp = lr 273 | 274 | -------------------------------------------------------------------------------- /PF-AFN_train/models/correlation/README.md: -------------------------------------------------------------------------------- 1 | This is an adaptation of the FlowNet2 implementation in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately. -------------------------------------------------------------------------------- /PF-AFN_train/models/correlation/__pycache__/correlation.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/correlation/__pycache__/correlation.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/correlation/__pycache__/correlation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/correlation/__pycache__/correlation.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_train/models/correlation/correlation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | 5 | import cupy 6 | import math 7 | import re 8 | 9 | kernel_Correlation_rearrange = ''' 10 | extern "C" __global__ void kernel_Correlation_rearrange( 11 | const int n, 12 | const float* input, 13 | float* output 14 | ) { 15 | int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; 16 | 17 | if (intIndex >= n) { 18 | return; 19 | } 20 | 21 | int intSample = blockIdx.z; 22 | int intChannel = blockIdx.y; 23 | 24 | float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; 25 | 26 | __syncthreads(); 27 | 28 | int intPaddedY = (intIndex / SIZE_3(input)) + 3*{{intStride}}; 29 | int intPaddedX = (intIndex % SIZE_3(input)) + 3*{{intStride}}; 30 | int intRearrange = ((SIZE_3(input) + 6*{{intStride}}) * intPaddedY) + intPaddedX; 31 | 32 | output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; 33 | } 34 | ''' 35 | 36 | kernel_Correlation_updateOutput = ''' 37 | extern "C" __global__ void kernel_Correlation_updateOutput( 38 | const int n, 39 | const float* rbot0, 40 | const float* rbot1, 41 | float* top 42 | ) { 43 | extern __shared__ char patch_data_char[]; 44 | 45 | float *patch_data = (float *)patch_data_char; 46 | 47 | // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 48 | int x1 = (blockIdx.x + 3) * {{intStride}}; 49 | int y1 = (blockIdx.y + 3) * {{intStride}}; 50 | int item = blockIdx.z; 51 | int ch_off = threadIdx.x; 52 | 53 | // Load 3D patch into shared shared memory 54 | for (int j = 0; j < 1; j++) { // HEIGHT 55 | for (int i = 0; i < 1; i++) { // WIDTH 56 | int ji_off = (j + i) * SIZE_3(rbot0); 57 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS 58 | int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; 59 | int idxPatchData = ji_off + ch; 60 | patch_data[idxPatchData] = rbot0[idx1]; 61 | } 62 | } 63 | } 64 | 65 | __syncthreads(); 66 | 67 | __shared__ float sum[32]; 68 | 69 | // Compute correlation 70 | for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { 71 | sum[ch_off] = 0; 72 | 73 | int s2o = (top_channel % 7 - 3) * {{intStride}}; 74 | int s2p = (top_channel / 7 - 3) * {{intStride}}; 75 | 76 | for (int j = 0; j < 1; j++) { // HEIGHT 77 | for (int i = 0; i < 1; i++) { // WIDTH 78 | int ji_off = (j + i) * SIZE_3(rbot0); 79 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS 80 | int x2 = x1 + s2o; 81 | int y2 = y1 + s2p; 82 | 83 | int idxPatchData = ji_off + ch; 84 | int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; 85 | 86 | sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; 87 | } 88 | } 89 | } 90 | 91 | __syncthreads(); 92 | 93 | if (ch_off == 0) { 94 | float total_sum = 0; 95 | for (int idx = 0; idx < 32; idx++) { 96 | total_sum += sum[idx]; 97 | } 98 | const int sumelems = SIZE_3(rbot0); 99 | const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; 100 | top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; 101 | } 102 | } 103 | } 104 | ''' 105 | 106 | kernel_Correlation_updateGradFirst = ''' 107 | #define ROUND_OFF 50000 108 | 109 | extern "C" __global__ void kernel_Correlation_updateGradFirst( 110 | const int n, 111 | const int intSample, 112 | const float* rbot0, 113 | const float* rbot1, 114 | const float* gradOutput, 115 | float* gradFirst, 116 | float* gradSecond 117 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 118 | int n = intIndex % SIZE_1(gradFirst); // channels 119 | int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 3*{{intStride}}; // w-pos 120 | int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 3*{{intStride}}; // h-pos 121 | 122 | // round_off is a trick to enable integer division with ceil, even for negative numbers 123 | // We use a large offset, for the inner part not to become negative. 124 | const int round_off = ROUND_OFF; 125 | const int round_off_s1 = {{intStride}} * round_off; 126 | 127 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: 128 | int xmin = (l - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}} 129 | int ymin = (m - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}} 130 | 131 | // Same here: 132 | int xmax = (l - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}}) / {{intStride}} 133 | int ymax = (m - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}}) / {{intStride}} 134 | 135 | float sum = 0; 136 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { 137 | xmin = max(0,xmin); 138 | xmax = min(SIZE_3(gradOutput)-1,xmax); 139 | 140 | ymin = max(0,ymin); 141 | ymax = min(SIZE_2(gradOutput)-1,ymax); 142 | 143 | for (int p = -3; p <= 3; p++) { 144 | for (int o = -3; o <= 3; o++) { 145 | // Get rbot1 data: 146 | int s2o = {{intStride}} * o; 147 | int s2p = {{intStride}} * p; 148 | int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; 149 | float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] 150 | 151 | // Index offset for gradOutput in following loops: 152 | int op = (p+3) * 7 + (o+3); // index[o,p] 153 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op); 154 | 155 | for (int y = ymin; y <= ymax; y++) { 156 | for (int x = xmin; x <= xmax; x++) { 157 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] 158 | sum += gradOutput[idxgradOutput] * bot1tmp; 159 | } 160 | } 161 | } 162 | } 163 | } 164 | const int sumelems = SIZE_1(gradFirst); 165 | const int bot0index = ((n * SIZE_2(gradFirst)) + (m-3*{{intStride}})) * SIZE_3(gradFirst) + (l-3*{{intStride}}); 166 | gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; 167 | } } 168 | ''' 169 | 170 | kernel_Correlation_updateGradSecond = ''' 171 | #define ROUND_OFF 50000 172 | 173 | extern "C" __global__ void kernel_Correlation_updateGradSecond( 174 | const int n, 175 | const int intSample, 176 | const float* rbot0, 177 | const float* rbot1, 178 | const float* gradOutput, 179 | float* gradFirst, 180 | float* gradSecond 181 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 182 | int n = intIndex % SIZE_1(gradSecond); // channels 183 | int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 3*{{intStride}}; // w-pos 184 | int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 3*{{intStride}}; // h-pos 185 | 186 | // round_off is a trick to enable integer division with ceil, even for negative numbers 187 | // We use a large offset, for the inner part not to become negative. 188 | const int round_off = ROUND_OFF; 189 | const int round_off_s1 = {{intStride}} * round_off; 190 | 191 | float sum = 0; 192 | for (int p = -3; p <= 3; p++) { 193 | for (int o = -3; o <= 3; o++) { 194 | int s2o = {{intStride}} * o; 195 | int s2p = {{intStride}} * p; 196 | 197 | //Get X,Y ranges and clamp 198 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: 199 | int xmin = (l - 3*{{intStride}} - s2o + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}} 200 | int ymin = (m - 3*{{intStride}} - s2p + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}} 201 | 202 | // Same here: 203 | int xmax = (l - 3*{{intStride}} - s2o + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}} - s2o) / {{intStride}} 204 | int ymax = (m - 3*{{intStride}} - s2p + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}} - s2p) / {{intStride}} 205 | 206 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { 207 | xmin = max(0,xmin); 208 | xmax = min(SIZE_3(gradOutput)-1,xmax); 209 | 210 | ymin = max(0,ymin); 211 | ymax = min(SIZE_2(gradOutput)-1,ymax); 212 | 213 | // Get rbot0 data: 214 | int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; 215 | float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] 216 | 217 | // Index offset for gradOutput in following loops: 218 | int op = (p+3) * 7 + (o+3); // index[o,p] 219 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op); 220 | 221 | for (int y = ymin; y <= ymax; y++) { 222 | for (int x = xmin; x <= xmax; x++) { 223 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] 224 | sum += gradOutput[idxgradOutput] * bot0tmp; 225 | } 226 | } 227 | } 228 | } 229 | } 230 | const int sumelems = SIZE_1(gradSecond); 231 | const int bot1index = ((n * SIZE_2(gradSecond)) + (m-3*{{intStride}})) * SIZE_3(gradSecond) + (l-3*{{intStride}}); 232 | gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; 233 | } } 234 | ''' 235 | 236 | def cupy_kernel(strFunction, objVariables): 237 | strKernel = globals()[strFunction].replace('{{intStride}}', str(objVariables['intStride'])) 238 | 239 | while True: 240 | objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) 241 | 242 | if objMatch is None: 243 | break 244 | # end 245 | 246 | intArg = int(objMatch.group(2)) 247 | 248 | strTensor = objMatch.group(4) 249 | intSizes = objVariables[strTensor].size() 250 | 251 | strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) 252 | # end 253 | 254 | while True: 255 | objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) 256 | 257 | if objMatch is None: 258 | break 259 | # end 260 | 261 | intArgs = int(objMatch.group(2)) 262 | strArgs = objMatch.group(4).split(',') 263 | 264 | strTensor = strArgs[0] 265 | intStrides = objVariables[strTensor].stride() 266 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] 267 | 268 | strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') 269 | # end 270 | 271 | return strKernel 272 | # end 273 | 274 | @cupy.util.memoize(for_each_device=True) 275 | def cupy_launch(strFunction, strKernel): 276 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) 277 | # end 278 | 279 | class _FunctionCorrelation(torch.autograd.Function): 280 | @staticmethod 281 | def forward(self, first, second, intStride): 282 | rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + (6 * intStride), first.shape[3] + (6 * intStride), first.shape[1] ]) 283 | rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + (6 * intStride), first.shape[3] + (6 * intStride), first.shape[1] ]) 284 | 285 | self.save_for_backward(first, second, rbot0, rbot1) 286 | 287 | self.intStride = intStride 288 | 289 | assert(first.is_contiguous() == True) 290 | assert(second.is_contiguous() == True) 291 | 292 | output = first.new_zeros([ first.shape[0], 49, int(math.ceil(first.shape[2] / intStride)), int(math.ceil(first.shape[3] / intStride)) ]) 293 | 294 | if first.is_cuda == True: 295 | n = first.shape[2] * first.shape[3] 296 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 297 | 'intStride': self.intStride, 298 | 'input': first, 299 | 'output': rbot0 300 | }))( 301 | grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]), 302 | block=tuple([ 16, 1, 1 ]), 303 | args=[ n, first.data_ptr(), rbot0.data_ptr() ] 304 | ) 305 | 306 | n = second.shape[2] * second.shape[3] 307 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 308 | 'intStride': self.intStride, 309 | 'input': second, 310 | 'output': rbot1 311 | }))( 312 | grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]), 313 | block=tuple([ 16, 1, 1 ]), 314 | args=[ n, second.data_ptr(), rbot1.data_ptr() ] 315 | ) 316 | 317 | n = output.shape[1] * output.shape[2] * output.shape[3] 318 | cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { 319 | 'intStride': self.intStride, 320 | 'rbot0': rbot0, 321 | 'rbot1': rbot1, 322 | 'top': output 323 | }))( 324 | grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]), 325 | block=tuple([ 32, 1, 1 ]), 326 | shared_mem=first.shape[1] * 4, 327 | args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ] 328 | ) 329 | 330 | elif first.is_cuda == False: 331 | raise NotImplementedError() 332 | 333 | # end 334 | 335 | return output 336 | # end 337 | 338 | @staticmethod 339 | def backward(self, gradOutput): 340 | first, second, rbot0, rbot1 = self.saved_tensors 341 | 342 | assert(gradOutput.is_contiguous() == True) 343 | 344 | gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None 345 | gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None 346 | 347 | if first.is_cuda == True: 348 | if gradFirst is not None: 349 | for intSample in range(first.shape[0]): 350 | n = first.shape[1] * first.shape[2] * first.shape[3] 351 | cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', { 352 | 'intStride': self.intStride, 353 | 'rbot0': rbot0, 354 | 'rbot1': rbot1, 355 | 'gradOutput': gradOutput, 356 | 'gradFirst': gradFirst, 357 | 'gradSecond': None 358 | }))( 359 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 360 | block=tuple([ 512, 1, 1 ]), 361 | args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ] 362 | ) 363 | # end 364 | # end 365 | 366 | if gradSecond is not None: 367 | for intSample in range(first.shape[0]): 368 | n = first.shape[1] * first.shape[2] * first.shape[3] 369 | cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', { 370 | 'intStride': self.intStride, 371 | 'rbot0': rbot0, 372 | 'rbot1': rbot1, 373 | 'gradOutput': gradOutput, 374 | 'gradFirst': None, 375 | 'gradSecond': gradSecond 376 | }))( 377 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 378 | block=tuple([ 512, 1, 1 ]), 379 | args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ] 380 | ) 381 | # end 382 | # end 383 | 384 | elif first.is_cuda == False: 385 | raise NotImplementedError() 386 | 387 | # end 388 | 389 | return gradFirst, gradSecond, None 390 | # end 391 | # end 392 | 393 | def FunctionCorrelation(tenFirst, tenSecond, intStride): 394 | return _FunctionCorrelation.apply(tenFirst, tenSecond, intStride) 395 | # end 396 | 397 | class ModuleCorrelation(torch.nn.Module): 398 | def __init__(self): 399 | super(ModuleCorrelation, self).__init__() 400 | # end 401 | 402 | def forward(self, tenFirst, tenSecond, intStride): 403 | return _FunctionCorrelation.apply(tenFirst, tenSecond, intStride) 404 | # end 405 | # end -------------------------------------------------------------------------------- /PF-AFN_train/models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | from torchvision import models 5 | from options.train_options import TrainOptions 6 | import os 7 | 8 | opt = TrainOptions().parse() 9 | 10 | class ResidualBlock(nn.Module): 11 | def __init__(self, in_features=64, norm_layer=nn.BatchNorm2d): 12 | super(ResidualBlock, self).__init__() 13 | self.relu = nn.ReLU(True) 14 | if norm_layer == None: 15 | self.block = nn.Sequential( 16 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), 19 | ) 20 | else: 21 | self.block = nn.Sequential( 22 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), 23 | norm_layer(in_features), 24 | nn.ReLU(inplace=True), 25 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False), 26 | norm_layer(in_features) 27 | ) 28 | 29 | def forward(self, x): 30 | residual = x 31 | out = self.block(x) 32 | out += residual 33 | out = self.relu(out) 34 | return out 35 | 36 | 37 | class ResUnetGenerator(nn.Module): 38 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, 39 | norm_layer=nn.BatchNorm2d, use_dropout=False): 40 | super(ResUnetGenerator, self).__init__() 41 | # construct unet structure 42 | unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) 43 | 44 | for i in range(num_downs - 5): 45 | unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 46 | unet_block = ResUnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 47 | unet_block = ResUnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 48 | unet_block = ResUnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 49 | unet_block = ResUnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) 50 | 51 | self.model = unet_block 52 | self.old_lr = opt.lr 53 | self.old_lr_gmm = 0.1*opt.lr 54 | 55 | def forward(self, input): 56 | return self.model(input) 57 | 58 | 59 | # Defines the submodule with skip connection. 60 | # X -------------------identity---------------------- X 61 | # |-- downsampling -- |submodule| -- upsampling --| 62 | class ResUnetSkipConnectionBlock(nn.Module): 63 | def __init__(self, outer_nc, inner_nc, input_nc=None, 64 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 65 | super(ResUnetSkipConnectionBlock, self).__init__() 66 | self.outermost = outermost 67 | use_bias = norm_layer == nn.InstanceNorm2d 68 | 69 | if input_nc is None: 70 | input_nc = outer_nc 71 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3, 72 | stride=2, padding=1, bias=use_bias) 73 | # add two resblock 74 | res_downconv = [ResidualBlock(inner_nc, norm_layer), ResidualBlock(inner_nc, norm_layer)] 75 | res_upconv = [ResidualBlock(outer_nc, norm_layer), ResidualBlock(outer_nc, norm_layer)] 76 | 77 | downrelu = nn.ReLU(True) 78 | uprelu = nn.ReLU(True) 79 | if norm_layer != None: 80 | downnorm = norm_layer(inner_nc) 81 | upnorm = norm_layer(outer_nc) 82 | 83 | if outermost: 84 | upsample = nn.Upsample(scale_factor=2, mode='nearest') 85 | upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) 86 | down = [downconv, downrelu] + res_downconv 87 | up = [upsample, upconv] 88 | model = down + [submodule] + up 89 | elif innermost: 90 | upsample = nn.Upsample(scale_factor=2, mode='nearest') 91 | upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) 92 | down = [downconv, downrelu] + res_downconv 93 | if norm_layer == None: 94 | up = [upsample, upconv, uprelu] + res_upconv 95 | else: 96 | up = [upsample, upconv, upnorm, uprelu] + res_upconv 97 | model = down + up 98 | else: 99 | upsample = nn.Upsample(scale_factor=2, mode='nearest') 100 | upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias) 101 | if norm_layer == None: 102 | down = [downconv, downrelu] + res_downconv 103 | up = [upsample, upconv, uprelu] + res_upconv 104 | else: 105 | down = [downconv, downnorm, downrelu] + res_downconv 106 | up = [upsample, upconv, upnorm, uprelu] + res_upconv 107 | 108 | if use_dropout: 109 | model = down + [submodule] + up + [nn.Dropout(0.5)] 110 | else: 111 | model = down + [submodule] + up 112 | 113 | self.model = nn.Sequential(*model) 114 | 115 | def forward(self, x): 116 | if self.outermost: 117 | return self.model(x) 118 | else: 119 | return torch.cat([x, self.model(x)], 1) 120 | 121 | 122 | class Vgg19(nn.Module): 123 | def __init__(self, requires_grad=False): 124 | super(Vgg19, self).__init__() 125 | vgg_pretrained_features = models.vgg19(pretrained=True).features 126 | self.slice1 = nn.Sequential() 127 | self.slice2 = nn.Sequential() 128 | self.slice3 = nn.Sequential() 129 | self.slice4 = nn.Sequential() 130 | self.slice5 = nn.Sequential() 131 | for x in range(2): 132 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 133 | for x in range(2, 7): 134 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 135 | for x in range(7, 12): 136 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 137 | for x in range(12, 21): 138 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 139 | for x in range(21, 30): 140 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 141 | if not requires_grad: 142 | for param in self.parameters(): 143 | param.requires_grad = False 144 | 145 | def forward(self, X): 146 | h_relu1 = self.slice1(X) 147 | h_relu2 = self.slice2(h_relu1) 148 | h_relu3 = self.slice3(h_relu2) 149 | h_relu4 = self.slice4(h_relu3) 150 | h_relu5 = self.slice5(h_relu4) 151 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 152 | return out 153 | 154 | class VGGLoss(nn.Module): 155 | def __init__(self, layids = None): 156 | super(VGGLoss, self).__init__() 157 | self.vgg = Vgg19() 158 | self.vgg.cuda() 159 | self.criterion = nn.L1Loss() 160 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] 161 | self.layids = layids 162 | 163 | def forward(self, x, y): 164 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 165 | loss = 0 166 | if self.layids is None: 167 | self.layids = list(range(len(x_vgg))) 168 | for i in self.layids: 169 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 170 | return loss 171 | 172 | def save_checkpoint(model, save_path): 173 | if not os.path.exists(os.path.dirname(save_path)): 174 | os.makedirs(os.path.dirname(save_path)) 175 | torch.save(model.state_dict(), save_path) 176 | 177 | 178 | def load_checkpoint_parallel(model, checkpoint_path): 179 | 180 | if not os.path.exists(checkpoint_path): 181 | print('No checkpoint!') 182 | return 183 | 184 | checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(opt.local_rank)) 185 | checkpoint_new = model.state_dict() 186 | for param in checkpoint_new: 187 | checkpoint_new[param] = checkpoint[param] 188 | model.load_state_dict(checkpoint_new) 189 | 190 | def load_checkpoint_part_parallel(model, checkpoint_path): 191 | 192 | if not os.path.exists(checkpoint_path): 193 | print('No checkpoint!') 194 | return 195 | checkpoint = torch.load(checkpoint_path,map_location='cuda:{}'.format(opt.local_rank)) 196 | checkpoint_new = model.state_dict() 197 | for param in checkpoint_new: 198 | if 'cond_' not in param and 'aflow_net.netRefine' not in param: 199 | checkpoint_new[param] = checkpoint[param] 200 | model.load_state_dict(checkpoint_new) 201 | 202 | 203 | -------------------------------------------------------------------------------- /PF-AFN_train/options/__init__.py: -------------------------------------------------------------------------------- 1 | # options_init -------------------------------------------------------------------------------- /PF-AFN_train/options/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/options/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/options/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/options/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_train/options/__pycache__/base_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/options/__pycache__/base_options.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/options/__pycache__/base_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/options/__pycache__/base_options.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_train/options/__pycache__/test_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/options/__pycache__/test_options.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/options/__pycache__/train_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/options/__pycache__/train_options.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/options/__pycache__/train_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/options/__pycache__/train_options.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_train/options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | 6 | class BaseOptions(): 7 | def __init__(self): 8 | self.parser = argparse.ArgumentParser() 9 | self.initialized = False 10 | 11 | def initialize(self): 12 | # experiment specifics 13 | self.parser.add_argument('--name', type=str, default='flow', help='name of the experiment. It decides where to store samples and models') 14 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 15 | self.parser.add_argument('--num_gpus', type=int, default=1, help='the number of gpus') 16 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 17 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 18 | self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator') 19 | self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit") 20 | self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose') 21 | 22 | # input/output sizes 23 | self.parser.add_argument('--batchSize', type=int, default=32, help='input batch size') 24 | self.parser.add_argument('--loadSize', type=int, default=512, help='scale images to this size') 25 | self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size') 26 | self.parser.add_argument('--label_nc', type=int, default=20, help='# of input label channels') 27 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 28 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 29 | 30 | # for setting inputs 31 | self.parser.add_argument('--dataroot', type=str,default='dataset/VITON_traindata/') 32 | self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 33 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 34 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') 35 | self.parser.add_argument('--nThreads', default=1, type=int, help='# threads for loading data') 36 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 37 | 38 | # for displays 39 | self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size') 40 | self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') 41 | 42 | # for generator 43 | self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG') 44 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 45 | self.parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG') 46 | self.parser.add_argument('--n_blocks_global', type=int, default=4, help='number of residual blocks in the global generator network') 47 | self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network') 48 | self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use') 49 | self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer') 50 | self.parser.add_argument('--tv_weight', type=float, default=0.1, help='weight for TV loss') 51 | 52 | self.initialized = True 53 | 54 | def parse(self, save=True): 55 | if not self.initialized: 56 | self.initialize() 57 | self.opt = self.parser.parse_args() 58 | self.opt.isTrain = self.isTrain # train or test 59 | 60 | str_ids = self.opt.gpu_ids.split(',') 61 | self.opt.gpu_ids = [] 62 | for str_id in str_ids: 63 | id = int(str_id) 64 | if id >= 0: 65 | self.opt.gpu_ids.append(id) 66 | 67 | # set gpu ids 68 | if len(self.opt.gpu_ids) > 0: 69 | torch.cuda.set_device(self.opt.gpu_ids[0]) 70 | 71 | args = vars(self.opt) 72 | 73 | print('------------ Options -------------') 74 | for k, v in sorted(args.items()): 75 | print('%s: %s' % (str(k), str(v))) 76 | print('-------------- End ----------------') 77 | 78 | # save to the disk 79 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 80 | util.mkdirs(expr_dir) 81 | if save and not self.opt.continue_train: 82 | file_name = os.path.join(expr_dir, 'opt.txt') 83 | with open(file_name, 'wt') as opt_file: 84 | opt_file.write('------------ Options -------------\n') 85 | for k, v in sorted(args.items()): 86 | opt_file.write('%s: %s\n' % (str(k), str(v))) 87 | opt_file.write('-------------- End ----------------\n') 88 | return self.opt 89 | -------------------------------------------------------------------------------- /PF-AFN_train/options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | class TrainOptions(BaseOptions): 4 | def initialize(self): 5 | BaseOptions.initialize(self) 6 | # for displays 7 | self.parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',help='job launcher') 8 | self.parser.add_argument('--local_rank', type=int, default=0) 9 | 10 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 11 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 12 | self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results') 13 | self.parser.add_argument('--save_epoch_freq', type=int, default=20, help='frequency of saving checkpoints at the end of epochs') 14 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 15 | self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') 16 | 17 | # for training 18 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 19 | self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location') 20 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 21 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 22 | self.parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate') 23 | self.parser.add_argument('--niter_decay', type=int, default=50, help='# of iter to linearly decay learning rate to zero') 24 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 25 | self.parser.add_argument('--lr', type=float, default=0.00005, help='initial learning rate for adam') 26 | self.parser.add_argument('--PFAFN_warp_checkpoint', type=str, help='load the pretrained model from the specified location') 27 | self.parser.add_argument('--PFAFN_gen_checkpoint', type=str, help='load the pretrained model from the specified location') 28 | self.parser.add_argument('--PBAFN_warp_checkpoint', type=str, help='load the pretrained model from the specified location') 29 | self.parser.add_argument('--PBAFN_gen_checkpoint', type=str, help='load the pretrained model from the specified location') 30 | # for discriminators 31 | self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use') 32 | self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') 33 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 34 | self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') 35 | self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') 36 | self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss') 37 | self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 38 | self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images') 39 | 40 | self.isTrain = True 41 | -------------------------------------------------------------------------------- /PF-AFN_train/runs/readme.txt: -------------------------------------------------------------------------------- 1 | The tensorboard logs will be saved here. 2 | -------------------------------------------------------------------------------- /PF-AFN_train/sample/readme.txt: -------------------------------------------------------------------------------- 1 | The images during training will be saved here. 2 | -------------------------------------------------------------------------------- /PF-AFN_train/scripts/train_PBAFN_e2e.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4736 train_PBAFN_e2e.py --name PBAFN_e2e \ 2 | --PBAFN_warp_checkpoint 'checkpoints/PBAFN_stage1/PBAFN_warp_epoch_101.pth' --resize_or_crop None --verbose --tf_log --batchSize 4 --num_gpus 8 --label_nc 14 --launcher pytorch 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /PF-AFN_train/scripts/train_PBAFN_stage1.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=7129 train_PBAFN_stage1.py --name PBAFN_stage1 \ 2 | --resize_or_crop None --verbose --tf_log --batchSize 4 --num_gpus 8 --label_nc 14 --launcher pytorch 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /PF-AFN_train/scripts/train_PFAFN_e2e.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=7129 train_PFAFN_e2e.py --name PFAFN_e2e \ 2 | --PFAFN_warp_checkpoint 'checkpoints/PFAFN_stage1/PFAFN_warp_epoch_201.pth' \ 3 | --PBAFN_warp_checkpoint 'checkpoints/PBAFN_e2e/PBAFN_warp_epoch_101.pth' --PBAFN_gen_checkpoint 'checkpoints/PBAFN_e2e/PBAFN_gen_epoch_101.pth' \ 4 | --resize_or_crop None --verbose --tf_log --batchSize 4 --num_gpus 8 --label_nc 14 --launcher pytorch 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /PF-AFN_train/scripts/train_PFAFN_stage1.sh: -------------------------------------------------------------------------------- 1 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4703 train_PFAFN_stage1.py --name PFAFN_stage1 \ 2 | --PBAFN_warp_checkpoint 'checkpoints/PBAFN_e2e/PBAFN_warp_epoch_101.pth' --PBAFN_gen_checkpoint 'checkpoints/PBAFN_e2e/PBAFN_gen_epoch_101.pth' \ 3 | --lr 0.00003 --niter 100 --niter_decay 100 --resize_or_crop None --verbose --tf_log --batchSize 4 --num_gpus 8 --label_nc 14 --launcher pytorch 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /PF-AFN_train/train_PBAFN_e2e.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from models.networks import ResUnetGenerator, VGGLoss, save_checkpoint, load_checkpoint_parallel 4 | from models.afwm import TVLoss, AFWM 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import os 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import DataLoader 11 | from torch.utils.data.distributed import DistributedSampler 12 | from tensorboardX import SummaryWriter 13 | import cv2 14 | import datetime 15 | 16 | opt = TrainOptions().parse() 17 | path = 'runs/'+opt.name 18 | os.makedirs(path,exist_ok=True) 19 | 20 | def CreateDataset(opt): 21 | from data.aligned_dataset import AlignedDataset 22 | dataset = AlignedDataset() 23 | print("dataset [%s] was created" % (dataset.name())) 24 | dataset.initialize(opt) 25 | return dataset 26 | 27 | os.makedirs('sample',exist_ok=True) 28 | opt = TrainOptions().parse() 29 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') 30 | 31 | torch.cuda.set_device(opt.local_rank) 32 | torch.distributed.init_process_group( 33 | 'nccl', 34 | init_method='env://' 35 | ) 36 | device = torch.device(f'cuda:{opt.local_rank}') 37 | 38 | start_epoch, epoch_iter = 1, 0 39 | 40 | train_data = CreateDataset(opt) 41 | train_sampler = DistributedSampler(train_data) 42 | train_loader = DataLoader(train_data, batch_size=opt.batchSize, shuffle=False, 43 | num_workers=4, pin_memory=True, sampler=train_sampler) 44 | dataset_size = len(train_loader) 45 | 46 | warp_model = AFWM(opt, 45) 47 | print(warp_model) 48 | warp_model.train() 49 | warp_model.cuda() 50 | load_checkpoint_parallel(warp_model, opt.PBAFN_warp_checkpoint) 51 | 52 | gen_model = ResUnetGenerator(8, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d) 53 | print(gen_model) 54 | gen_model.train() 55 | gen_model.cuda() 56 | 57 | warp_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(warp_model).to(device) 58 | gen_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(gen_model).to(device) 59 | 60 | if opt.isTrain and len(opt.gpu_ids): 61 | model = torch.nn.parallel.DistributedDataParallel(warp_model, device_ids=[opt.local_rank]) 62 | model_gen = torch.nn.parallel.DistributedDataParallel(gen_model, device_ids=[opt.local_rank]) 63 | 64 | criterionL1 = nn.L1Loss() 65 | criterionVGG = VGGLoss() 66 | # optimizer 67 | params_warp = [p for p in model.parameters()] 68 | params_gen = [p for p in model_gen.parameters()] 69 | optimizer_warp = torch.optim.Adam(params_warp, lr=0.2*opt.lr, betas=(opt.beta1, 0.999)) 70 | optimizer_gen = torch.optim.Adam(params_gen, lr=opt.lr, betas=(opt.beta1, 0.999)) 71 | 72 | total_steps = (start_epoch-1) * dataset_size + epoch_iter 73 | 74 | step = 0 75 | step_per_batch = dataset_size 76 | 77 | if opt.local_rank == 0: 78 | writer = SummaryWriter(path) 79 | 80 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): 81 | epoch_start_time = time.time() 82 | if epoch != start_epoch: 83 | epoch_iter = epoch_iter % dataset_size 84 | 85 | train_sampler.set_epoch(epoch) 86 | 87 | for i, data in enumerate(train_loader): 88 | 89 | iter_start_time = time.time() 90 | 91 | total_steps += 1 92 | epoch_iter += 1 93 | save_fake = True 94 | 95 | t_mask = torch.FloatTensor((data['label'].cpu().numpy()==7).astype(np.float)) 96 | data['label'] = data['label']*(1-t_mask)+t_mask*4 97 | edge = data['edge'] 98 | pre_clothes_edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int)) 99 | clothes = data['color'] 100 | clothes = clothes * pre_clothes_edge 101 | person_clothes_edge = torch.FloatTensor((data['label'].cpu().numpy()==4).astype(np.int)) 102 | real_image = data['image'] 103 | person_clothes = real_image*person_clothes_edge 104 | pose = data['pose'] 105 | size = data['label'].size() 106 | oneHot_size1 = (size[0], 25, size[2], size[3]) 107 | densepose = torch.cuda.FloatTensor(torch.Size(oneHot_size1)).zero_() 108 | densepose = densepose.scatter_(1,data['densepose'].data.long().cuda(),1.0) 109 | densepose_fore = data['densepose']/24.0 110 | face_mask = torch.FloatTensor((data['label'].cpu().numpy()==1).astype(np.int))+torch.FloatTensor((data['label'].cpu().numpy()==12).astype(np.int)) 111 | other_clothes_mask = torch.FloatTensor((data['label'].cpu().numpy()==5).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy()==6).astype(np.int))\ 112 | + torch.FloatTensor((data['label'].cpu().numpy()==8).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy()==9).astype(np.int))\ 113 | + torch.FloatTensor((data['label'].cpu().numpy()==10).astype(np.int)) 114 | face_img = face_mask * real_image 115 | other_clothes_img = other_clothes_mask * real_image 116 | preserve_region = face_img + other_clothes_img 117 | preserve_mask = torch.cat([face_mask, other_clothes_mask],1) 118 | concat = torch.cat([preserve_mask.cuda(), densepose, pose.cuda()],1) 119 | arm_mask = torch.FloatTensor((data['label'].cpu().numpy()==11).astype(np.float)) + torch.FloatTensor((data['label'].cpu().numpy()==13).astype(np.float)) 120 | hand_mask = torch.FloatTensor((data['densepose'].cpu().numpy()==3).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy()==4).astype(np.int)) 121 | hand_mask = arm_mask*hand_mask 122 | hand_img = hand_mask*real_image 123 | dense_preserve_mask = torch.FloatTensor((data['densepose'].cpu().numpy()==15).astype(np.int))+torch.FloatTensor((data['densepose'].cpu().numpy()==16).astype(np.int))\ 124 | +torch.FloatTensor((data['densepose'].cpu().numpy()==17).astype(np.int))+torch.FloatTensor((data['densepose'].cpu().numpy()==18).astype(np.int))\ 125 | +torch.FloatTensor((data['densepose'].cpu().numpy()==19).astype(np.int))+torch.FloatTensor((data['densepose'].cpu().numpy()==20).astype(np.int))\ 126 | +torch.FloatTensor((data['densepose'].cpu().numpy()==21).astype(np.int))+torch.FloatTensor((data['densepose'].cpu().numpy()==22)) 127 | dense_preserve_mask = dense_preserve_mask.cuda()*(1-person_clothes_edge.cuda()) 128 | preserve_region = face_img + other_clothes_img +hand_img 129 | 130 | flow_out = model(concat.cuda(), clothes.cuda(), pre_clothes_edge.cuda()) 131 | warped_cloth, last_flow, _1, _2, delta_list, x_all, x_edge_all, delta_x_all, delta_y_all = flow_out 132 | 133 | epsilon = 0.001 134 | loss_smooth = sum([TVLoss(x) for x in delta_list]) 135 | warp_loss = 0 136 | 137 | for num in range(5): 138 | cur_person_clothes = F.interpolate(person_clothes, scale_factor=0.5**(4-num), mode='bilinear') 139 | cur_person_clothes_edge = F.interpolate(person_clothes_edge, scale_factor=0.5**(4-num), mode='bilinear') 140 | loss_l1 = criterionL1(x_all[num], cur_person_clothes.cuda()) 141 | loss_vgg = criterionVGG(x_all[num], cur_person_clothes.cuda()) 142 | loss_edge = criterionL1(x_edge_all[num], cur_person_clothes_edge.cuda()) 143 | b,c,h,w = delta_x_all[num].shape 144 | loss_flow_x = (delta_x_all[num].pow(2) + epsilon*epsilon).pow(0.45) 145 | loss_flow_x = torch.sum(loss_flow_x) / (b*c*h*w) 146 | loss_flow_y = (delta_y_all[num].pow(2) + epsilon*epsilon).pow(0.45) 147 | loss_flow_y = torch.sum(loss_flow_y) / (b*c*h*w) 148 | loss_second_smooth = loss_flow_x + loss_flow_y 149 | warp_loss = warp_loss + (num+1) * loss_l1 + (num+1) * 0.2 * loss_vgg + (num+1) * 2 * loss_edge + (num+1) * 6 * loss_second_smooth 150 | 151 | warp_loss = 0.01 * loss_smooth + warp_loss 152 | 153 | if opt.local_rank == 0: 154 | writer.add_scalar('warp_loss', warp_loss, step) 155 | 156 | warped_prod_edge = x_edge_all[4] 157 | gen_inputs = torch.cat([preserve_region.cuda(), warped_cloth, warped_prod_edge, dense_preserve_mask], 1) 158 | 159 | gen_outputs = model_gen(gen_inputs) 160 | p_rendered, m_composite = torch.split(gen_outputs, [3, 1], 1) 161 | p_rendered = torch.tanh(p_rendered) 162 | m_composite = torch.sigmoid(m_composite) 163 | m_composite1 = m_composite * warped_prod_edge 164 | m_composite = person_clothes_edge.cuda()*m_composite1 165 | p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite) 166 | 167 | loss_mask_l1 = torch.mean(torch.abs(1 - m_composite)) 168 | loss_l1 = criterionL1(p_tryon, real_image.cuda()) 169 | loss_vgg = criterionVGG(p_tryon,real_image.cuda()) 170 | bg_loss_l1 = criterionL1(p_rendered, real_image.cuda()) 171 | bg_loss_vgg = criterionVGG(p_rendered, real_image.cuda()) 172 | gen_loss = (loss_l1 * 5 + loss_vgg + bg_loss_l1 * 5 + bg_loss_vgg + loss_mask_l1) 173 | 174 | 175 | if opt.local_rank == 0: 176 | writer.add_scalar('gen_loss', gen_loss, step) 177 | 178 | loss_all = 0.5 * warp_loss + 1.0 * gen_loss 179 | 180 | if opt.local_rank == 0: 181 | writer.add_scalar('loss_all', loss_all, step) 182 | 183 | optimizer_warp.zero_grad() 184 | optimizer_gen.zero_grad() 185 | loss_all.backward() 186 | optimizer_warp.step() 187 | optimizer_gen.step() 188 | 189 | ############## Display results and errors ########## 190 | path = 'sample/'+opt.name 191 | os.makedirs(path,exist_ok=True) 192 | if step % 1000 == 0: 193 | if opt.local_rank == 0: 194 | a = real_image.float().cuda() 195 | b = person_clothes.cuda() 196 | c = clothes.cuda() 197 | d = torch.cat([densepose_fore.cuda(),densepose_fore.cuda(),densepose_fore.cuda()],1) 198 | e = warped_cloth 199 | f = torch.cat([warped_prod_edge,warped_prod_edge,warped_prod_edge],1) 200 | g = preserve_region.cuda() 201 | h = torch.cat([dense_preserve_mask,dense_preserve_mask,dense_preserve_mask],1) 202 | i = p_rendered 203 | j = torch.cat([m_composite1,m_composite1,m_composite1],1) 204 | k = p_tryon 205 | combine = torch.cat([a[0],b[0],c[0],d[0],e[0],f[0],g[0],h[0],i[0],j[0],k[0]], 2).squeeze() 206 | cv_img = (combine.permute(1,2,0).detach().cpu().numpy()+1)/2 207 | writer.add_image('combine', (combine.data + 1) / 2.0, step) 208 | rgb = (cv_img*255).astype(np.uint8) 209 | bgr = cv2.cvtColor(rgb,cv2.COLOR_RGB2BGR) 210 | cv2.imwrite('sample/'+opt.name+'/'+str(step)+'.jpg',bgr) 211 | 212 | step += 1 213 | iter_end_time = time.time() 214 | iter_delta_time = iter_end_time - iter_start_time 215 | step_delta = (step_per_batch-step%step_per_batch) + step_per_batch*(opt.niter + opt.niter_decay-epoch) 216 | eta = iter_delta_time*step_delta 217 | eta = str(datetime.timedelta(seconds=int(eta))) 218 | time_stamp = datetime.datetime.now() 219 | now = time_stamp.strftime('%Y.%m.%d-%H:%M:%S') 220 | 221 | if step % 100 == 0: 222 | if opt.local_rank == 0: 223 | print('{}:{}:[step-{}]--[loss-{:.6f}]--[loss-{:.6f}]--[ETA-{}]'.format(now, epoch_iter, step, warp_loss, gen_loss, eta)) 224 | 225 | if epoch_iter >= dataset_size: 226 | break 227 | 228 | iter_end_time = time.time() 229 | if opt.local_rank == 0: 230 | print('End of epoch %d / %d \t Time Taken: %d sec' % 231 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 232 | 233 | ### save model for this epoch 234 | if epoch % opt.save_epoch_freq == 0: 235 | if opt.local_rank == 0: 236 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 237 | save_checkpoint(model.module, os.path.join(opt.checkpoints_dir, opt.name, 'PBAFN_warp_epoch_%03d.pth' % (epoch+1))) 238 | save_checkpoint(model_gen.module, os.path.join(opt.checkpoints_dir, opt.name, 'PBAFN_gen_epoch_%03d.pth' % (epoch+1))) 239 | 240 | if epoch > opt.niter: 241 | model.module.update_learning_rate_warp(optimizer_warp) 242 | model.module.update_learning_rate(optimizer_gen) 243 | -------------------------------------------------------------------------------- /PF-AFN_train/train_PBAFN_stage1.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from models.networks import VGGLoss,save_checkpoint 4 | from models.afwm import TVLoss,AFWM 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import os 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import DataLoader 11 | from torch.utils.data.distributed import DistributedSampler 12 | from tensorboardX import SummaryWriter 13 | import cv2 14 | import datetime 15 | 16 | opt = TrainOptions().parse() 17 | path = 'runs/'+opt.name 18 | os.makedirs(path,exist_ok=True) 19 | 20 | def CreateDataset(opt): 21 | from data.aligned_dataset import AlignedDataset 22 | dataset = AlignedDataset() 23 | print("dataset [%s] was created" % (dataset.name())) 24 | dataset.initialize(opt) 25 | return dataset 26 | 27 | os.makedirs('sample',exist_ok=True) 28 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') 29 | 30 | torch.cuda.set_device(opt.local_rank) 31 | torch.distributed.init_process_group( 32 | 'nccl', 33 | init_method='env://' 34 | ) 35 | device = torch.device(f'cuda:{opt.local_rank}') 36 | 37 | start_epoch, epoch_iter = 1, 0 38 | 39 | train_data = CreateDataset(opt) 40 | train_sampler = DistributedSampler(train_data) 41 | train_loader = DataLoader(train_data, batch_size=opt.batchSize, shuffle=False, 42 | num_workers=4, pin_memory=True, sampler=train_sampler) 43 | dataset_size = len(train_loader) 44 | print('#training images = %d' % dataset_size) 45 | 46 | warp_model = AFWM(opt, 45) 47 | print(warp_model) 48 | warp_model.train() 49 | warp_model.cuda() 50 | warp_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(warp_model).to(device) 51 | 52 | if opt.isTrain and len(opt.gpu_ids): 53 | model = torch.nn.parallel.DistributedDataParallel(warp_model, device_ids=[opt.local_rank]) 54 | 55 | criterionL1 = nn.L1Loss() 56 | criterionVGG = VGGLoss() 57 | 58 | params_warp = [p for p in model.parameters()] 59 | optimizer_warp = torch.optim.Adam(params_warp, lr=opt.lr, betas=(opt.beta1, 0.999)) 60 | 61 | total_steps = (start_epoch-1) * dataset_size + epoch_iter 62 | step = 0 63 | step_per_batch = dataset_size 64 | 65 | if opt.local_rank == 0: 66 | writer = SummaryWriter(path) 67 | 68 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): 69 | epoch_start_time = time.time() 70 | if epoch != start_epoch: 71 | epoch_iter = epoch_iter % dataset_size 72 | 73 | train_sampler.set_epoch(epoch) 74 | 75 | for i, data in enumerate(train_loader): 76 | iter_start_time = time.time() 77 | 78 | total_steps += 1 79 | epoch_iter += 1 80 | save_fake = True 81 | 82 | t_mask = torch.FloatTensor((data['label'].cpu().numpy()==7).astype(np.float)) 83 | data['label'] = data['label']*(1-t_mask)+t_mask*4 84 | edge = data['edge'] 85 | pre_clothes_edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int)) 86 | clothes = data['color'] 87 | clothes = clothes * pre_clothes_edge 88 | person_clothes_edge = torch.FloatTensor((data['label'].cpu().numpy()==4).astype(np.int)) 89 | real_image = data['image'] 90 | person_clothes = real_image * person_clothes_edge 91 | pose = data['pose'] 92 | size = data['label'].size() 93 | oneHot_size1 = (size[0], 25, size[2], size[3]) 94 | densepose = torch.cuda.FloatTensor(torch.Size(oneHot_size1)).zero_() 95 | densepose = densepose.scatter_(1,data['densepose'].data.long().cuda(),1.0) 96 | densepose_fore = data['densepose']/24.0 97 | face_mask = torch.FloatTensor((data['label'].cpu().numpy()==1).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy()==12).astype(np.int)) 98 | other_clothes_mask = torch.FloatTensor((data['label'].cpu().numpy()==5).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy()==6).astype(np.int)) + \ 99 | torch.FloatTensor((data['label'].cpu().numpy()==8).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy()==9).astype(np.int)) + \ 100 | torch.FloatTensor((data['label'].cpu().numpy()==10).astype(np.int)) 101 | preserve_mask = torch.cat([face_mask,other_clothes_mask],1) 102 | concat = torch.cat([preserve_mask.cuda(),densepose,pose.cuda()],1) 103 | 104 | flow_out = model(concat.cuda(), clothes.cuda(), pre_clothes_edge.cuda()) 105 | warped_cloth, last_flow, _1, _2, delta_list, x_all, x_edge_all, delta_x_all, delta_y_all = flow_out 106 | warped_prod_edge = x_edge_all[4] 107 | 108 | epsilon = 0.001 109 | loss_smooth = sum([TVLoss(x) for x in delta_list]) 110 | loss_all = 0 111 | 112 | for num in range(5): 113 | cur_person_clothes = F.interpolate(person_clothes, scale_factor=0.5**(4-num), mode='bilinear') 114 | cur_person_clothes_edge = F.interpolate(person_clothes_edge, scale_factor=0.5**(4-num), mode='bilinear') 115 | loss_l1 = criterionL1(x_all[num], cur_person_clothes.cuda()) 116 | loss_vgg = criterionVGG(x_all[num], cur_person_clothes.cuda()) 117 | loss_edge = criterionL1(x_edge_all[num], cur_person_clothes_edge.cuda()) 118 | b,c,h,w = delta_x_all[num].shape 119 | loss_flow_x = (delta_x_all[num].pow(2)+ epsilon*epsilon).pow(0.45) 120 | loss_flow_x = torch.sum(loss_flow_x)/(b*c*h*w) 121 | loss_flow_y = (delta_y_all[num].pow(2)+ epsilon*epsilon).pow(0.45) 122 | loss_flow_y = torch.sum(loss_flow_y)/(b*c*h*w) 123 | loss_second_smooth = loss_flow_x + loss_flow_y 124 | loss_all = loss_all + (num+1) * loss_l1 + (num + 1) * 0.2 * loss_vgg + (num+1) * 2 * loss_edge + (num + 1) * 6 * loss_second_smooth 125 | 126 | loss_all = 0.01 * loss_smooth + loss_all 127 | 128 | if opt.local_rank == 0: 129 | writer.add_scalar('loss_all', loss_all, step) 130 | 131 | optimizer_warp.zero_grad() 132 | loss_all.backward() 133 | optimizer_warp.step() 134 | ############## Display results and errors ########## 135 | 136 | path = 'sample/'+opt.name 137 | os.makedirs(path,exist_ok=True) 138 | if step % 1000 == 0: 139 | if opt.local_rank == 0: 140 | a = real_image.float().cuda() 141 | b = person_clothes.cuda() 142 | c = clothes.cuda() 143 | d = torch.cat([densepose_fore.cuda(),densepose_fore.cuda(),densepose_fore.cuda()],1) 144 | e = warped_cloth 145 | f = torch.cat([warped_prod_edge,warped_prod_edge,warped_prod_edge],1) 146 | combine = torch.cat([a[0],b[0],c[0],d[0],e[0],f[0]], 2).squeeze() 147 | cv_img=(combine.permute(1,2,0).detach().cpu().numpy()+1)/2 148 | writer.add_image('combine', (combine.data + 1) / 2.0, step) 149 | rgb=(cv_img*255).astype(np.uint8) 150 | bgr=cv2.cvtColor(rgb,cv2.COLOR_RGB2BGR) 151 | cv2.imwrite('sample/'+opt.name+'/'+str(step)+'.jpg',bgr) 152 | 153 | step += 1 154 | iter_end_time = time.time() 155 | iter_delta_time = iter_end_time - iter_start_time 156 | step_delta = (step_per_batch-step%step_per_batch) + step_per_batch*(opt.niter + opt.niter_decay-epoch) 157 | eta = iter_delta_time*step_delta 158 | eta = str(datetime.timedelta(seconds=int(eta))) 159 | time_stamp = datetime.datetime.now() 160 | now = time_stamp.strftime('%Y.%m.%d-%H:%M:%S') 161 | if step % 100 == 0: 162 | if opt.local_rank == 0: 163 | print('{}:{}:[step-{}]--[loss-{:.6f}]--[ETA-{}]'.format(now, epoch_iter,step, loss_all,eta)) 164 | 165 | if epoch_iter >= dataset_size: 166 | break 167 | 168 | # end of epoch 169 | iter_end_time = time.time() 170 | if opt.local_rank == 0: 171 | print('End of epoch %d / %d \t Time Taken: %d sec' % 172 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 173 | 174 | ### save model for this epoch 175 | if epoch % opt.save_epoch_freq == 0: 176 | if opt.local_rank == 0: 177 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 178 | save_checkpoint(model.module, os.path.join(opt.checkpoints_dir, opt.name, 'PBAFN_warp_epoch_%03d.pth' % (epoch+1))) 179 | 180 | if epoch > opt.niter: 181 | model.module.update_learning_rate(optimizer_warp) 182 | -------------------------------------------------------------------------------- /PF-AFN_train/train_PFAFN_e2e.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from models.networks import ResUnetGenerator, VGGLoss, save_checkpoint, load_checkpoint_parallel 4 | from models.afwm import TVLoss, AFWM 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import os 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import DataLoader 11 | from torch.utils.data.distributed import DistributedSampler 12 | from tensorboardX import SummaryWriter 13 | import datetime 14 | import cv2 15 | 16 | opt = TrainOptions().parse() 17 | path = 'runs/' + opt.name 18 | os.makedirs(path, exist_ok=True) 19 | 20 | 21 | def CreateDataset(opt): 22 | from data.aligned_dataset import AlignedDataset 23 | dataset = AlignedDataset() 24 | print("dataset [%s] was created" % (dataset.name())) 25 | dataset.initialize(opt) 26 | return dataset 27 | 28 | 29 | os.makedirs('sample', exist_ok=True) 30 | opt = TrainOptions().parse() 31 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') 32 | 33 | torch.cuda.set_device(opt.local_rank) 34 | torch.distributed.init_process_group( 35 | 'nccl', 36 | init_method='env://' 37 | ) 38 | device = torch.device(f'cuda:{opt.local_rank}') 39 | 40 | start_epoch, epoch_iter = 1, 0 41 | 42 | train_data = CreateDataset(opt) 43 | train_sampler = DistributedSampler(train_data) 44 | train_loader = DataLoader(train_data, batch_size=opt.batchSize, shuffle=False, 45 | num_workers=4, pin_memory=True, sampler=train_sampler) 46 | dataset_size = len(train_loader) 47 | 48 | PF_warp_model = AFWM(opt, 3) 49 | print(PF_warp_model) 50 | PF_warp_model.train() 51 | PF_warp_model.cuda() 52 | load_checkpoint_parallel(PF_warp_model, opt.PFAFN_warp_checkpoint) 53 | 54 | PF_gen_model = ResUnetGenerator(7, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d) 55 | print(PF_gen_model) 56 | PF_gen_model.train() 57 | PF_gen_model.cuda() 58 | 59 | PB_warp_model = AFWM(opt, 45) 60 | print(PB_warp_model) 61 | PB_warp_model.eval() 62 | PB_warp_model.cuda() 63 | load_checkpoint_parallel(PB_warp_model, opt.PBAFN_warp_checkpoint) 64 | 65 | PB_gen_model = ResUnetGenerator(8, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d) 66 | print(PB_gen_model) 67 | PB_gen_model.eval() 68 | PB_gen_model.cuda() 69 | load_checkpoint_parallel(PB_gen_model, opt.PBAFN_gen_checkpoint) 70 | 71 | PF_warp_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(PF_warp_model).to(device) 72 | PF_gen_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(PF_gen_model).to(device) 73 | 74 | if opt.isTrain and len(opt.gpu_ids): 75 | PF_warp_model = torch.nn.parallel.DistributedDataParallel(PF_warp_model, device_ids=[opt.local_rank]) 76 | PF_gen_model = torch.nn.parallel.DistributedDataParallel(PF_gen_model, device_ids=[opt.local_rank]) 77 | PB_warp_model = torch.nn.parallel.DistributedDataParallel(PB_warp_model, device_ids=[opt.local_rank]) 78 | PB_gen_model = torch.nn.parallel.DistributedDataParallel(PB_gen_model, device_ids=[opt.local_rank]) 79 | 80 | criterionL1 = nn.L1Loss() 81 | criterionVGG = VGGLoss() 82 | criterionL2 = nn.MSELoss('sum') 83 | 84 | params_warp = [p for p in PF_warp_model.parameters()] 85 | params_gen = [p for p in PF_gen_model.parameters()] 86 | optimizer_warp = torch.optim.Adam(params_warp, lr=0.2 * opt.lr, betas=(opt.beta1, 0.999)) 87 | optimizer_gen = torch.optim.Adam(params_gen, lr=opt.lr, betas=(opt.beta1, 0.999)) 88 | 89 | total_steps = (start_epoch - 1) * dataset_size + epoch_iter 90 | 91 | if opt.local_rank == 0: 92 | writer = SummaryWriter(path) 93 | 94 | step = 0 95 | step_per_batch = dataset_size 96 | 97 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): 98 | epoch_start_time = time.time() 99 | if epoch != start_epoch: 100 | epoch_iter = epoch_iter % dataset_size 101 | 102 | train_sampler.set_epoch(epoch) 103 | 104 | for i, data in enumerate(train_loader): 105 | 106 | iter_start_time = time.time() 107 | 108 | total_steps += 1 109 | epoch_iter += 1 110 | save_fake = True 111 | 112 | t_mask = torch.FloatTensor((data['label'].cpu().numpy() == 7).astype(np.float)) 113 | data['label'] = data['label'] * (1 - t_mask) + t_mask * 4 114 | edge = data['edge'] 115 | pre_clothes_edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int)) 116 | clothes = data['color'] 117 | clothes = clothes * pre_clothes_edge 118 | edge_un = data['edge_un'] 119 | pre_clothes_edge_un = torch.FloatTensor((edge_un.detach().numpy() > 0.5).astype(np.int)) 120 | clothes_un = data['color_un'] 121 | clothes_un = clothes_un * pre_clothes_edge_un 122 | person_clothes_edge = torch.FloatTensor((data['label'].cpu().numpy() == 4).astype(np.int)) 123 | real_image = data['image'] 124 | person_clothes = real_image * person_clothes_edge 125 | pose = data['pose'] 126 | size = data['label'].size() 127 | oneHot_size1 = (size[0], 25, size[2], size[3]) 128 | densepose = torch.cuda.FloatTensor(torch.Size(oneHot_size1)).zero_() 129 | densepose = densepose.scatter_(1, data['densepose'].data.long().cuda(), 1.0) 130 | densepose_fore = data['densepose'] / 24 131 | face_mask = torch.FloatTensor((data['label'].cpu().numpy() == 1).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy() == 12).astype(np.int)) 132 | other_clothes_mask = torch.FloatTensor((data['label'].cpu().numpy() == 5).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy() == 6).astype(np.int)) \ 133 | + torch.FloatTensor((data['label'].cpu().numpy() == 8).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy() == 9).astype(np.int)) \ 134 | + torch.FloatTensor((data['label'].cpu().numpy() == 10).astype(np.int)) 135 | face_img = face_mask * real_image 136 | other_clothes_img = other_clothes_mask * real_image 137 | preserve_mask = torch.cat([face_mask, other_clothes_mask], 1) 138 | 139 | concat_un = torch.cat([preserve_mask.cuda(), densepose, pose.cuda()], 1) 140 | flow_out_un = PB_warp_model(concat_un.cuda(), clothes_un.cuda(), pre_clothes_edge_un.cuda()) 141 | warped_cloth_un, last_flow_un, cond_un_all, flow_un_all, delta_list_un, x_all_un, x_edge_all_un, delta_x_all_un, delta_y_all_un = flow_out_un 142 | warped_prod_edge_un = F.grid_sample(pre_clothes_edge_un.cuda(), last_flow_un.permute(0, 2, 3, 1), 143 | mode='bilinear', padding_mode='zeros') 144 | 145 | flow_out_sup = PB_warp_model(concat_un.cuda(), clothes.cuda(), pre_clothes_edge.cuda()) 146 | warped_cloth_sup, last_flow_sup, cond_sup_all, flow_sup_all, delta_list_sup, x_all_sup, x_edge_all_sup, delta_x_all_sup, delta_y_all_sup = flow_out_sup 147 | 148 | arm_mask = torch.FloatTensor((data['label'].cpu().numpy() == 11).astype(np.float)) + torch.FloatTensor((data['label'].cpu().numpy() == 13).astype(np.float)) 149 | hand_mask = torch.FloatTensor((data['densepose'].cpu().numpy() == 3).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 4).astype(np.int)) 150 | dense_preserve_mask = torch.FloatTensor((data['densepose'].cpu().numpy() == 15).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 16).astype(np.int)) \ 151 | + torch.FloatTensor((data['densepose'].cpu().numpy() == 17).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 18).astype(np.int)) \ 152 | + torch.FloatTensor((data['densepose'].cpu().numpy() == 19).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 20).astype(np.int)) \ 153 | + torch.FloatTensor((data['densepose'].cpu().numpy() == 21).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 22)) 154 | hand_img = (arm_mask * hand_mask) * real_image 155 | dense_preserve_mask = dense_preserve_mask.cuda() * (1 - warped_prod_edge_un) 156 | preserve_region = face_img + other_clothes_img + hand_img 157 | 158 | gen_inputs_un = torch.cat([preserve_region.cuda(), warped_cloth_un, warped_prod_edge_un, dense_preserve_mask], 1) 159 | gen_outputs_un = PB_gen_model(gen_inputs_un) 160 | p_rendered_un, m_composite_un = torch.split(gen_outputs_un, [3, 1], 1) 161 | p_rendered_un = torch.tanh(p_rendered_un) 162 | m_composite_un = torch.sigmoid(m_composite_un) 163 | m_composite_un = m_composite_un * warped_prod_edge_un 164 | p_tryon_un = warped_cloth_un * m_composite_un + p_rendered_un * (1 - m_composite_un) 165 | 166 | flow_out = PF_warp_model(p_tryon_un.detach(), clothes.cuda(), pre_clothes_edge.cuda()) 167 | warped_cloth, last_flow, cond_all, flow_all, delta_list, x_all, x_edge_all, delta_x_all, delta_y_all = flow_out 168 | warped_prod_edge = x_edge_all[4] 169 | 170 | epsilon = 0.001 171 | loss_smooth = sum([TVLoss(x) for x in delta_list]) 172 | loss_warp = 0 173 | loss_fea_sup_all = 0 174 | loss_flow_sup_all = 0 175 | 176 | l1_loss_batch = torch.abs(warped_cloth_sup.detach() - person_clothes.cuda()) 177 | l1_loss_batch = l1_loss_batch.reshape(opt.batchSize, 3 * 256 * 192) 178 | l1_loss_batch = l1_loss_batch.sum(dim=1) / (3 * 256 * 192) 179 | l1_loss_batch_pred = torch.abs(warped_cloth.detach() - person_clothes.cuda()) 180 | l1_loss_batch_pred = l1_loss_batch_pred.reshape(opt.batchSize, 3 * 256 * 192) 181 | l1_loss_batch_pred = l1_loss_batch_pred.sum(dim=1) / (3 * 256 * 192) 182 | weight = (l1_loss_batch < l1_loss_batch_pred).float() 183 | num_all = len(np.where(weight.cpu().numpy() > 0)[0]) 184 | if num_all == 0: 185 | num_all = 1 186 | 187 | for num in range(5): 188 | cur_person_clothes = F.interpolate(person_clothes, scale_factor=0.5 ** (4 - num), mode='bilinear') 189 | cur_person_clothes_edge = F.interpolate(person_clothes_edge, scale_factor=0.5 ** (4 - num), mode='bilinear') 190 | loss_l1 = criterionL1(x_all[num], cur_person_clothes.cuda()) 191 | loss_vgg = criterionVGG(x_all[num], cur_person_clothes.cuda()) 192 | loss_edge = criterionL1(x_edge_all[num], cur_person_clothes_edge.cuda()) 193 | b, c, h, w = delta_x_all[num].shape 194 | loss_flow_x = (delta_x_all[num].pow(2) + epsilon * epsilon).pow(0.45) 195 | loss_flow_x = torch.sum(loss_flow_x) / (b * c * h * w) 196 | loss_flow_y = (delta_y_all[num].pow(2) + epsilon * epsilon).pow(0.45) 197 | loss_flow_y = torch.sum(loss_flow_y) / (b * c * h * w) 198 | loss_second_smooth = loss_flow_x + loss_flow_y 199 | b1, c1, h1, w1 = cond_all[num].shape 200 | weight_all = weight.reshape(-1, 1, 1, 1).repeat(1, 256, h1, w1) 201 | cond_sup_loss = ((cond_sup_all[num].detach() - cond_all[num]) ** 2 * weight_all).sum() / (256 * h1 * w1 * num_all) 202 | loss_fea_sup_all = loss_fea_sup_all + (5 - num) * 0.04 * cond_sup_loss 203 | loss_warp = loss_warp + (num + 1) * loss_l1 + (num + 1) * 0.2 * loss_vgg + (num + 1) * 2 * loss_edge + (num + 1) * 6 * loss_second_smooth + (5 - num) * 0.04 * cond_sup_loss 204 | if num >= 2: 205 | b1, c1, h1, w1 = flow_all[num].shape 206 | weight_all = weight.reshape(-1, 1, 1).repeat(1, h1, w1) 207 | flow_sup_loss = (torch.norm(flow_sup_all[num].detach() - flow_all[num], p=2, dim=1) * weight_all).sum() / (h1 * w1 * num_all) 208 | loss_flow_sup_all = loss_flow_sup_all + (num + 1) * 1 * flow_sup_loss 209 | loss_warp = loss_warp + (num + 1) * 1 * flow_sup_loss 210 | 211 | loss_warp = 0.01 * loss_smooth + loss_warp 212 | 213 | if opt.local_rank == 0: 214 | writer.add_scalar('loss_warp', loss_warp, step) 215 | writer.add_scalar('loss_fea_sup_all', loss_fea_sup_all, step) 216 | writer.add_scalar('loss_flow_sup_all', loss_flow_sup_all, step) 217 | 218 | skin_mask = warped_prod_edge_un.detach() * (1 - person_clothes_edge.cuda()) 219 | gen_inputs = torch.cat([p_tryon_un.detach(), warped_cloth, warped_prod_edge], 1) 220 | gen_outputs = PF_gen_model(gen_inputs) 221 | p_rendered, m_composite = torch.split(gen_outputs, [3, 1], 1) 222 | p_rendered = torch.tanh(p_rendered) 223 | m_composite = torch.sigmoid(m_composite) 224 | m_composite1 = m_composite * warped_prod_edge 225 | m_composite = person_clothes_edge.cuda() * m_composite1 226 | p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite) 227 | 228 | loss_mask_l1 = torch.mean(torch.abs(1 - m_composite)) 229 | loss_l1_skin = criterionL1(p_rendered * skin_mask, skin_mask * real_image.cuda()) 230 | loss_vgg_skin = criterionVGG(p_rendered * skin_mask, skin_mask * real_image.cuda()) 231 | loss_l1 = criterionL1(p_tryon, real_image.cuda()) 232 | loss_vgg = criterionVGG(p_tryon, real_image.cuda()) 233 | bg_loss_l1 = criterionL1(p_rendered, real_image.cuda()) 234 | bg_loss_vgg = criterionVGG(p_rendered, real_image.cuda()) 235 | 236 | if epoch < opt.niter: 237 | loss_gen = (loss_l1 * 5 + loss_l1_skin * 30 + loss_vgg + loss_vgg_skin * 2 + bg_loss_l1 * 5 + bg_loss_vgg + 1 * loss_mask_l1) 238 | else: 239 | loss_gen = (loss_l1 * 5 + loss_l1_skin * 60 + loss_vgg + loss_vgg_skin * 4 + bg_loss_l1 * 5 + bg_loss_vgg + 1 * loss_mask_l1) 240 | 241 | loss_all = 0.25 * loss_warp + loss_gen 242 | 243 | if opt.local_rank == 0: 244 | writer.add_scalar('loss_gen', loss_gen, step) 245 | 246 | optimizer_warp.zero_grad() 247 | optimizer_gen.zero_grad() 248 | loss_all.backward() 249 | optimizer_warp.step() 250 | optimizer_gen.step() 251 | 252 | ############## Display results and errors ########## 253 | path = 'sample/' + opt.name 254 | os.makedirs(path, exist_ok=True) 255 | ### display output images 256 | if step % 1000 == 0: 257 | if opt.local_rank == 0: 258 | a = real_image.float().cuda() 259 | b = p_tryon_un.detach() 260 | c = clothes.cuda() 261 | d = person_clothes.cuda() 262 | e = torch.cat([skin_mask.cuda(), skin_mask.cuda(), skin_mask.cuda()], 1) 263 | f = warped_cloth 264 | g = p_rendered 265 | h = torch.cat([m_composite1, m_composite1, m_composite1], 1) 266 | i = p_tryon 267 | combine = torch.cat([a[0], b[0], c[0], d[0], e[0], f[0], g[0], h[0], i[0]], 2).squeeze() 268 | cv_img = (combine.permute(1, 2, 0).detach().cpu().numpy() + 1) / 2 269 | writer.add_image('combine', (combine.data + 1) / 2.0, step) 270 | rgb = (cv_img * 255).astype(np.uint8) 271 | bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) 272 | cv2.imwrite('sample/' + opt.name + '/' + str(step) + '.jpg', bgr) 273 | 274 | step += 1 275 | iter_end_time = time.time() 276 | iter_delta_time = iter_end_time - iter_start_time 277 | step_delta = (step_per_batch - step % step_per_batch) + step_per_batch * (opt.niter + opt.niter_decay - epoch) 278 | eta = iter_delta_time * step_delta 279 | eta = str(datetime.timedelta(seconds=int(eta))) 280 | time_stamp = datetime.datetime.now() 281 | now = time_stamp.strftime('%Y.%m.%d-%H:%M:%S') 282 | 283 | if step % 100 == 0: 284 | if opt.local_rank == 0: 285 | print('{}:{}:[step-{}]--[loss-{:.6f}]--[loss-{:.6f}]--[ETA-{}]'.format(now, epoch_iter, step, loss_gen, loss_warp, eta)) 286 | 287 | if epoch_iter >= dataset_size: 288 | break 289 | 290 | # end of epoch 291 | iter_end_time = time.time() 292 | if opt.local_rank == 0: 293 | print('End of epoch %d / %d \t Time Taken: %d sec' % 294 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 295 | 296 | if epoch % opt.save_epoch_freq == 0: 297 | if opt.local_rank == 0: 298 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 299 | save_checkpoint(PF_warp_model.module, 300 | os.path.join(opt.checkpoints_dir, opt.name, 'PFAFN_warp_epoch_%03d.pth' % (epoch + 1))) 301 | save_checkpoint(PF_gen_model.module, 302 | os.path.join(opt.checkpoints_dir, opt.name, 'PFAFN_gen_epoch_%03d.pth' % (epoch + 1))) 303 | 304 | if epoch > opt.niter: 305 | PF_warp_model.module.update_learning_rate_warp(optimizer_warp) 306 | PF_warp_model.module.update_learning_rate(optimizer_gen) 307 | -------------------------------------------------------------------------------- /PF-AFN_train/train_PFAFN_stage1.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from models.networks import ResUnetGenerator, VGGLoss, save_checkpoint, load_checkpoint_part_parallel, \ 4 | load_checkpoint_parallel 5 | from models.afwm import TVLoss, AFWM 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import os 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import DataLoader 12 | from torch.utils.data.distributed import DistributedSampler 13 | from tensorboardX import SummaryWriter 14 | import cv2 15 | import datetime 16 | 17 | opt = TrainOptions().parse() 18 | path = 'runs/' + opt.name 19 | os.makedirs(path, exist_ok=True) 20 | 21 | 22 | def CreateDataset(opt): 23 | from data.aligned_dataset import AlignedDataset 24 | dataset = AlignedDataset() 25 | print("dataset [%s] was created" % (dataset.name())) 26 | dataset.initialize(opt) 27 | return dataset 28 | 29 | 30 | os.makedirs('sample', exist_ok=True) 31 | opt = TrainOptions().parse() 32 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') 33 | 34 | torch.cuda.set_device(opt.local_rank) 35 | torch.distributed.init_process_group( 36 | 'nccl', 37 | init_method='env://' 38 | ) 39 | device = torch.device(f'cuda:{opt.local_rank}') 40 | 41 | start_epoch, epoch_iter = 1, 0 42 | 43 | train_data = CreateDataset(opt) 44 | train_sampler = DistributedSampler(train_data) 45 | train_loader = DataLoader(train_data, batch_size=opt.batchSize, shuffle=False, 46 | num_workers=4, pin_memory=True, sampler=train_sampler) 47 | dataset_size = len(train_loader) 48 | print('#training images = %d' % dataset_size) 49 | 50 | PF_warp_model = AFWM(opt, 3) 51 | print(PF_warp_model) 52 | PF_warp_model.train() 53 | PF_warp_model.cuda() 54 | load_checkpoint_part_parallel(PF_warp_model, opt.PBAFN_warp_checkpoint) 55 | 56 | PB_warp_model = AFWM(opt, 45) 57 | print(PB_warp_model) 58 | PB_warp_model.eval() 59 | PB_warp_model.cuda() 60 | load_checkpoint_parallel(PB_warp_model, opt.PBAFN_warp_checkpoint) 61 | 62 | PB_gen_model = ResUnetGenerator(8, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d) 63 | print(PB_gen_model) 64 | PB_gen_model.eval() 65 | PB_gen_model.cuda() 66 | load_checkpoint_parallel(PB_gen_model, opt.PBAFN_gen_checkpoint) 67 | 68 | PF_warp_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(PF_warp_model).to(device) 69 | 70 | if opt.isTrain and len(opt.gpu_ids): 71 | PF_warp_model = torch.nn.parallel.DistributedDataParallel(PF_warp_model, device_ids=[opt.local_rank]) 72 | PB_warp_model = torch.nn.parallel.DistributedDataParallel(PB_warp_model, device_ids=[opt.local_rank]) 73 | PB_gen_model = torch.nn.parallel.DistributedDataParallel(PB_gen_model, device_ids=[opt.local_rank]) 74 | 75 | criterionL1 = nn.L1Loss() 76 | criterionVGG = VGGLoss() 77 | criterionL2 = nn.MSELoss('sum') 78 | 79 | # optimizer 80 | params = [p for p in PF_warp_model.parameters()] 81 | optimizer = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) 82 | 83 | params_part = [] 84 | for name, param in PF_warp_model.named_parameters(): 85 | if 'cond_' in name or 'aflow_net.netRefine' in name: 86 | params_part.append(param) 87 | optimizer_part = torch.optim.Adam(params_part, lr=opt.lr, betas=(opt.beta1, 0.999)) 88 | 89 | total_steps = (start_epoch - 1) * dataset_size + epoch_iter 90 | 91 | if opt.local_rank == 0: 92 | writer = SummaryWriter(path) 93 | 94 | step = 0 95 | step_per_batch = dataset_size 96 | 97 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): 98 | epoch_start_time = time.time() 99 | if epoch != start_epoch: 100 | epoch_iter = epoch_iter % dataset_size 101 | 102 | train_sampler.set_epoch(epoch) 103 | 104 | for i, data in enumerate(train_loader): 105 | 106 | iter_start_time = time.time() 107 | 108 | total_steps += 1 109 | epoch_iter += 1 110 | save_fake = True 111 | 112 | t_mask = torch.FloatTensor((data['label'].cpu().numpy() == 7).astype(np.float)) 113 | data['label'] = data['label'] * (1 - t_mask) + t_mask * 4 114 | edge = data['edge'] 115 | pre_clothes_edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int)) 116 | clothes = data['color'] 117 | clothes = clothes * pre_clothes_edge 118 | edge_un = data['edge_un'] 119 | pre_clothes_edge_un = torch.FloatTensor((edge_un.detach().numpy() > 0.5).astype(np.int)) 120 | clothes_un = data['color_un'] 121 | clothes_un = clothes_un * pre_clothes_edge_un 122 | person_clothes_edge = torch.FloatTensor((data['label'].cpu().numpy() == 4).astype(np.int)) 123 | real_image = data['image'] 124 | person_clothes = real_image * person_clothes_edge 125 | pose = data['pose'] 126 | size = data['label'].size() 127 | oneHot_size1 = (size[0], 25, size[2], size[3]) 128 | densepose = torch.cuda.FloatTensor(torch.Size(oneHot_size1)).zero_() 129 | densepose = densepose.scatter_(1, data['densepose'].data.long().cuda(), 1.0) 130 | densepose_fore = data['densepose'] / 24 131 | face_mask = torch.FloatTensor((data['label'].cpu().numpy() == 1).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy() == 12).astype(np.int)) 132 | other_clothes_mask = torch.FloatTensor((data['label'].cpu().numpy() == 5).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy() == 6).astype(np.int)) \ 133 | + torch.FloatTensor((data['label'].cpu().numpy() == 8).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy() == 9).astype(np.int)) \ 134 | + torch.FloatTensor((data['label'].cpu().numpy() == 10).astype(np.int)) 135 | face_img = face_mask * real_image 136 | other_clothes_img = other_clothes_mask * real_image 137 | preserve_mask = torch.cat([face_mask, other_clothes_mask], 1) 138 | 139 | concat_un = torch.cat([preserve_mask.cuda(), densepose, pose.cuda()], 1) 140 | flow_out_un = PB_warp_model(concat_un.cuda(), clothes_un.cuda(), pre_clothes_edge_un.cuda()) 141 | warped_cloth_un, last_flow_un, cond_un_all, flow_un_all, delta_list_un, x_all_un, x_edge_all_un, delta_x_all_un, delta_y_all_un = flow_out_un 142 | warped_prod_edge_un = F.grid_sample(pre_clothes_edge_un.cuda(), last_flow_un.permute(0, 2, 3, 1), 143 | mode='bilinear', padding_mode='zeros') 144 | 145 | flow_out_sup = PB_warp_model(concat_un.cuda(), clothes.cuda(), pre_clothes_edge.cuda()) 146 | warped_cloth_sup, last_flow_sup, cond_sup_all, flow_sup_all, delta_list_sup, x_all_sup, x_edge_all_sup, delta_x_all_sup, delta_y_all_sup = flow_out_sup 147 | 148 | arm_mask = torch.FloatTensor((data['label'].cpu().numpy() == 11).astype(np.float)) + torch.FloatTensor((data['label'].cpu().numpy() == 13).astype(np.float)) 149 | hand_mask = torch.FloatTensor((data['densepose'].cpu().numpy() == 3).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 4).astype(np.int)) 150 | dense_preserve_mask = torch.FloatTensor((data['densepose'].cpu().numpy() == 15).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 16).astype(np.int)) \ 151 | + torch.FloatTensor((data['densepose'].cpu().numpy() == 17).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 18).astype(np.int)) \ 152 | + torch.FloatTensor((data['densepose'].cpu().numpy() == 19).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 20).astype(np.int)) \ 153 | + torch.FloatTensor((data['densepose'].cpu().numpy() == 21).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 22)) 154 | hand_img = (arm_mask * hand_mask) * real_image 155 | dense_preserve_mask = dense_preserve_mask.cuda() * (1 - warped_prod_edge_un) 156 | preserve_region = face_img + other_clothes_img + hand_img 157 | 158 | gen_inputs_un = torch.cat([preserve_region.cuda(), warped_cloth_un, warped_prod_edge_un, dense_preserve_mask], 1) 159 | gen_outputs_un = PB_gen_model(gen_inputs_un) 160 | p_rendered_un, m_composite_un = torch.split(gen_outputs_un, [3, 1], 1) 161 | p_rendered_un = torch.tanh(p_rendered_un) 162 | m_composite_un = torch.sigmoid(m_composite_un) 163 | m_composite_un = m_composite_un * warped_prod_edge_un 164 | p_tryon_un = warped_cloth_un * m_composite_un + p_rendered_un * (1 - m_composite_un) 165 | 166 | flow_out = PF_warp_model(p_tryon_un.detach(), clothes.cuda(), pre_clothes_edge.cuda()) 167 | warped_cloth, last_flow, cond_all, flow_all, delta_list, x_all, x_edge_all, delta_x_all, delta_y_all = flow_out 168 | warped_prod_edge = x_edge_all[4] 169 | 170 | epsilon = 0.001 171 | loss_smooth = sum([TVLoss(x) for x in delta_list]) 172 | loss_all = 0 173 | loss_fea_sup_all = 0 174 | loss_flow_sup_all = 0 175 | 176 | l1_loss_batch = torch.abs(warped_cloth_sup.detach() - person_clothes.cuda()) 177 | l1_loss_batch = l1_loss_batch.reshape(opt.batchSize, 3 * 256 * 192) 178 | l1_loss_batch = l1_loss_batch.sum(dim=1) / (3 * 256 * 192) 179 | l1_loss_batch_pred = torch.abs(warped_cloth.detach() - person_clothes.cuda()) 180 | l1_loss_batch_pred = l1_loss_batch_pred.reshape(opt.batchSize, 3 * 256 * 192) 181 | l1_loss_batch_pred = l1_loss_batch_pred.sum(dim=1) / (3 * 256 * 192) 182 | weight = (l1_loss_batch < l1_loss_batch_pred).float() 183 | num_all = len(np.where(weight.cpu().numpy() > 0)[0]) 184 | if num_all == 0: 185 | num_all = 1 186 | 187 | for num in range(5): 188 | cur_person_clothes = F.interpolate(person_clothes, scale_factor=0.5 ** (4 - num), mode='bilinear') 189 | cur_person_clothes_edge = F.interpolate(person_clothes_edge, scale_factor=0.5 ** (4 - num), mode='bilinear') 190 | loss_l1 = criterionL1(x_all[num], cur_person_clothes.cuda()) 191 | loss_vgg = criterionVGG(x_all[num], cur_person_clothes.cuda()) 192 | loss_edge = criterionL1(x_edge_all[num], cur_person_clothes_edge.cuda()) 193 | b, c, h, w = delta_x_all[num].shape 194 | loss_flow_x = (delta_x_all[num].pow(2) + epsilon * epsilon).pow(0.45) 195 | loss_flow_x = torch.sum(loss_flow_x) / (b * c * h * w) 196 | loss_flow_y = (delta_y_all[num].pow(2) + epsilon * epsilon).pow(0.45) 197 | loss_flow_y = torch.sum(loss_flow_y) / (b * c * h * w) 198 | loss_second_smooth = loss_flow_x + loss_flow_y 199 | b1, c1, h1, w1 = cond_all[num].shape 200 | weight_all = weight.reshape(-1, 1, 1, 1).repeat(1, 256, h1, w1) 201 | cond_sup_loss = ((cond_sup_all[num].detach() - cond_all[num]) ** 2 * weight_all).sum() / (256 * h1 * w1 * num_all) 202 | loss_fea_sup_all = loss_fea_sup_all + (5 - num) * 0.04 * cond_sup_loss 203 | loss_all = loss_all + (num + 1) * loss_l1 + (num + 1) * 0.2 * loss_vgg + (num + 1) * 2 * loss_edge + (num + 1) * 6 * loss_second_smooth + (5 - num) * 0.04 * cond_sup_loss 204 | if num >= 2: 205 | b1, c1, h1, w1 = flow_all[num].shape 206 | weight_all = weight.reshape(-1, 1, 1).repeat(1, h1, w1) 207 | flow_sup_loss = (torch.norm(flow_sup_all[num].detach() - flow_all[num], p=2, dim=1) * weight_all).sum() / (h1 * w1 * num_all) 208 | loss_flow_sup_all = loss_flow_sup_all + (num + 1) * 1 * flow_sup_loss 209 | loss_all = loss_all + (num + 1) * 1 * flow_sup_loss 210 | 211 | loss_all = 0.01 * loss_smooth + loss_all 212 | 213 | # sum per device losses 214 | if opt.local_rank == 0: 215 | writer.add_scalar('loss_all', loss_all, step) 216 | writer.add_scalar('loss_fea_sup_all', loss_fea_sup_all, step) 217 | writer.add_scalar('loss_flow_sup_all', loss_flow_sup_all, step) 218 | 219 | if epoch < opt.niter: 220 | optimizer_part.zero_grad() 221 | loss_all.backward() 222 | optimizer_part.step() 223 | else: 224 | optimizer.zero_grad() 225 | loss_all.backward() 226 | optimizer.step() 227 | 228 | ############## Display results and errors ########## 229 | path = 'sample/' + opt.name 230 | os.makedirs(path, exist_ok=True) 231 | ### display output images 232 | if step % 1000 == 0: 233 | if opt.local_rank == 0: 234 | a = real_image.float().cuda() 235 | b = p_tryon_un.detach() 236 | c = clothes.cuda() 237 | d = person_clothes.cuda() 238 | e = torch.cat([person_clothes_edge.cuda(), person_clothes_edge.cuda(), person_clothes_edge.cuda()], 1) 239 | f = torch.cat([densepose_fore.cuda(), densepose_fore.cuda(), densepose_fore.cuda()], 1) 240 | g = warped_cloth 241 | h = torch.cat([warped_prod_edge, warped_prod_edge, warped_prod_edge], 1) 242 | combine = torch.cat([a[0], b[0], c[0], d[0], e[0], f[0], g[0], h[0]], 2).squeeze() 243 | cv_img = (combine.permute(1, 2, 0).detach().cpu().numpy() + 1) / 2 244 | writer.add_image('combine', (combine.data + 1) / 2.0, step) 245 | rgb = (cv_img * 255).astype(np.uint8) 246 | bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR) 247 | cv2.imwrite('sample/' + opt.name + '/' + str(step) + '.jpg', bgr) 248 | 249 | step += 1 250 | iter_end_time = time.time() 251 | iter_delta_time = iter_end_time - iter_start_time 252 | step_delta = (step_per_batch - step % step_per_batch) + step_per_batch * (opt.niter + opt.niter_decay - epoch) 253 | eta = iter_delta_time * step_delta 254 | eta = str(datetime.timedelta(seconds=int(eta))) 255 | time_stamp = datetime.datetime.now() 256 | now = time_stamp.strftime('%Y.%m.%d-%H:%M:%S') 257 | if step % 100 == 0: 258 | if opt.local_rank == 0: 259 | print('{}:{}:[step-{}]--[loss-{:.6f}]--[loss-{:.6f}]--[loss-{:.6f}]--[ETA-{}]'.format(now, epoch_iter, step, loss_all, loss_fea_sup_all, loss_flow_sup_all, eta)) 260 | 261 | if epoch_iter >= dataset_size: 262 | break 263 | 264 | # end of epoch 265 | iter_end_time = time.time() 266 | if opt.local_rank == 0: 267 | print('End of epoch %d / %d \t Time Taken: %d sec' % 268 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 269 | 270 | ### save model for this epoch 271 | if epoch % opt.save_epoch_freq == 0: 272 | if opt.local_rank == 0: 273 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 274 | save_checkpoint(PF_warp_model.module, 275 | os.path.join(opt.checkpoints_dir, opt.name, 'PFAFN_warp_epoch_%03d.pth' % (epoch + 1))) 276 | 277 | if epoch > opt.niter: 278 | PF_warp_model.module.update_learning_rate(optimizer) 279 | -------------------------------------------------------------------------------- /PF-AFN_train/util/__init__.py: -------------------------------------------------------------------------------- 1 | # util_init -------------------------------------------------------------------------------- /PF-AFN_train/util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_train/util/__pycache__/image_pool.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/util/__pycache__/image_pool.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/util/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/util/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /PF-AFN_train/util/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/util/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /PF-AFN_train/util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.autograd import Variable 4 | class ImagePool(): 5 | def __init__(self, pool_size): 6 | self.pool_size = pool_size 7 | if self.pool_size > 0: 8 | self.num_imgs = 0 9 | self.images = [] 10 | 11 | def query(self, images): 12 | if self.pool_size == 0: 13 | return images 14 | return_images = [] 15 | for image in images.data: 16 | image = torch.unsqueeze(image, 0) 17 | if self.num_imgs < self.pool_size: 18 | self.num_imgs = self.num_imgs + 1 19 | self.images.append(image) 20 | return_images.append(image) 21 | else: 22 | p = random.uniform(0, 1) 23 | if p > 0.5: 24 | random_id = random.randint(0, self.pool_size-1) 25 | tmp = self.images[random_id].clone() 26 | self.images[random_id] = image 27 | return_images.append(tmp) 28 | else: 29 | return_images.append(image) 30 | return_images = Variable(torch.cat(return_images, 0)) 31 | return return_images 32 | -------------------------------------------------------------------------------- /PF-AFN_train/util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import numpy as np 7 | import os 8 | 9 | # Converts a Tensor into a Numpy array 10 | # |imtype|: the desired type of the converted numpy array 11 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True): 12 | if isinstance(image_tensor, list): 13 | image_numpy = [] 14 | for i in range(len(image_tensor)): 15 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) 16 | return image_numpy 17 | image_numpy = image_tensor.cpu().float().numpy() 18 | #if normalize: 19 | # image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 20 | #else: 21 | # image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 22 | image_numpy = (image_numpy + 1) / 2.0 23 | image_numpy = np.clip(image_numpy, 0, 1) 24 | if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3: 25 | image_numpy = image_numpy[:,:,0] 26 | 27 | return image_numpy 28 | 29 | # Converts a one-hot tensor into a colorful label map 30 | def tensor2label(label_tensor, n_label, imtype=np.uint8): 31 | if n_label == 0: 32 | return tensor2im(label_tensor, imtype) 33 | label_tensor = label_tensor.cpu().float() 34 | if label_tensor.size()[0] > 1: 35 | label_tensor = label_tensor.max(0, keepdim=True)[1] 36 | label_tensor = Colorize(n_label)(label_tensor) 37 | #label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) 38 | label_numpy = label_tensor.numpy() 39 | label_numpy = label_numpy / 255.0 40 | 41 | return label_numpy 42 | 43 | def save_image(image_numpy, image_path): 44 | image_pil = Image.fromarray(image_numpy) 45 | image_pil.save(image_path) 46 | 47 | def mkdirs(paths): 48 | if isinstance(paths, list) and not isinstance(paths, str): 49 | for path in paths: 50 | mkdir(path) 51 | else: 52 | mkdir(paths) 53 | 54 | def mkdir(path): 55 | if not os.path.exists(path): 56 | os.makedirs(path) 57 | 58 | ############################################################################### 59 | # Code from 60 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py 61 | # Modified so it complies with the Citscape label map colors 62 | ############################################################################### 63 | def uint82bin(n, count=8): 64 | """returns the binary of integer n, count refers to amount of bits""" 65 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)]) 66 | 67 | def labelcolormap(N): 68 | if N == 35: # cityscape 69 | cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81), 70 | (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153), 71 | (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0), 72 | (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70), 73 | ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)], 74 | dtype=np.uint8) 75 | else: 76 | cmap = np.zeros((N, 3), dtype=np.uint8) 77 | for i in range(N): 78 | r, g, b = 0, 0, 0 79 | id = i 80 | for j in range(7): 81 | str_id = uint82bin(id) 82 | r = r ^ (np.uint8(str_id[-1]) << (7-j)) 83 | g = g ^ (np.uint8(str_id[-2]) << (7-j)) 84 | b = b ^ (np.uint8(str_id[-3]) << (7-j)) 85 | id = id >> 3 86 | cmap[i, 0] = r 87 | cmap[i, 1] = g 88 | cmap[i, 2] = b 89 | return cmap 90 | 91 | class Colorize(object): 92 | def __init__(self, n=35): 93 | self.cmap = labelcolormap(n) 94 | self.cmap = torch.from_numpy(self.cmap[:n]) 95 | 96 | def __call__(self, gray_image): 97 | size = gray_image.size() 98 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 99 | 100 | for label in range(0, len(self.cmap)): 101 | mask = (label == gray_image[0]).cpu() 102 | color_image[0][mask] = self.cmap[label][0] 103 | color_image[1][mask] = self.cmap[label][1] 104 | color_image[2][mask] = self.cmap[label][2] 105 | 106 | return color_image 107 | -------------------------------------------------------------------------------- /PFAFN_supp.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PFAFN_supp.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Parser-Free Virtual Try-on via Distilling Appearance Flows, CVPR 2021 2 | Official code for CVPR 2021 paper 'Parser-Free Virtual Try-on via Distilling Appearance Flows' 3 | 4 | 5 | **The training code has been released.** 6 | 7 | ![image](https://github.com/geyuying/PF-AFN/blob/main/show/compare_both.jpg?raw=true) 8 | 9 | [[Paper]](https://openaccess.thecvf.com/content/CVPR2021/papers/Ge_Parser-Free_Virtual_Try-On_via_Distilling_Appearance_Flows_CVPR_2021_paper.pdf) [[Supplementary Material]](https://github.com/geyuying/PF-AFN/blob/main/PFAFN_supp.pdf) 10 | 11 | [[Checkpoints for Test]](https://drive.google.com/file/d/1_a0AiN8Y_d_9TNDhHIcRlERz3zptyYWV/view?usp=sharing) 12 | 13 | [[Training_Data]](https://drive.google.com/file/d/1Uc0DTTkSfCPXDhd4CMx2TQlzlC6bDolK/view?usp=sharing) 14 | [[Test_Data]](https://drive.google.com/file/d/1Y7uV0gomwWyxCvvH8TIbY7D9cTAUy6om/view?usp=sharing) 15 | 16 | [[VGG_Model]](https://drive.google.com/file/d/1Mw24L52FfOT9xXm3I1GL8btn7vttsHd9/view?usp=sharing) 17 | 18 | ## Our Environment 19 | anaconda3 20 | 21 | pytorch 1.1.0 22 | 23 | torchvision 0.3.0 24 | 25 | cuda 9.0 26 | 27 | cupy 6.0.0 28 | 29 | opencv-python 4.5.1 30 | 31 | 8 GTX1080 GPU for training; 1 GTX1080 GPU for test 32 | 33 | python 3.6 34 | 35 | ## Installation 36 | conda create -n tryon python=3.6 37 | 38 | source activate tryon or conda activate tryon 39 | 40 | conda install pytorch=1.1.0 torchvision=0.3.0 cudatoolkit=9.0 -c pytorch 41 | 42 | conda install cupy or pip install cupy==6.0.0 43 | 44 | pip install opencv-python 45 | 46 | git clone https://github.com/geyuying/PF-AFN.git 47 | 48 | cd PF-AFN 49 | 50 | ## Training on VITON dataset 51 | 1. cd PF-AFN_train 52 | 2. Download the VITON training set from [VITON_train](https://drive.google.com/file/d/1Uc0DTTkSfCPXDhd4CMx2TQlzlC6bDolK/view?usp=sharing) and put the folder "VITON_traindata" under the folder "dataset". 53 | 3. Dowload the VGG_19 model from [VGG_Model](https://drive.google.com/file/d/1Mw24L52FfOT9xXm3I1GL8btn7vttsHd9/view?usp=sharing) and put "vgg19-dcbb9e9d.pth" under the folder "models". 54 | 4. First train the parser-based network PBAFN. Run **scripts/train_PBAFN_stage1.sh**. After the parser-based warping module is trained, run **scripts/train_PBAFN_e2e.sh**. 55 | 5. After training the parser-based network PBAFN, train the parser-free network PFAFN. Run **scripts/train_PFAFN_stage1.sh**. After the parser-free warping module is trained, run **scripts/train_PFAFN_e2e.sh**. 56 | 6. Following the above insructions with the provided training code, the [[trained PF-AFN]](https://drive.google.com/file/d/1Pz2kA65N4Ih9w6NFYBDmdtVdB-nrrdc3/view?usp=sharing) achieves FID 9.92 on VITON test set with the test_pairs.txt (You can find it in https://github.com/minar09/cp-vton-plus/blob/master/data/test_pairs.txt). 57 | 58 | ## Run the demo 59 | 1. cd PF-AFN_test 60 | 2. First, you need to download the checkpoints from [checkpoints](https://drive.google.com/file/d/1_a0AiN8Y_d_9TNDhHIcRlERz3zptyYWV/view?usp=sharing) and put the folder "PFAFN" under the folder "checkpoints". The folder "checkpoints/PFAFN" shold contain "warp_model_final.pth" and "gen_model_final.pth". 61 | 3. The "dataset" folder contains the demo images for test, where the "test_img" folder contains the person images, the "test_clothes" folder contains the clothes images, and the "test_edge" folder contains edges extracted from the clothes images with the built-in function in python (We saved the extracted edges from the clothes images for convenience). 'demo.txt' records the test pairs. 62 | 4. During test, a person image, a clothes image and its extracted edge are fed into the network to generate the try-on image. **No human parsing results or human pose estimation results are needed for test.** 63 | 5. To test with the saved model, run **test.sh** and the results will be saved in the folder "results". 64 | 6. **To reproduce our results from the saved model, your test environment should be the same as our test environment, especifically for the version of cupy.** 65 | 66 | ![image](https://github.com/geyuying/PF-AFN/blob/main/show/compare.jpg?raw=true) 67 | ## Dataset 68 | 1. [VITON](https://github.com/xthan/VITON) contains a training set of 14,221 image pairs and a test set of 2,032 image pairs, each of which has a front-view woman photo and a top clothing image with the resolution 256 x 192. Our saved model is trained on the VITON training set and tested on the VITON test set. 69 | 2. To train from scratch on VITON training set, you can download [VITON_train](https://drive.google.com/file/d/1Uc0DTTkSfCPXDhd4CMx2TQlzlC6bDolK/view?usp=sharing). 70 | 3. To test our saved model on the complete VITON test set, you can download [VITON_test](https://drive.google.com/file/d/1Y7uV0gomwWyxCvvH8TIbY7D9cTAUy6om/view?usp=sharing). 71 | 72 | ## License 73 | The use of this code is RESTRICTED to non-commercial research and educational purposes. 74 | 75 | ## Acknowledgement 76 | Our code is based on the implementation of "Clothflow: A flow-based model for clothed person generation" (See the citation below), including the implementation of the feature pyramid networks (FPN) and the ResUnetGenerator, and the adaptation of the cascaded structure to predict the appearance flows. If you use our code, please also cite their work as below. 77 | 78 | 79 | ## Citation 80 | If our code is helpful to your work, please cite: 81 | ``` 82 | @article{ge2021parser, 83 | title={Parser-Free Virtual Try-on via Distilling Appearance Flows}, 84 | author={Ge, Yuying and Song, Yibing and Zhang, Ruimao and Ge, Chongjian and Liu, Wei and Luo, Ping}, 85 | journal={arXiv preprint arXiv:2103.04559}, 86 | year={2021} 87 | } 88 | ``` 89 | ``` 90 | @inproceedings{han2019clothflow, 91 | title={Clothflow: A flow-based model for clothed person generation}, 92 | author={Han, Xintong and Hu, Xiaojun and Huang, Weilin and Scott, Matthew R}, 93 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 94 | pages={10471--10480}, 95 | year={2019} 96 | } 97 | ``` 98 | -------------------------------------------------------------------------------- /show/compare.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/show/compare.jpg -------------------------------------------------------------------------------- /show/compare_both.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/show/compare_both.jpg --------------------------------------------------------------------------------