├── PF-AFN_test
├── checkpoints
│ └── .md
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── aligned_dataset_test.cpython-36.pyc
│ │ ├── aligned_dataset_test.cpython-37.pyc
│ │ ├── base_data_loader.cpython-36.pyc
│ │ ├── base_data_loader.cpython-37.pyc
│ │ ├── base_dataset.cpython-36.pyc
│ │ ├── base_dataset.cpython-37.pyc
│ │ ├── custom_dataset_data_loader_test.cpython-36.pyc
│ │ ├── custom_dataset_data_loader_test.cpython-37.pyc
│ │ ├── data_loader_test.cpython-36.pyc
│ │ ├── data_loader_test.cpython-37.pyc
│ │ ├── image_folder.cpython-36.pyc
│ │ └── image_folder.cpython-37.pyc
│ ├── aligned_dataset_test.py
│ ├── base_data_loader.py
│ ├── base_dataset.py
│ ├── custom_dataset_data_loader_test.py
│ ├── data_loader_test.py
│ └── image_folder.py
├── dataset
│ ├── test_clothes
│ │ ├── 003434_1.jpg
│ │ ├── 006026_1.jpg
│ │ ├── 010567_1.jpg
│ │ ├── 014396_1.jpg
│ │ ├── 017575_1.jpg
│ │ └── 019119_1.jpg
│ ├── test_edge
│ │ ├── 003434_1.jpg
│ │ ├── 006026_1.jpg
│ │ ├── 010567_1.jpg
│ │ ├── 014396_1.jpg
│ │ ├── 017575_1.jpg
│ │ └── 019119_1.jpg
│ └── test_img
│ │ ├── 000066_0.jpg
│ │ ├── 004912_0.jpg
│ │ ├── 005510_0.jpg
│ │ ├── 014834_0.jpg
│ │ ├── 015794_0.jpg
│ │ └── 016962_0.jpg
├── demo.txt
├── models
│ ├── __pycache__
│ │ ├── afwm.cpython-36.pyc
│ │ ├── afwm.cpython-37.pyc
│ │ ├── networks.cpython-36.pyc
│ │ └── networks.cpython-37.pyc
│ ├── afwm.py
│ ├── correlation
│ │ ├── README.md
│ │ ├── __pycache__
│ │ │ ├── correlation.cpython-36.pyc
│ │ │ └── correlation.cpython-37.pyc
│ │ └── correlation.py
│ └── networks.py
├── options
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── base_options.cpython-36.pyc
│ │ ├── base_options.cpython-37.pyc
│ │ ├── test_options.cpython-36.pyc
│ │ └── test_options.cpython-37.pyc
│ ├── base_options.py
│ └── test_options.py
├── results
│ └── demo
│ │ └── PFAFN
│ │ ├── 0.jpg
│ │ ├── 1.jpg
│ │ ├── 2.jpg
│ │ ├── 3.jpg
│ │ ├── 4.jpg
│ │ └── 5.jpg
├── test.py
├── test.sh
└── util
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── util.cpython-36.pyc
│ └── util.cpython-37.pyc
│ ├── image_pool.py
│ └── util.py
├── PF-AFN_train
├── checkpoints
│ └── readme.txt
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── aligned_dataset.cpython-36.pyc
│ │ ├── aligned_dataset.cpython-37.pyc
│ │ ├── aligned_dataset_fake.cpython-36.pyc
│ │ ├── aligned_dataset_test.cpython-36.pyc
│ │ ├── base_data_loader.cpython-36.pyc
│ │ ├── base_dataset.cpython-36.pyc
│ │ ├── base_dataset.cpython-37.pyc
│ │ ├── custom_dataset_data_loader.cpython-36.pyc
│ │ ├── custom_dataset_data_loader_test.cpython-36.pyc
│ │ ├── data_loader.cpython-36.pyc
│ │ ├── data_loader_test.cpython-36.pyc
│ │ ├── image_folder.cpython-36.pyc
│ │ └── image_folder.cpython-37.pyc
│ ├── aligned_dataset.py
│ ├── base_data_loader.py
│ ├── base_dataset.py
│ ├── custom_dataset_data_loader.py
│ ├── data_loader.py
│ └── image_folder.py
├── dataset
│ └── readme.txt
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── afwm.cpython-37.pyc
│ │ ├── base_model.cpython-36.pyc
│ │ ├── flow_gmm.cpython-36.pyc
│ │ ├── flow_gmm_add.cpython-36.pyc
│ │ ├── flow_gmm_cor.cpython-36.pyc
│ │ ├── flow_gmm_cor_add.cpython-36.pyc
│ │ ├── flow_gmm_cor_more.cpython-36.pyc
│ │ ├── flow_gmm_cor_more_add.cpython-36.pyc
│ │ ├── flow_gmm_cor_more_feat.cpython-36.pyc
│ │ ├── flow_gmm_cor_more_grid_offset_sep.cpython-36.pyc
│ │ ├── flow_gmm_cor_more_grid_sep.cpython-36.pyc
│ │ ├── flow_gmm_cor_more_new.cpython-36.pyc
│ │ ├── flow_gmm_cor_more_offset.cpython-36.pyc
│ │ ├── flow_gmm_cor_more_revise.cpython-36.pyc
│ │ ├── flow_gmm_cor_more_revise_new.cpython-36.pyc
│ │ ├── flow_gmm_cor_more_revise_new_sep.cpython-36.pyc
│ │ ├── flow_gmm_cor_more_revise_new_sep_all.cpython-36.pyc
│ │ ├── flow_gmm_cor_more_revise_new_sep_all_more.cpython-36.pyc
│ │ ├── flow_gmm_cor_more_revise_new_sep_all_more_heatmap.cpython-36.pyc
│ │ ├── flow_gmm_cor_more_revise_new_sep_all_more_no_refine.cpython-36.pyc
│ │ ├── flow_gmm_cor_more_revise_new_sep_all_more_trans.cpython-36.pyc
│ │ ├── flow_gmm_cor_more_sep.cpython-36.pyc
│ │ ├── flow_gmm_cor_sep.cpython-36.pyc
│ │ ├── flow_gmm_smooth.cpython-36.pyc
│ │ ├── flow_gmm_vis.cpython-36.pyc
│ │ ├── models.cpython-36.pyc
│ │ ├── networks.cpython-36.pyc
│ │ ├── networks.cpython-37.pyc
│ │ ├── networks_flow.cpython-36.pyc
│ │ ├── pix2pixHD_model.cpython-36.pyc
│ │ └── predict_mask.cpython-36.pyc
│ ├── afwm.py
│ ├── correlation
│ │ ├── README.md
│ │ ├── __pycache__
│ │ │ ├── correlation.cpython-36.pyc
│ │ │ └── correlation.cpython-37.pyc
│ │ └── correlation.py
│ └── networks.py
├── options
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-36.pyc
│ │ ├── __init__.cpython-37.pyc
│ │ ├── base_options.cpython-36.pyc
│ │ ├── base_options.cpython-37.pyc
│ │ ├── test_options.cpython-36.pyc
│ │ ├── train_options.cpython-36.pyc
│ │ └── train_options.cpython-37.pyc
│ ├── base_options.py
│ └── train_options.py
├── runs
│ └── readme.txt
├── sample
│ └── readme.txt
├── scripts
│ ├── train_PBAFN_e2e.sh
│ ├── train_PBAFN_stage1.sh
│ ├── train_PFAFN_e2e.sh
│ └── train_PFAFN_stage1.sh
├── train_PBAFN_e2e.py
├── train_PBAFN_stage1.py
├── train_PFAFN_e2e.py
├── train_PFAFN_stage1.py
└── util
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── image_pool.cpython-36.pyc
│ ├── util.cpython-36.pyc
│ └── util.cpython-37.pyc
│ ├── image_pool.py
│ └── util.py
├── PFAFN_supp.pdf
├── README.md
└── show
├── compare.jpg
└── compare_both.jpg
/PF-AFN_test/checkpoints/.md:
--------------------------------------------------------------------------------
1 | Please put the downloaded checkpoints folder "PFAFN" under the folder "checkpoints".
2 |
--------------------------------------------------------------------------------
/PF-AFN_test/data/__init__.py:
--------------------------------------------------------------------------------
1 | # data_init
--------------------------------------------------------------------------------
/PF-AFN_test/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/data/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/data/__pycache__/aligned_dataset_test.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/aligned_dataset_test.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/data/__pycache__/aligned_dataset_test.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/aligned_dataset_test.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/data/__pycache__/base_data_loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/base_data_loader.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/data/__pycache__/base_data_loader.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/base_data_loader.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/data/__pycache__/base_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/base_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/data/__pycache__/base_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/base_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/data/__pycache__/custom_dataset_data_loader_test.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/custom_dataset_data_loader_test.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/data/__pycache__/custom_dataset_data_loader_test.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/custom_dataset_data_loader_test.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/data/__pycache__/data_loader_test.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/data_loader_test.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/data/__pycache__/data_loader_test.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/data_loader_test.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/data/__pycache__/image_folder.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/image_folder.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/data/__pycache__/image_folder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/data/__pycache__/image_folder.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/data/aligned_dataset_test.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from data.base_dataset import BaseDataset, get_params, get_transform
3 | from PIL import Image
4 | import linecache
5 |
6 | class AlignedDataset(BaseDataset):
7 | def initialize(self, opt):
8 | self.opt = opt
9 | self.root = opt.dataroot
10 |
11 | self.fine_height=256
12 | self.fine_width=192
13 |
14 | self.dataset_size = len(open('demo.txt').readlines())
15 |
16 | dir_I = '_img'
17 | self.dir_I = os.path.join(opt.dataroot, opt.phase + dir_I)
18 |
19 | dir_C = '_clothes'
20 | self.dir_C = os.path.join(opt.dataroot, opt.phase + dir_C)
21 |
22 | dir_E = '_edge'
23 | self.dir_E = os.path.join(opt.dataroot, opt.phase + dir_E)
24 |
25 | def __getitem__(self, index):
26 |
27 | file_path ='demo.txt'
28 | im_name, c_name = linecache.getline(file_path, index+1).strip().split()
29 |
30 | I_path = os.path.join(self.dir_I,im_name)
31 | I = Image.open(I_path).convert('RGB')
32 |
33 | params = get_params(self.opt, I.size)
34 | transform = get_transform(self.opt, params)
35 | transform_E = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
36 |
37 | I_tensor = transform(I)
38 |
39 | C_path = os.path.join(self.dir_C,c_name)
40 | C = Image.open(C_path).convert('RGB')
41 | C_tensor = transform(C)
42 |
43 | E_path = os.path.join(self.dir_E,c_name)
44 | E = Image.open(E_path).convert('L')
45 | E_tensor = transform_E(E)
46 |
47 | input_dict = { 'image': I_tensor,'clothes': C_tensor, 'edge': E_tensor}
48 | return input_dict
49 |
50 | def __len__(self):
51 | return self.dataset_size
52 |
53 | def name(self):
54 | return 'AlignedDataset'
55 |
--------------------------------------------------------------------------------
/PF-AFN_test/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | class BaseDataLoader():
3 | def __init__(self):
4 | pass
5 |
6 | def initialize(self, opt):
7 | self.opt = opt
8 | pass
9 |
10 | def load_data():
11 | return None
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/PF-AFN_test/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import random
6 |
7 | class BaseDataset(data.Dataset):
8 | def __init__(self):
9 | super(BaseDataset, self).__init__()
10 |
11 | def name(self):
12 | return 'BaseDataset'
13 |
14 | def initialize(self, opt):
15 | pass
16 |
17 | def get_params(opt, size):
18 | w, h = size
19 | new_h = h
20 | new_w = w
21 | if opt.resize_or_crop == 'resize_and_crop':
22 | new_h = new_w = opt.loadSize
23 | elif opt.resize_or_crop == 'scale_width_and_crop':
24 | new_w = opt.loadSize
25 | new_h = opt.loadSize * h // w
26 |
27 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
28 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
29 |
30 | flip = 0
31 | return {'crop_pos': (x, y), 'flip': flip}
32 |
33 | def get_transform_resize(opt, params, method=Image.BICUBIC, normalize=True):
34 | transform_list = []
35 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
36 | osize = [256,192]
37 | transform_list.append(transforms.Scale(osize, method))
38 | if 'crop' in opt.resize_or_crop:
39 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
40 |
41 | if opt.resize_or_crop == 'none':
42 | base = float(2 ** opt.n_downsample_global)
43 | if opt.netG == 'local':
44 | base *= (2 ** opt.n_local_enhancers)
45 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
46 |
47 | if opt.isTrain and not opt.no_flip:
48 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
49 |
50 | transform_list += [transforms.ToTensor()]
51 |
52 | if normalize:
53 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
54 | (0.5, 0.5, 0.5))]
55 | return transforms.Compose(transform_list)
56 |
57 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
58 | transform_list = []
59 | if 'resize' in opt.resize_or_crop:
60 | osize = [opt.loadSize, opt.loadSize]
61 | transform_list.append(transforms.Scale(osize, method))
62 | elif 'scale_width' in opt.resize_or_crop:
63 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
64 | osize = [256,192]
65 | transform_list.append(transforms.Scale(osize, method))
66 | if 'crop' in opt.resize_or_crop:
67 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
68 |
69 | if opt.resize_or_crop == 'none':
70 | base = float(16)
71 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
72 |
73 | if opt.isTrain and not opt.no_flip:
74 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
75 |
76 | transform_list += [transforms.ToTensor()]
77 |
78 | if normalize:
79 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
80 | (0.5, 0.5, 0.5))]
81 | return transforms.Compose(transform_list)
82 |
83 | def normalize():
84 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
85 |
86 | def __make_power_2(img, base, method=Image.BICUBIC):
87 | ow, oh = img.size
88 | h = int(round(oh / base) * base)
89 | w = int(round(ow / base) * base)
90 | if (h == oh) and (w == ow):
91 | return img
92 | return img.resize((w, h), method)
93 |
94 | def __scale_width(img, target_width, method=Image.BICUBIC):
95 | ow, oh = img.size
96 | if (ow == target_width):
97 | return img
98 | w = target_width
99 | h = int(target_width * oh / ow)
100 | return img.resize((w, h), method)
101 |
102 | def __crop(img, pos, size):
103 | ow, oh = img.size
104 | x1, y1 = pos
105 | tw = th = size
106 | if (ow > tw or oh > th):
107 | return img.crop((x1, y1, x1 + tw, y1 + th))
108 | return img
109 |
110 | def __flip(img, flip):
111 | if flip:
112 | return img.transpose(Image.FLIP_LEFT_RIGHT)
113 | return img
114 |
--------------------------------------------------------------------------------
/PF-AFN_test/data/custom_dataset_data_loader_test.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from data.base_data_loader import BaseDataLoader
3 |
4 |
5 | def CreateDataset(opt):
6 | dataset = None
7 | from data.aligned_dataset_test import AlignedDataset
8 | dataset = AlignedDataset()
9 |
10 | print("dataset [%s] was created" % (dataset.name()))
11 | dataset.initialize(opt)
12 | return dataset
13 |
14 | class CustomDatasetDataLoader(BaseDataLoader):
15 | def name(self):
16 | return 'CustomDatasetDataLoader'
17 |
18 | def initialize(self, opt):
19 | BaseDataLoader.initialize(self, opt)
20 | self.dataset = CreateDataset(opt)
21 | self.dataloader = torch.utils.data.DataLoader(
22 | self.dataset,
23 | batch_size=opt.batchSize,
24 | shuffle = False,
25 | num_workers=int(opt.nThreads))
26 |
27 | def load_data(self):
28 | return self.dataloader
29 |
30 | def __len__(self):
31 | return min(len(self.dataset), self.opt.max_dataset_size)
32 |
--------------------------------------------------------------------------------
/PF-AFN_test/data/data_loader_test.py:
--------------------------------------------------------------------------------
1 |
2 | def CreateDataLoader(opt):
3 | from data.custom_dataset_data_loader_test import CustomDatasetDataLoader
4 | data_loader = CustomDatasetDataLoader()
5 | print(data_loader.name())
6 | data_loader.initialize(opt)
7 | return data_loader
8 |
--------------------------------------------------------------------------------
/PF-AFN_test/data/image_folder.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image
3 | import os
4 |
5 | IMG_EXTENSIONS = [
6 | '.jpg', '.JPG', '.jpeg', '.JPEG',
7 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
8 | ]
9 |
10 |
11 | def is_image_file(filename):
12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
13 |
14 | def make_dataset(dir):
15 | images = []
16 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
17 |
18 | f = dir.split('/')[-1].split('_')[-1]
19 | print (dir, f)
20 | dirs= os.listdir(dir)
21 | for img in dirs:
22 |
23 | path = os.path.join(dir, img)
24 | #print(path)
25 | images.append(path)
26 | return images
27 |
28 | def make_dataset_test(dir):
29 | images = []
30 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
31 |
32 | f = dir.split('/')[-1].split('_')[-1]
33 | for i in range(len([name for name in os.listdir(dir) if os.path.isfile(os.path.join(dir, name))])):
34 | if f == 'label' or f == 'labelref':
35 | img = str(i) + '.png'
36 | else:
37 | img = str(i) + '.jpg'
38 | path = os.path.join(dir, img)
39 | images.append(path)
40 | return images
41 |
42 | def default_loader(path):
43 | return Image.open(path).convert('RGB')
44 |
45 |
46 | class ImageFolder(data.Dataset):
47 |
48 | def __init__(self, root, transform=None, return_paths=False,
49 | loader=default_loader):
50 | imgs = make_dataset(root)
51 | if len(imgs) == 0:
52 | raise(RuntimeError("Found 0 images in: " + root + "\n"
53 | "Supported image extensions are: " +
54 | ",".join(IMG_EXTENSIONS)))
55 |
56 | self.root = root
57 | self.imgs = imgs
58 | self.transform = transform
59 | self.return_paths = return_paths
60 | self.loader = loader
61 |
62 | def __getitem__(self, index):
63 | path = self.imgs[index]
64 | img = self.loader(path)
65 | if self.transform is not None:
66 | img = self.transform(img)
67 | if self.return_paths:
68 | return img, path
69 | else:
70 | return img
71 |
72 | def __len__(self):
73 | return len(self.imgs)
74 |
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_clothes/003434_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_clothes/003434_1.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_clothes/006026_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_clothes/006026_1.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_clothes/010567_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_clothes/010567_1.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_clothes/014396_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_clothes/014396_1.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_clothes/017575_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_clothes/017575_1.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_clothes/019119_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_clothes/019119_1.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_edge/003434_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_edge/003434_1.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_edge/006026_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_edge/006026_1.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_edge/010567_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_edge/010567_1.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_edge/014396_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_edge/014396_1.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_edge/017575_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_edge/017575_1.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_edge/019119_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_edge/019119_1.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_img/000066_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_img/000066_0.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_img/004912_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_img/004912_0.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_img/005510_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_img/005510_0.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_img/014834_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_img/014834_0.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_img/015794_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_img/015794_0.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/dataset/test_img/016962_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/dataset/test_img/016962_0.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/demo.txt:
--------------------------------------------------------------------------------
1 | 000066_0.jpg 017575_1.jpg
2 | 016962_0.jpg 003434_1.jpg
3 | 004912_0.jpg 014396_1.jpg
4 | 005510_0.jpg 006026_1.jpg
5 | 014834_0.jpg 019119_1.jpg
6 | 015794_0.jpg 010567_1.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/models/__pycache__/afwm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/models/__pycache__/afwm.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/models/__pycache__/afwm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/models/__pycache__/afwm.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/models/__pycache__/networks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/models/__pycache__/networks.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/models/__pycache__/networks.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/models/__pycache__/networks.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/models/afwm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from .correlation import correlation
5 |
6 | def apply_offset(offset):
7 |
8 | sizes = list(offset.size()[2:])
9 | grid_list = torch.meshgrid([torch.arange(size, device=offset.device) for size in sizes])
10 | grid_list = reversed(grid_list)
11 |
12 | grid_list = [grid.float().unsqueeze(0) + offset[:, dim, ...]
13 | for dim, grid in enumerate(grid_list)]
14 |
15 | grid_list = [grid / ((size - 1.0) / 2.0) - 1.0
16 | for grid, size in zip(grid_list, reversed(sizes))]
17 |
18 | return torch.stack(grid_list, dim=-1)
19 |
20 |
21 | class ResBlock(nn.Module):
22 | def __init__(self, in_channels):
23 | super(ResBlock, self).__init__()
24 | self.block = nn.Sequential(
25 | nn.BatchNorm2d(in_channels),
26 | nn.ReLU(inplace=True),
27 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False),
28 | nn.BatchNorm2d(in_channels),
29 | nn.ReLU(inplace=True),
30 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
31 | )
32 |
33 | def forward(self, x):
34 | return self.block(x) + x
35 |
36 |
37 | class DownSample(nn.Module):
38 | def __init__(self, in_channels, out_channels):
39 | super(DownSample, self).__init__()
40 | self.block= nn.Sequential(
41 | nn.BatchNorm2d(in_channels),
42 | nn.ReLU(inplace=True),
43 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)
44 | )
45 |
46 | def forward(self, x):
47 | return self.block(x)
48 |
49 |
50 |
51 | class FeatureEncoder(nn.Module):
52 | def __init__(self, in_channels, chns=[64,128,256,256,256]):
53 | super(FeatureEncoder, self).__init__()
54 | self.encoders = []
55 | for i, out_chns in enumerate(chns):
56 | if i == 0:
57 | encoder = nn.Sequential(DownSample(in_channels, out_chns),
58 | ResBlock(out_chns),
59 | ResBlock(out_chns))
60 | else:
61 | encoder = nn.Sequential(DownSample(chns[i-1], out_chns),
62 | ResBlock(out_chns),
63 | ResBlock(out_chns))
64 |
65 | self.encoders.append(encoder)
66 |
67 | self.encoders = nn.ModuleList(self.encoders)
68 |
69 |
70 | def forward(self, x):
71 | encoder_features = []
72 | for encoder in self.encoders:
73 | x = encoder(x)
74 | encoder_features.append(x)
75 | return encoder_features
76 |
77 | class RefinePyramid(nn.Module):
78 | def __init__(self, chns=[64,128,256,256,256], fpn_dim=256):
79 | super(RefinePyramid, self).__init__()
80 | self.chns = chns
81 |
82 | self.adaptive = []
83 | for in_chns in list(reversed(chns)):
84 | adaptive_layer = nn.Conv2d(in_chns, fpn_dim, kernel_size=1)
85 | self.adaptive.append(adaptive_layer)
86 | self.adaptive = nn.ModuleList(self.adaptive)
87 |
88 | self.smooth = []
89 | for i in range(len(chns)):
90 | smooth_layer = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, padding=1)
91 | self.smooth.append(smooth_layer)
92 | self.smooth = nn.ModuleList(self.smooth)
93 |
94 | def forward(self, x):
95 | conv_ftr_list = x
96 |
97 | feature_list = []
98 | last_feature = None
99 | for i, conv_ftr in enumerate(list(reversed(conv_ftr_list))):
100 | feature = self.adaptive[i](conv_ftr)
101 |
102 | if last_feature is not None:
103 | feature = feature + F.interpolate(last_feature, scale_factor=2, mode='nearest')
104 |
105 | feature = self.smooth[i](feature)
106 | last_feature = feature
107 | feature_list.append(feature)
108 |
109 | return tuple(reversed(feature_list))
110 |
111 |
112 | class AFlowNet(nn.Module):
113 | def __init__(self, num_pyramid, fpn_dim=256):
114 | super(AFlowNet, self).__init__()
115 | self.netMain = []
116 | self.netRefine = []
117 | for i in range(num_pyramid):
118 | netMain_layer = torch.nn.Sequential(
119 | torch.nn.Conv2d(in_channels=49, out_channels=128, kernel_size=3, stride=1, padding=1),
120 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
121 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
122 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
123 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
124 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
125 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
126 | )
127 |
128 | netRefine_layer = torch.nn.Sequential(
129 | torch.nn.Conv2d(2 * fpn_dim, out_channels=128, kernel_size=3, stride=1, padding=1),
130 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
131 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
132 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
133 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
134 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
135 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
136 | )
137 | self.netMain.append(netMain_layer)
138 | self.netRefine.append(netRefine_layer)
139 |
140 | self.netMain = nn.ModuleList(self.netMain)
141 | self.netRefine = nn.ModuleList(self.netRefine)
142 |
143 |
144 | def forward(self, x, x_warps, x_conds, warp_feature=True):
145 | last_flow = None
146 |
147 | for i in range(len(x_warps)):
148 | x_warp = x_warps[len(x_warps) - 1 - i]
149 | x_cond = x_conds[len(x_warps) - 1 - i]
150 |
151 | if last_flow is not None and warp_feature:
152 | x_warp_after = F.grid_sample(x_warp, last_flow.detach().permute(0, 2, 3, 1),
153 | mode='bilinear', padding_mode='border')
154 | else:
155 | x_warp_after = x_warp
156 |
157 | tenCorrelation = F.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=x_warp_after, tenSecond=x_cond, intStride=1), negative_slope=0.1, inplace=False)
158 | flow = self.netMain[i](tenCorrelation)
159 | flow = apply_offset(flow)
160 |
161 | if last_flow is not None:
162 | flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border')
163 | else:
164 | flow = flow.permute(0, 3, 1, 2)
165 |
166 | last_flow = flow
167 | x_warp = F.grid_sample(x_warp, flow.permute(0, 2, 3, 1),mode='bilinear', padding_mode='border')
168 | concat = torch.cat([x_warp,x_cond],1)
169 | flow = self.netRefine[i](concat)
170 | flow = apply_offset(flow)
171 | flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border')
172 |
173 | last_flow = F.interpolate(flow, scale_factor=2, mode='bilinear')
174 |
175 | x_warp = F.grid_sample(x, last_flow.permute(0, 2, 3, 1),
176 | mode='bilinear', padding_mode='border')
177 | return x_warp, last_flow,
178 |
179 |
180 | class AFWM(nn.Module):
181 |
182 | def __init__(self, opt, input_nc):
183 | super(AFWM, self).__init__()
184 | num_filters = [64,128,256,256,256]
185 | self.image_features = FeatureEncoder(3, num_filters)
186 | self.cond_features = FeatureEncoder(input_nc, num_filters)
187 | self.image_FPN = RefinePyramid(num_filters)
188 | self.cond_FPN = RefinePyramid(num_filters)
189 | self.aflow_net = AFlowNet(len(num_filters))
190 |
191 | def forward(self, cond_input, image_input):
192 | cond_pyramids = self.cond_FPN(self.cond_features(cond_input)) # maybe use nn.Sequential
193 | image_pyramids = self.image_FPN(self.image_features(image_input))
194 |
195 | x_warp, last_flow = self.aflow_net(image_input, image_pyramids, cond_pyramids)
196 |
197 | return x_warp, last_flow
198 |
199 |
--------------------------------------------------------------------------------
/PF-AFN_test/models/correlation/README.md:
--------------------------------------------------------------------------------
1 | This is an adaptation of the FlowNet2 implementation in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately.
--------------------------------------------------------------------------------
/PF-AFN_test/models/correlation/__pycache__/correlation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/models/correlation/__pycache__/correlation.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/models/correlation/__pycache__/correlation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/models/correlation/__pycache__/correlation.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/models/correlation/correlation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import torch
4 |
5 | import cupy
6 | import math
7 | import re
8 |
9 | kernel_Correlation_rearrange = '''
10 | extern "C" __global__ void kernel_Correlation_rearrange(
11 | const int n,
12 | const float* input,
13 | float* output
14 | ) {
15 | int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
16 |
17 | if (intIndex >= n) {
18 | return;
19 | }
20 |
21 | int intSample = blockIdx.z;
22 | int intChannel = blockIdx.y;
23 |
24 | float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
25 |
26 | __syncthreads();
27 |
28 | int intPaddedY = (intIndex / SIZE_3(input)) + 3*{{intStride}};
29 | int intPaddedX = (intIndex % SIZE_3(input)) + 3*{{intStride}};
30 | int intRearrange = ((SIZE_3(input) + 6*{{intStride}}) * intPaddedY) + intPaddedX;
31 |
32 | output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue;
33 | }
34 | '''
35 |
36 | kernel_Correlation_updateOutput = '''
37 | extern "C" __global__ void kernel_Correlation_updateOutput(
38 | const int n,
39 | const float* rbot0,
40 | const float* rbot1,
41 | float* top
42 | ) {
43 | extern __shared__ char patch_data_char[];
44 |
45 | float *patch_data = (float *)patch_data_char;
46 |
47 | // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
48 | int x1 = (blockIdx.x + 3) * {{intStride}};
49 | int y1 = (blockIdx.y + 3) * {{intStride}};
50 | int item = blockIdx.z;
51 | int ch_off = threadIdx.x;
52 |
53 | // Load 3D patch into shared shared memory
54 | for (int j = 0; j < 1; j++) { // HEIGHT
55 | for (int i = 0; i < 1; i++) { // WIDTH
56 | int ji_off = (j + i) * SIZE_3(rbot0);
57 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
58 | int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
59 | int idxPatchData = ji_off + ch;
60 | patch_data[idxPatchData] = rbot0[idx1];
61 | }
62 | }
63 | }
64 |
65 | __syncthreads();
66 |
67 | __shared__ float sum[32];
68 |
69 | // Compute correlation
70 | for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
71 | sum[ch_off] = 0;
72 |
73 | int s2o = (top_channel % 7 - 3) * {{intStride}};
74 | int s2p = (top_channel / 7 - 3) * {{intStride}};
75 |
76 | for (int j = 0; j < 1; j++) { // HEIGHT
77 | for (int i = 0; i < 1; i++) { // WIDTH
78 | int ji_off = (j + i) * SIZE_3(rbot0);
79 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
80 | int x2 = x1 + s2o;
81 | int y2 = y1 + s2p;
82 |
83 | int idxPatchData = ji_off + ch;
84 | int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
85 |
86 | sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
87 | }
88 | }
89 | }
90 |
91 | __syncthreads();
92 |
93 | if (ch_off == 0) {
94 | float total_sum = 0;
95 | for (int idx = 0; idx < 32; idx++) {
96 | total_sum += sum[idx];
97 | }
98 | const int sumelems = SIZE_3(rbot0);
99 | const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
100 | top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
101 | }
102 | }
103 | }
104 | '''
105 |
106 | kernel_Correlation_updateGradFirst = '''
107 | #define ROUND_OFF 50000
108 |
109 | extern "C" __global__ void kernel_Correlation_updateGradFirst(
110 | const int n,
111 | const int intSample,
112 | const float* rbot0,
113 | const float* rbot1,
114 | const float* gradOutput,
115 | float* gradFirst,
116 | float* gradSecond
117 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
118 | int n = intIndex % SIZE_1(gradFirst); // channels
119 | int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 3*{{intStride}}; // w-pos
120 | int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 3*{{intStride}}; // h-pos
121 |
122 | // round_off is a trick to enable integer division with ceil, even for negative numbers
123 | // We use a large offset, for the inner part not to become negative.
124 | const int round_off = ROUND_OFF;
125 | const int round_off_s1 = {{intStride}} * round_off;
126 |
127 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
128 | int xmin = (l - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}}
129 | int ymin = (m - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}}
130 |
131 | // Same here:
132 | int xmax = (l - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}}) / {{intStride}}
133 | int ymax = (m - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}}) / {{intStride}}
134 |
135 | float sum = 0;
136 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
137 | xmin = max(0,xmin);
138 | xmax = min(SIZE_3(gradOutput)-1,xmax);
139 |
140 | ymin = max(0,ymin);
141 | ymax = min(SIZE_2(gradOutput)-1,ymax);
142 |
143 | for (int p = -3; p <= 3; p++) {
144 | for (int o = -3; o <= 3; o++) {
145 | // Get rbot1 data:
146 | int s2o = {{intStride}} * o;
147 | int s2p = {{intStride}} * p;
148 | int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
149 | float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
150 |
151 | // Index offset for gradOutput in following loops:
152 | int op = (p+3) * 7 + (o+3); // index[o,p]
153 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
154 |
155 | for (int y = ymin; y <= ymax; y++) {
156 | for (int x = xmin; x <= xmax; x++) {
157 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
158 | sum += gradOutput[idxgradOutput] * bot1tmp;
159 | }
160 | }
161 | }
162 | }
163 | }
164 | const int sumelems = SIZE_1(gradFirst);
165 | const int bot0index = ((n * SIZE_2(gradFirst)) + (m-3*{{intStride}})) * SIZE_3(gradFirst) + (l-3*{{intStride}});
166 | gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems;
167 | } }
168 | '''
169 |
170 | kernel_Correlation_updateGradSecond = '''
171 | #define ROUND_OFF 50000
172 |
173 | extern "C" __global__ void kernel_Correlation_updateGradSecond(
174 | const int n,
175 | const int intSample,
176 | const float* rbot0,
177 | const float* rbot1,
178 | const float* gradOutput,
179 | float* gradFirst,
180 | float* gradSecond
181 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
182 | int n = intIndex % SIZE_1(gradSecond); // channels
183 | int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 3*{{intStride}}; // w-pos
184 | int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 3*{{intStride}}; // h-pos
185 |
186 | // round_off is a trick to enable integer division with ceil, even for negative numbers
187 | // We use a large offset, for the inner part not to become negative.
188 | const int round_off = ROUND_OFF;
189 | const int round_off_s1 = {{intStride}} * round_off;
190 |
191 | float sum = 0;
192 | for (int p = -3; p <= 3; p++) {
193 | for (int o = -3; o <= 3; o++) {
194 | int s2o = {{intStride}} * o;
195 | int s2p = {{intStride}} * p;
196 |
197 | //Get X,Y ranges and clamp
198 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
199 | int xmin = (l - 3*{{intStride}} - s2o + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}}
200 | int ymin = (m - 3*{{intStride}} - s2p + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}}
201 |
202 | // Same here:
203 | int xmax = (l - 3*{{intStride}} - s2o + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}} - s2o) / {{intStride}}
204 | int ymax = (m - 3*{{intStride}} - s2p + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}} - s2p) / {{intStride}}
205 |
206 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
207 | xmin = max(0,xmin);
208 | xmax = min(SIZE_3(gradOutput)-1,xmax);
209 |
210 | ymin = max(0,ymin);
211 | ymax = min(SIZE_2(gradOutput)-1,ymax);
212 |
213 | // Get rbot0 data:
214 | int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
215 | float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
216 |
217 | // Index offset for gradOutput in following loops:
218 | int op = (p+3) * 7 + (o+3); // index[o,p]
219 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
220 |
221 | for (int y = ymin; y <= ymax; y++) {
222 | for (int x = xmin; x <= xmax; x++) {
223 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
224 | sum += gradOutput[idxgradOutput] * bot0tmp;
225 | }
226 | }
227 | }
228 | }
229 | }
230 | const int sumelems = SIZE_1(gradSecond);
231 | const int bot1index = ((n * SIZE_2(gradSecond)) + (m-3*{{intStride}})) * SIZE_3(gradSecond) + (l-3*{{intStride}});
232 | gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems;
233 | } }
234 | '''
235 |
236 | def cupy_kernel(strFunction, objVariables):
237 | strKernel = globals()[strFunction].replace('{{intStride}}', str(objVariables['intStride']))
238 |
239 | while True:
240 | objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
241 |
242 | if objMatch is None:
243 | break
244 | # end
245 |
246 | intArg = int(objMatch.group(2))
247 |
248 | strTensor = objMatch.group(4)
249 | intSizes = objVariables[strTensor].size()
250 |
251 | strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
252 | # end
253 |
254 | while True:
255 | objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
256 |
257 | if objMatch is None:
258 | break
259 | # end
260 |
261 | intArgs = int(objMatch.group(2))
262 | strArgs = objMatch.group(4).split(',')
263 |
264 | strTensor = strArgs[0]
265 | intStrides = objVariables[strTensor].stride()
266 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]
267 |
268 | strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
269 | # end
270 |
271 | return strKernel
272 | # end
273 |
274 | @cupy.util.memoize(for_each_device=True)
275 | def cupy_launch(strFunction, strKernel):
276 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
277 | # end
278 |
279 | class _FunctionCorrelation(torch.autograd.Function):
280 | @staticmethod
281 | def forward(self, first, second, intStride):
282 | rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + (6 * intStride), first.shape[3] + (6 * intStride), first.shape[1] ])
283 | rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + (6 * intStride), first.shape[3] + (6 * intStride), first.shape[1] ])
284 |
285 | self.save_for_backward(first, second, rbot0, rbot1)
286 |
287 | self.intStride = intStride
288 |
289 | assert(first.is_contiguous() == True)
290 | assert(second.is_contiguous() == True)
291 |
292 | output = first.new_zeros([ first.shape[0], 49, int(math.ceil(first.shape[2] / intStride)), int(math.ceil(first.shape[3] / intStride)) ])
293 |
294 | if first.is_cuda == True:
295 | n = first.shape[2] * first.shape[3]
296 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
297 | 'intStride': self.intStride,
298 | 'input': first,
299 | 'output': rbot0
300 | }))(
301 | grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]),
302 | block=tuple([ 16, 1, 1 ]),
303 | args=[ n, first.data_ptr(), rbot0.data_ptr() ]
304 | )
305 |
306 | n = second.shape[2] * second.shape[3]
307 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
308 | 'intStride': self.intStride,
309 | 'input': second,
310 | 'output': rbot1
311 | }))(
312 | grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]),
313 | block=tuple([ 16, 1, 1 ]),
314 | args=[ n, second.data_ptr(), rbot1.data_ptr() ]
315 | )
316 |
317 | n = output.shape[1] * output.shape[2] * output.shape[3]
318 | cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {
319 | 'intStride': self.intStride,
320 | 'rbot0': rbot0,
321 | 'rbot1': rbot1,
322 | 'top': output
323 | }))(
324 | grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]),
325 | block=tuple([ 32, 1, 1 ]),
326 | shared_mem=first.shape[1] * 4,
327 | args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ]
328 | )
329 |
330 | elif first.is_cuda == False:
331 | raise NotImplementedError()
332 |
333 | # end
334 |
335 | return output
336 | # end
337 |
338 | @staticmethod
339 | def backward(self, gradOutput):
340 | first, second, rbot0, rbot1 = self.saved_tensors
341 |
342 | assert(gradOutput.is_contiguous() == True)
343 |
344 | gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None
345 | gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None
346 |
347 | if first.is_cuda == True:
348 | if gradFirst is not None:
349 | for intSample in range(first.shape[0]):
350 | n = first.shape[1] * first.shape[2] * first.shape[3]
351 | cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', {
352 | 'intStride': self.intStride,
353 | 'rbot0': rbot0,
354 | 'rbot1': rbot1,
355 | 'gradOutput': gradOutput,
356 | 'gradFirst': gradFirst,
357 | 'gradSecond': None
358 | }))(
359 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
360 | block=tuple([ 512, 1, 1 ]),
361 | args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ]
362 | )
363 | # end
364 | # end
365 |
366 | if gradSecond is not None:
367 | for intSample in range(first.shape[0]):
368 | n = first.shape[1] * first.shape[2] * first.shape[3]
369 | cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', {
370 | 'intStride': self.intStride,
371 | 'rbot0': rbot0,
372 | 'rbot1': rbot1,
373 | 'gradOutput': gradOutput,
374 | 'gradFirst': None,
375 | 'gradSecond': gradSecond
376 | }))(
377 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
378 | block=tuple([ 512, 1, 1 ]),
379 | args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ]
380 | )
381 | # end
382 | # end
383 |
384 | elif first.is_cuda == False:
385 | raise NotImplementedError()
386 |
387 | # end
388 |
389 | return gradFirst, gradSecond, None
390 | # end
391 | # end
392 |
393 | def FunctionCorrelation(tenFirst, tenSecond, intStride):
394 | return _FunctionCorrelation.apply(tenFirst, tenSecond, intStride)
395 | # end
396 |
397 | class ModuleCorrelation(torch.nn.Module):
398 | def __init__(self):
399 | super(ModuleCorrelation, self).__init__()
400 | # end
401 |
402 | def forward(self, tenFirst, tenSecond, intStride):
403 | return _FunctionCorrelation.apply(tenFirst, tenSecond, intStride)
404 | # end
405 | # end
--------------------------------------------------------------------------------
/PF-AFN_test/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.parallel
4 | import os
5 |
6 | class UnetSkipConnectionBlock(nn.Module):
7 | def __init__(self, outer_nc, inner_nc, input_nc=None,
8 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
9 | super(UnetSkipConnectionBlock, self).__init__()
10 | self.outermost = outermost
11 | use_bias = norm_layer == nn.InstanceNorm2d
12 |
13 | if input_nc is None:
14 | input_nc = outer_nc
15 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
16 | stride=2, padding=1, bias=use_bias)
17 | downrelu = nn.LeakyReLU(0.2, True)
18 | uprelu = nn.ReLU(True)
19 | if norm_layer != None:
20 | downnorm = norm_layer(inner_nc)
21 | upnorm = norm_layer(outer_nc)
22 |
23 | if outermost:
24 | upsample = nn.Upsample(scale_factor=2, mode='bilinear')
25 | upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
26 | down = [downconv]
27 | up = [uprelu, upsample, upconv]
28 | model = down + [submodule] + up
29 | elif innermost:
30 | upsample = nn.Upsample(scale_factor=2, mode='bilinear')
31 | upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
32 | down = [downrelu, downconv]
33 | if norm_layer == None:
34 | up = [uprelu, upsample, upconv]
35 | else:
36 | up = [uprelu, upsample, upconv, upnorm]
37 | model = down + up
38 | else:
39 | upsample = nn.Upsample(scale_factor=2, mode='bilinear')
40 | upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
41 | if norm_layer == None:
42 | down = [downrelu, downconv]
43 | up = [uprelu, upsample, upconv]
44 | else:
45 | down = [downrelu, downconv, downnorm]
46 | up = [uprelu, upsample, upconv, upnorm]
47 |
48 | if use_dropout:
49 | model = down + [submodule] + up + [nn.Dropout(0.5)]
50 | else:
51 | model = down + [submodule] + up
52 |
53 | self.model = nn.Sequential(*model)
54 |
55 | def forward(self, x):
56 | if self.outermost:
57 | return self.model(x)
58 | else:
59 | return torch.cat([x, self.model(x)], 1)
60 |
61 | class ResidualBlock(nn.Module):
62 | def __init__(self, in_features=64, norm_layer=nn.BatchNorm2d):
63 | super(ResidualBlock, self).__init__()
64 | self.relu = nn.ReLU(True)
65 | if norm_layer == None:
66 | self.block = nn.Sequential(
67 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
68 | nn.ReLU(inplace=True),
69 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
70 | )
71 | else:
72 | self.block = nn.Sequential(
73 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
74 | norm_layer(in_features),
75 | nn.ReLU(inplace=True),
76 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
77 | norm_layer(in_features)
78 | )
79 |
80 | def forward(self, x):
81 | residual = x
82 | out = self.block(x)
83 | out += residual
84 | out = self.relu(out)
85 | return out
86 |
87 | class ResUnetGenerator(nn.Module):
88 | def __init__(self, input_nc, output_nc, num_downs, ngf=64,
89 | norm_layer=nn.BatchNorm2d, use_dropout=False):
90 | super(ResUnetGenerator, self).__init__()
91 | unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
92 |
93 | for i in range(num_downs - 5):
94 | unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
95 | unet_block = ResUnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
96 | unet_block = ResUnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
97 | unet_block = ResUnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
98 | unet_block = ResUnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
99 |
100 | self.model = unet_block
101 |
102 | def forward(self, input):
103 | return self.model(input)
104 |
105 |
106 | class ResUnetSkipConnectionBlock(nn.Module):
107 | def __init__(self, outer_nc, inner_nc, input_nc=None,
108 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
109 | super(ResUnetSkipConnectionBlock, self).__init__()
110 | self.outermost = outermost
111 | use_bias = norm_layer == nn.InstanceNorm2d
112 |
113 | if input_nc is None:
114 | input_nc = outer_nc
115 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3,
116 | stride=2, padding=1, bias=use_bias)
117 |
118 | res_downconv = [ResidualBlock(inner_nc, norm_layer), ResidualBlock(inner_nc, norm_layer)]
119 | res_upconv = [ResidualBlock(outer_nc, norm_layer), ResidualBlock(outer_nc, norm_layer)]
120 |
121 | downrelu = nn.ReLU(True)
122 | uprelu = nn.ReLU(True)
123 | if norm_layer != None:
124 | downnorm = norm_layer(inner_nc)
125 | upnorm = norm_layer(outer_nc)
126 |
127 | if outermost:
128 | upsample = nn.Upsample(scale_factor=2, mode='nearest')
129 | upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
130 | down = [downconv, downrelu] + res_downconv
131 | up = [upsample, upconv]
132 | model = down + [submodule] + up
133 | elif innermost:
134 | upsample = nn.Upsample(scale_factor=2, mode='nearest')
135 | upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
136 | down = [downconv, downrelu] + res_downconv
137 | if norm_layer == None:
138 | up = [upsample, upconv, uprelu] + res_upconv
139 | else:
140 | up = [upsample, upconv, upnorm, uprelu] + res_upconv
141 | model = down + up
142 | else:
143 | upsample = nn.Upsample(scale_factor=2, mode='nearest')
144 | upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
145 | if norm_layer == None:
146 | down = [downconv, downrelu] + res_downconv
147 | up = [upsample, upconv, uprelu] + res_upconv
148 | else:
149 | down = [downconv, downnorm, downrelu] + res_downconv
150 | up = [upsample, upconv, upnorm, uprelu] + res_upconv
151 |
152 | if use_dropout:
153 | model = down + [submodule] + up + [nn.Dropout(0.5)]
154 | else:
155 | model = down + [submodule] + up
156 |
157 | self.model = nn.Sequential(*model)
158 |
159 | def forward(self, x):
160 | if self.outermost:
161 | return self.model(x)
162 | else:
163 | return torch.cat([x, self.model(x)], 1)
164 |
165 |
166 | def save_checkpoint(model, save_path):
167 | if not os.path.exists(os.path.dirname(save_path)):
168 | os.makedirs(os.path.dirname(save_path))
169 | torch.save(model.state_dict(), save_path)
170 |
171 |
172 | def load_checkpoint(model, checkpoint_path):
173 |
174 | if not os.path.exists(checkpoint_path):
175 | print('No checkpoint!')
176 | return
177 |
178 | checkpoint = torch.load(checkpoint_path)
179 | checkpoint_new = model.state_dict()
180 | for param in checkpoint_new:
181 | checkpoint_new[param] = checkpoint[param]
182 |
183 | model.load_state_dict(checkpoint_new)
184 |
185 |
186 |
187 |
--------------------------------------------------------------------------------
/PF-AFN_test/options/__init__.py:
--------------------------------------------------------------------------------
1 | # options_init
--------------------------------------------------------------------------------
/PF-AFN_test/options/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/options/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/options/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/options/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/options/__pycache__/base_options.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/options/__pycache__/base_options.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/options/__pycache__/base_options.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/options/__pycache__/base_options.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/options/__pycache__/test_options.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/options/__pycache__/test_options.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/options/__pycache__/test_options.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/options/__pycache__/test_options.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | class BaseOptions():
5 | def __init__(self):
6 | self.parser = argparse.ArgumentParser()
7 | self.initialized = False
8 |
9 | def initialize(self):
10 | self.parser.add_argument('--name', type=str, default='demo', help='name of the experiment. It decides where to store samples and models')
11 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
12 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
13 | self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
14 | self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit")
15 | self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose')
16 |
17 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
18 | self.parser.add_argument('--loadSize', type=int, default=512, help='scale images to this size')
19 | self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
20 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
21 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
22 |
23 | self.parser.add_argument('--dataroot', type=str,
24 | default='dataset/')
25 | self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
26 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
27 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
28 | self.parser.add_argument('--nThreads', default=1, type=int, help='# threads for loading data')
29 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
30 |
31 | self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size')
32 | self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
33 |
34 | self.initialized = True
35 |
36 | def parse(self, save=True):
37 | if not self.initialized:
38 | self.initialize()
39 | self.opt = self.parser.parse_args()
40 | self.opt.isTrain = self.isTrain # train or test
41 |
42 | str_ids = self.opt.gpu_ids.split(',')
43 | self.opt.gpu_ids = []
44 | for str_id in str_ids:
45 | id = int(str_id)
46 | if id >= 0:
47 | self.opt.gpu_ids.append(id)
48 |
49 | if len(self.opt.gpu_ids) > 0:
50 | torch.cuda.set_device(self.opt.gpu_ids[0])
51 |
52 | args = vars(self.opt)
53 |
54 | print('------------ Options -------------')
55 | for k, v in sorted(args.items()):
56 | print('%s: %s' % (str(k), str(v)))
57 | print('-------------- End ----------------')
58 |
59 | return self.opt
60 |
--------------------------------------------------------------------------------
/PF-AFN_test/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 | class TestOptions(BaseOptions):
4 | def initialize(self):
5 | BaseOptions.initialize(self)
6 |
7 | self.parser.add_argument('--warp_checkpoint', type=str, default='checkpoints/PFAFN/warp_model_final.pth', help='load the pretrained model from the specified location')
8 | self.parser.add_argument('--gen_checkpoint', type=str, default='checkpoints/PFAFN/gen_model_final.pth', help='load the pretrained model from the specified location')
9 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
10 |
11 | self.isTrain = False
12 |
--------------------------------------------------------------------------------
/PF-AFN_test/results/demo/PFAFN/0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/results/demo/PFAFN/0.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/results/demo/PFAFN/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/results/demo/PFAFN/1.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/results/demo/PFAFN/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/results/demo/PFAFN/2.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/results/demo/PFAFN/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/results/demo/PFAFN/3.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/results/demo/PFAFN/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/results/demo/PFAFN/4.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/results/demo/PFAFN/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/results/demo/PFAFN/5.jpg
--------------------------------------------------------------------------------
/PF-AFN_test/test.py:
--------------------------------------------------------------------------------
1 | import time
2 | from options.test_options import TestOptions
3 | from data.data_loader_test import CreateDataLoader
4 | from models.networks import ResUnetGenerator, load_checkpoint
5 | from models.afwm import AFWM
6 | import torch.nn as nn
7 | import os
8 | import numpy as np
9 | import torch
10 | import cv2
11 | import torch.nn.functional as F
12 |
13 | opt = TestOptions().parse()
14 |
15 | start_epoch, epoch_iter = 1, 0
16 |
17 | data_loader = CreateDataLoader(opt)
18 | dataset = data_loader.load_data()
19 | dataset_size = len(data_loader)
20 | print(dataset_size)
21 |
22 | warp_model = AFWM(opt, 3)
23 | print(warp_model)
24 | warp_model.eval()
25 | warp_model.cuda()
26 | load_checkpoint(warp_model, opt.warp_checkpoint)
27 |
28 | gen_model = ResUnetGenerator(7, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d)
29 | print(gen_model)
30 | gen_model.eval()
31 | gen_model.cuda()
32 | load_checkpoint(gen_model, opt.gen_checkpoint)
33 |
34 | total_steps = (start_epoch-1) * dataset_size + epoch_iter
35 | step = 0
36 | step_per_batch = dataset_size / opt.batchSize
37 |
38 | for epoch in range(1,2):
39 |
40 | for i, data in enumerate(dataset, start=epoch_iter):
41 | iter_start_time = time.time()
42 | total_steps += opt.batchSize
43 | epoch_iter += opt.batchSize
44 |
45 | real_image = data['image']
46 | clothes = data['clothes']
47 | ##edge is extracted from the clothes image with the built-in function in python
48 | edge = data['edge']
49 | edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int))
50 | clothes = clothes * edge
51 |
52 | flow_out = warp_model(real_image.cuda(), clothes.cuda())
53 | warped_cloth, last_flow, = flow_out
54 | warped_edge = F.grid_sample(edge.cuda(), last_flow.permute(0, 2, 3, 1),
55 | mode='bilinear', padding_mode='zeros')
56 |
57 | gen_inputs = torch.cat([real_image.cuda(), warped_cloth, warped_edge], 1)
58 | gen_outputs = gen_model(gen_inputs)
59 | p_rendered, m_composite = torch.split(gen_outputs, [3, 1], 1)
60 | p_rendered = torch.tanh(p_rendered)
61 | m_composite = torch.sigmoid(m_composite)
62 | m_composite = m_composite * warped_edge
63 | p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite)
64 |
65 | path = 'results/' + opt.name
66 | os.makedirs(path, exist_ok=True)
67 | sub_path = path + '/PFAFN'
68 | os.makedirs(sub_path,exist_ok=True)
69 |
70 | if step % 1 == 0:
71 | a = real_image.float().cuda()
72 | b= clothes.cuda()
73 | c = p_tryon
74 | combine = torch.cat([a[0],b[0],c[0]], 2).squeeze()
75 | cv_img=(combine.permute(1,2,0).detach().cpu().numpy()+1)/2
76 | rgb=(cv_img*255).astype(np.uint8)
77 | bgr=cv2.cvtColor(rgb,cv2.COLOR_RGB2BGR)
78 | cv2.imwrite(sub_path+'/'+str(step)+'.jpg',bgr)
79 |
80 | step += 1
81 | if epoch_iter >= dataset_size:
82 | break
83 |
84 |
85 |
--------------------------------------------------------------------------------
/PF-AFN_test/test.sh:
--------------------------------------------------------------------------------
1 | python test.py --name demo --resize_or_crop None --batchSize 1 --gpu_ids 0
2 |
--------------------------------------------------------------------------------
/PF-AFN_test/util/__init__.py:
--------------------------------------------------------------------------------
1 | # util_init
--------------------------------------------------------------------------------
/PF-AFN_test/util/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/util/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/util/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/util/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/util/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/util/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/util/__pycache__/util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_test/util/__pycache__/util.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_test/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from torch.autograd import Variable
4 | class ImagePool():
5 | def __init__(self, pool_size):
6 | self.pool_size = pool_size
7 | if self.pool_size > 0:
8 | self.num_imgs = 0
9 | self.images = []
10 |
11 | def query(self, images):
12 | if self.pool_size == 0:
13 | return images
14 | return_images = []
15 | for image in images.data:
16 | image = torch.unsqueeze(image, 0)
17 | if self.num_imgs < self.pool_size:
18 | self.num_imgs = self.num_imgs + 1
19 | self.images.append(image)
20 | return_images.append(image)
21 | else:
22 | p = random.uniform(0, 1)
23 | if p > 0.5:
24 | random_id = random.randint(0, self.pool_size-1)
25 | tmp = self.images[random_id].clone()
26 | self.images[random_id] = image
27 | return_images.append(tmp)
28 | else:
29 | return_images.append(image)
30 | return_images = Variable(torch.cat(return_images, 0))
31 | return return_images
32 |
--------------------------------------------------------------------------------
/PF-AFN_test/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | from PIL import Image
5 | import numpy as np
6 | import os
7 |
8 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
9 | if isinstance(image_tensor, list):
10 | image_numpy = []
11 | for i in range(len(image_tensor)):
12 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
13 | return image_numpy
14 | image_numpy = image_tensor.cpu().float().numpy()
15 |
16 | image_numpy = (image_numpy + 1) / 2.0
17 | image_numpy = np.clip(image_numpy, 0, 1)
18 | if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
19 | image_numpy = image_numpy[:,:,0]
20 |
21 | return image_numpy
22 |
23 | def tensor2label(label_tensor, n_label, imtype=np.uint8):
24 | if n_label == 0:
25 | return tensor2im(label_tensor, imtype)
26 | label_tensor = label_tensor.cpu().float()
27 | if label_tensor.size()[0] > 1:
28 | label_tensor = label_tensor.max(0, keepdim=True)[1]
29 | label_tensor = Colorize(n_label)(label_tensor)
30 | label_numpy = label_tensor.numpy()
31 | label_numpy = label_numpy / 255.0
32 |
33 | return label_numpy
34 |
35 | def save_image(image_numpy, image_path):
36 | image_pil = Image.fromarray(image_numpy)
37 | image_pil.save(image_path)
38 |
39 | def mkdirs(paths):
40 | if isinstance(paths, list) and not isinstance(paths, str):
41 | for path in paths:
42 | mkdir(path)
43 | else:
44 | mkdir(paths)
45 |
46 | def mkdir(path):
47 | if not os.path.exists(path):
48 | os.makedirs(path)
49 |
50 |
51 | def uint82bin(n, count=8):
52 | """returns the binary of integer n, count refers to amount of bits"""
53 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
54 |
55 | def labelcolormap(N):
56 | if N == 35: # cityscape
57 | cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
58 | (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
59 | (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
60 | (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
61 | ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
62 | dtype=np.uint8)
63 | else:
64 | cmap = np.zeros((N, 3), dtype=np.uint8)
65 | for i in range(N):
66 | r, g, b = 0, 0, 0
67 | id = i
68 | for j in range(7):
69 | str_id = uint82bin(id)
70 | r = r ^ (np.uint8(str_id[-1]) << (7-j))
71 | g = g ^ (np.uint8(str_id[-2]) << (7-j))
72 | b = b ^ (np.uint8(str_id[-3]) << (7-j))
73 | id = id >> 3
74 | cmap[i, 0] = r
75 | cmap[i, 1] = g
76 | cmap[i, 2] = b
77 | return cmap
78 |
79 | class Colorize(object):
80 | def __init__(self, n=35):
81 | self.cmap = labelcolormap(n)
82 | self.cmap = torch.from_numpy(self.cmap[:n])
83 |
84 | def __call__(self, gray_image):
85 | size = gray_image.size()
86 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
87 |
88 | for label in range(0, len(self.cmap)):
89 | mask = (label == gray_image[0]).cpu()
90 | color_image[0][mask] = self.cmap[label][0]
91 | color_image[1][mask] = self.cmap[label][1]
92 | color_image[2][mask] = self.cmap[label][2]
93 |
94 | return color_image
95 |
--------------------------------------------------------------------------------
/PF-AFN_train/checkpoints/readme.txt:
--------------------------------------------------------------------------------
1 | The checkpoints will be saved here.
2 |
--------------------------------------------------------------------------------
/PF-AFN_train/data/__init__.py:
--------------------------------------------------------------------------------
1 | # data_init
--------------------------------------------------------------------------------
/PF-AFN_train/data/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/data/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/data/__pycache__/aligned_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/aligned_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/data/__pycache__/aligned_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/aligned_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/data/__pycache__/aligned_dataset_fake.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/aligned_dataset_fake.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/data/__pycache__/aligned_dataset_test.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/aligned_dataset_test.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/data/__pycache__/base_data_loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/base_data_loader.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/data/__pycache__/base_dataset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/base_dataset.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/data/__pycache__/base_dataset.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/base_dataset.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/data/__pycache__/custom_dataset_data_loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/custom_dataset_data_loader.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/data/__pycache__/custom_dataset_data_loader_test.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/custom_dataset_data_loader_test.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/data/__pycache__/data_loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/data_loader.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/data/__pycache__/data_loader_test.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/data_loader_test.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/data/__pycache__/image_folder.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/image_folder.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/data/__pycache__/image_folder.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/data/__pycache__/image_folder.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/data/aligned_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from data.base_dataset import BaseDataset, get_params, get_transform
3 | from data.image_folder import make_dataset
4 | from PIL import Image
5 | import torch
6 | import json
7 | import numpy as np
8 | import os.path as osp
9 | from PIL import ImageDraw
10 |
11 |
12 | class AlignedDataset(BaseDataset):
13 | def initialize(self, opt):
14 | self.opt = opt
15 | self.root = opt.dataroot
16 | self.diction={}
17 |
18 | if opt.isTrain or opt.use_encoded_image:
19 | dir_A = '_A' if self.opt.label_nc == 0 else '_label'
20 | self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
21 | self.A_paths = sorted(make_dataset(self.dir_A))
22 |
23 | self.fine_height=256
24 | self.fine_width=192
25 | self.radius=5
26 |
27 | dir_B = '_B' if self.opt.label_nc == 0 else '_img'
28 | self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
29 | self.B_paths = sorted(make_dataset(self.dir_B))
30 |
31 | self.dataset_size = len(self.A_paths)
32 |
33 | if opt.isTrain or opt.use_encoded_image:
34 | dir_E = '_edge'
35 | self.dir_E = os.path.join(opt.dataroot, opt.phase + dir_E)
36 | self.E_paths = sorted(make_dataset(self.dir_E))
37 |
38 | if opt.isTrain or opt.use_encoded_image:
39 | dir_C = '_color'
40 | self.dir_C = os.path.join(opt.dataroot, opt.phase + dir_C)
41 | self.C_paths = sorted(make_dataset(self.dir_C))
42 |
43 |
44 | def __getitem__(self, index):
45 |
46 | A_path = self.A_paths[index]
47 | A = Image.open(A_path).convert('L')
48 |
49 | params = get_params(self.opt, A.size)
50 | if self.opt.label_nc == 0:
51 | transform_A = get_transform(self.opt, params)
52 | A_tensor = transform_A(A.convert('RGB'))
53 | else:
54 | transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
55 | A_tensor = transform_A(A) * 255.0
56 |
57 | B_path = self.B_paths[index]
58 | B = Image.open(B_path).convert('RGB')
59 | transform_B = get_transform(self.opt, params)
60 | B_tensor = transform_B(B)
61 |
62 | C_path = self.C_paths[index]
63 | C = Image.open(C_path).convert('RGB')
64 | C_tensor = transform_B(C)
65 |
66 | E_path = self.E_paths[index]
67 | E = Image.open(E_path).convert('L')
68 | E_tensor = transform_A(E)
69 |
70 | index_un = np.random.randint(14221)
71 | C_un_path = self.C_paths[index_un]
72 | C_un = Image.open(C_un_path).convert('RGB')
73 | C_un_tensor = transform_B(C_un)
74 |
75 | E_un_path = self.E_paths[index_un]
76 | E_un = Image.open(E_un_path).convert('L')
77 | E_un_tensor = transform_A(E_un)
78 |
79 | pose_name =B_path.replace('.png', '_keypoints.json').replace('.jpg','_keypoints.json').replace('train_img','train_pose')
80 | with open(osp.join(pose_name), 'r') as f:
81 | pose_label = json.load(f)
82 | try:
83 | pose_data = pose_label['people'][0]['pose_keypoints']
84 | except IndexError:
85 | pose_data = [0 for i in range(54)]
86 | pose_data = np.array(pose_data)
87 | pose_data = pose_data.reshape((-1,3))
88 |
89 | point_num = pose_data.shape[0]
90 | pose_map = torch.zeros(point_num, self.fine_height, self.fine_width)
91 | r = self.radius
92 | im_pose = Image.new('L', (self.fine_width, self.fine_height))
93 | pose_draw = ImageDraw.Draw(im_pose)
94 | for i in range(point_num):
95 | one_map = Image.new('L', (self.fine_width, self.fine_height))
96 | draw = ImageDraw.Draw(one_map)
97 | pointx = pose_data[i,0]
98 | pointy = pose_data[i,1]
99 | if pointx > 1 and pointy > 1:
100 | draw.rectangle((pointx-r, pointy-r, pointx+r, pointy+r), 'white', 'white')
101 | pose_draw.rectangle((pointx-r, pointy-r, pointx+r, pointy+r), 'white', 'white')
102 | one_map = transform_B(one_map.convert('RGB'))
103 | pose_map[i] = one_map[0]
104 | P_tensor=pose_map
105 |
106 | densepose_name = B_path.replace('.png', '.npy').replace('.jpg','.npy').replace('train_img','train_densepose')
107 | dense_mask = np.load(densepose_name).astype(np.float32)
108 | dense_mask = transform_A(dense_mask)
109 |
110 | if self.opt.isTrain:
111 | input_dict = { 'label': A_tensor, 'image': B_tensor, 'path': A_path, 'img_path': B_path ,'color_path': C_path,'color_un_path': C_un_path,
112 | 'edge': E_tensor, 'color': C_tensor, 'edge_un': E_un_tensor, 'color_un': C_un_tensor, 'pose':P_tensor, 'densepose':dense_mask
113 | }
114 |
115 | return input_dict
116 |
117 | def __len__(self):
118 | return len(self.A_paths) // (self.opt.batchSize * self.opt.num_gpus) * (self.opt.batchSize * self.opt.num_gpus)
119 |
120 | def name(self):
121 | return 'AlignedDataset'
122 |
--------------------------------------------------------------------------------
/PF-AFN_train/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 | class BaseDataLoader():
2 | def __init__(self):
3 | pass
4 |
5 | def initialize(self, opt):
6 | self.opt = opt
7 | pass
8 |
9 | def load_data():
10 | return None
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/PF-AFN_train/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import random
6 |
7 | class BaseDataset(data.Dataset):
8 | def __init__(self):
9 | super(BaseDataset, self).__init__()
10 |
11 | def name(self):
12 | return 'BaseDataset'
13 |
14 | def initialize(self, opt):
15 | pass
16 |
17 | def get_params(opt, size):
18 | w, h = size
19 | new_h = h
20 | new_w = w
21 | if opt.resize_or_crop == 'resize_and_crop':
22 | new_h = new_w = opt.loadSize
23 | elif opt.resize_or_crop == 'scale_width_and_crop':
24 | new_w = opt.loadSize
25 | new_h = opt.loadSize * h // w
26 |
27 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
28 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
29 |
30 | #flip = random.random() > 0.5
31 | flip = 0
32 | return {'crop_pos': (x, y), 'flip': flip}
33 |
34 | def get_transform_resize(opt, params, method=Image.BICUBIC, normalize=True):
35 | transform_list = []
36 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
37 | osize = [256,192]
38 | transform_list.append(transforms.Scale(osize, method))
39 | if 'crop' in opt.resize_or_crop:
40 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
41 |
42 | if opt.resize_or_crop == 'none':
43 | base = float(2 ** opt.n_downsample_global)
44 | if opt.netG == 'local':
45 | base *= (2 ** opt.n_local_enhancers)
46 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
47 |
48 | if opt.isTrain and not opt.no_flip:
49 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
50 |
51 | transform_list += [transforms.ToTensor()]
52 |
53 | if normalize:
54 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
55 | (0.5, 0.5, 0.5))]
56 | return transforms.Compose(transform_list)
57 |
58 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
59 | transform_list = []
60 | if 'resize' in opt.resize_or_crop:
61 | osize = [opt.loadSize, opt.loadSize]
62 | transform_list.append(transforms.Scale(osize, method))
63 | elif 'scale_width' in opt.resize_or_crop:
64 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
65 | osize = [256,192]
66 | transform_list.append(transforms.Scale(osize, method))
67 | if 'crop' in opt.resize_or_crop:
68 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
69 |
70 | if opt.resize_or_crop == 'none':
71 | base = float(2 ** opt.n_downsample_global)
72 | if opt.netG == 'local':
73 | base *= (2 ** opt.n_local_enhancers)
74 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
75 |
76 | if opt.isTrain and not opt.no_flip:
77 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
78 |
79 | transform_list += [transforms.ToTensor()]
80 |
81 | if normalize:
82 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
83 | (0.5, 0.5, 0.5))]
84 | return transforms.Compose(transform_list)
85 |
86 | def normalize():
87 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
88 |
89 | def __make_power_2(img, base, method=Image.BICUBIC):
90 | ow, oh = img.size
91 | h = int(round(oh / base) * base)
92 | w = int(round(ow / base) * base)
93 | if (h == oh) and (w == ow):
94 | return img
95 | return img.resize((w, h), method)
96 |
97 | def __scale_width(img, target_width, method=Image.BICUBIC):
98 | ow, oh = img.size
99 | if (ow == target_width):
100 | return img
101 | w = target_width
102 | h = int(target_width * oh / ow)
103 | return img.resize((w, h), method)
104 |
105 | def __crop(img, pos, size):
106 | ow, oh = img.size
107 | x1, y1 = pos
108 | tw = th = size
109 | if (ow > tw or oh > th):
110 | return img.crop((x1, y1, x1 + tw, y1 + th))
111 | return img
112 |
113 | def __flip(img, flip):
114 | if flip:
115 | return img.transpose(Image.FLIP_LEFT_RIGHT)
116 | return img
117 |
--------------------------------------------------------------------------------
/PF-AFN_train/data/custom_dataset_data_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from data.base_data_loader import BaseDataLoader
3 |
4 |
5 | def CreateDataset(opt):
6 | dataset = None
7 | from data.aligned_dataset import AlignedDataset
8 | dataset = AlignedDataset()
9 |
10 | print("dataset [%s] was created" % (dataset.name()))
11 | dataset.initialize(opt)
12 | return dataset
13 |
14 | class CustomDatasetDataLoader(BaseDataLoader):
15 | def name(self):
16 | return 'CustomDatasetDataLoader'
17 |
18 | def initialize(self, opt):
19 | BaseDataLoader.initialize(self, opt)
20 | self.dataset = CreateDataset(opt)
21 | self.dataloader = torch.utils.data.DataLoader(
22 | self.dataset,
23 | batch_size=opt.batchSize,
24 | shuffle=not opt.serial_batches,
25 | num_workers=int(opt.nThreads))
26 |
27 | def load_data(self):
28 | return self.dataloader
29 |
30 | def __len__(self):
31 | return min(len(self.dataset), self.opt.max_dataset_size)
32 |
--------------------------------------------------------------------------------
/PF-AFN_train/data/data_loader.py:
--------------------------------------------------------------------------------
1 | def CreateDataLoader(opt):
2 | from data.custom_dataset_data_loader import CustomDatasetDataLoader
3 | data_loader = CustomDatasetDataLoader()
4 | print(data_loader.name())
5 | data_loader.initialize(opt)
6 | return data_loader
7 |
--------------------------------------------------------------------------------
/PF-AFN_train/data/image_folder.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image
3 | import os
4 |
5 | IMG_EXTENSIONS = [
6 | '.jpg', '.JPG', '.jpeg', '.JPEG',
7 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
8 | ]
9 |
10 |
11 | def is_image_file(filename):
12 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
13 |
14 | def make_dataset(dir):
15 | images = []
16 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
17 |
18 | f = dir.split('/')[-1].split('_')[-1]
19 | print (dir, f)
20 | dirs= os.listdir(dir)
21 | for img in dirs:
22 |
23 | path = os.path.join(dir, img)
24 | #print(path)
25 | images.append(path)
26 | return images
27 |
28 | def make_dataset_test(dir):
29 | images = []
30 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
31 |
32 | f = dir.split('/')[-1].split('_')[-1]
33 | for i in range(len([name for name in os.listdir(dir) if os.path.isfile(os.path.join(dir, name))])):
34 | if f == 'label' or f == 'labelref':
35 | img = str(i) + '.png'
36 | else:
37 | img = str(i) + '.jpg'
38 | path = os.path.join(dir, img)
39 | #print(path)
40 | images.append(path)
41 | return images
42 |
43 | def default_loader(path):
44 | return Image.open(path).convert('RGB')
45 |
46 |
47 | class ImageFolder(data.Dataset):
48 |
49 | def __init__(self, root, transform=None, return_paths=False,
50 | loader=default_loader):
51 | imgs = make_dataset(root)
52 | if len(imgs) == 0:
53 | raise(RuntimeError("Found 0 images in: " + root + "\n"
54 | "Supported image extensions are: " +
55 | ",".join(IMG_EXTENSIONS)))
56 |
57 | self.root = root
58 | self.imgs = imgs
59 | self.transform = transform
60 | self.return_paths = return_paths
61 | self.loader = loader
62 |
63 | def __getitem__(self, index):
64 | path = self.imgs[index]
65 | img = self.loader(path)
66 | if self.transform is not None:
67 | img = self.transform(img)
68 | if self.return_paths:
69 | return img, path
70 | else:
71 | return img
72 |
73 | def __len__(self):
74 | return len(self.imgs)
75 |
--------------------------------------------------------------------------------
/PF-AFN_train/dataset/readme.txt:
--------------------------------------------------------------------------------
1 | The downloaded dataset should be put here.
2 |
--------------------------------------------------------------------------------
/PF-AFN_train/models/__init__.py:
--------------------------------------------------------------------------------
1 | # model_init
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/afwm.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/afwm.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/base_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/base_model.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_add.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_add.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_add.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_add.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_more.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_add.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_add.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_feat.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_feat.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_grid_offset_sep.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_grid_offset_sep.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_grid_sep.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_grid_sep.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_new.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_new.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_offset.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_offset.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all_more.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all_more.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all_more_heatmap.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all_more_heatmap.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all_more_no_refine.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all_more_no_refine.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all_more_trans.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_revise_new_sep_all_more_trans.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_sep.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_more_sep.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_cor_sep.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_cor_sep.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_smooth.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_smooth.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/flow_gmm_vis.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/flow_gmm_vis.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/models.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/models.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/networks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/networks.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/networks.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/networks.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/networks_flow.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/networks_flow.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/pix2pixHD_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/pix2pixHD_model.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/__pycache__/predict_mask.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/__pycache__/predict_mask.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/afwm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from options.train_options import TrainOptions
6 | from .correlation import correlation # the custom cost volume layer
7 | opt = TrainOptions().parse()
8 |
9 | def apply_offset(offset):
10 | sizes = list(offset.size()[2:])
11 | grid_list = torch.meshgrid([torch.arange(size, device=offset.device) for size in sizes])
12 | grid_list = reversed(grid_list)
13 | # apply offset
14 | grid_list = [grid.float().unsqueeze(0) + offset[:, dim, ...]
15 | for dim, grid in enumerate(grid_list)]
16 | # normalize
17 | grid_list = [grid / ((size - 1.0) / 2.0) - 1.0
18 | for grid, size in zip(grid_list, reversed(sizes))]
19 |
20 | return torch.stack(grid_list, dim=-1)
21 |
22 |
23 | def TVLoss(x):
24 | tv_h = x[:, :, 1:, :] - x[:, :, :-1, :]
25 | tv_w = x[:, :, :, 1:] - x[:, :, :, :-1]
26 |
27 | return torch.mean(torch.abs(tv_h)) + torch.mean(torch.abs(tv_w))
28 |
29 |
30 | # backbone
31 | class ResBlock(nn.Module):
32 | def __init__(self, in_channels):
33 | super(ResBlock, self).__init__()
34 | self.block = nn.Sequential(
35 | nn.BatchNorm2d(in_channels),
36 | nn.ReLU(inplace=True),
37 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False),
38 | nn.BatchNorm2d(in_channels),
39 | nn.ReLU(inplace=True),
40 | nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
41 | )
42 |
43 | def forward(self, x):
44 | return self.block(x) + x
45 |
46 |
47 | class DownSample(nn.Module):
48 | def __init__(self, in_channels, out_channels):
49 | super(DownSample, self).__init__()
50 | self.block= nn.Sequential(
51 | nn.BatchNorm2d(in_channels),
52 | nn.ReLU(inplace=True),
53 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)
54 | )
55 |
56 | def forward(self, x):
57 | return self.block(x)
58 |
59 |
60 |
61 | class FeatureEncoder(nn.Module):
62 | def __init__(self, in_channels, chns=[64,128,256,256,256]):
63 | # in_channels = 3 for images, and is larger (e.g., 17+1+1) for agnositc representation
64 | super(FeatureEncoder, self).__init__()
65 | self.encoders = []
66 | for i, out_chns in enumerate(chns):
67 | if i == 0:
68 | encoder = nn.Sequential(DownSample(in_channels, out_chns),
69 | ResBlock(out_chns),
70 | ResBlock(out_chns))
71 | else:
72 | encoder = nn.Sequential(DownSample(chns[i-1], out_chns),
73 | ResBlock(out_chns),
74 | ResBlock(out_chns))
75 |
76 | self.encoders.append(encoder)
77 |
78 | self.encoders = nn.ModuleList(self.encoders)
79 |
80 |
81 | def forward(self, x):
82 | encoder_features = []
83 | for encoder in self.encoders:
84 | x = encoder(x)
85 | encoder_features.append(x)
86 | return encoder_features
87 |
88 | class RefinePyramid(nn.Module):
89 | def __init__(self, chns=[64,128,256,256,256], fpn_dim=256):
90 | super(RefinePyramid, self).__init__()
91 | self.chns = chns
92 |
93 | # adaptive
94 | self.adaptive = []
95 | for in_chns in list(reversed(chns)):
96 | adaptive_layer = nn.Conv2d(in_chns, fpn_dim, kernel_size=1)
97 | self.adaptive.append(adaptive_layer)
98 | self.adaptive = nn.ModuleList(self.adaptive)
99 | # output conv
100 | self.smooth = []
101 | for i in range(len(chns)):
102 | smooth_layer = nn.Conv2d(fpn_dim, fpn_dim, kernel_size=3, padding=1)
103 | self.smooth.append(smooth_layer)
104 | self.smooth = nn.ModuleList(self.smooth)
105 |
106 | def forward(self, x):
107 | conv_ftr_list = x
108 |
109 | feature_list = []
110 | last_feature = None
111 | for i, conv_ftr in enumerate(list(reversed(conv_ftr_list))):
112 | # adaptive
113 | feature = self.adaptive[i](conv_ftr)
114 | # fuse
115 | if last_feature is not None:
116 | feature = feature + F.interpolate(last_feature, scale_factor=2, mode='nearest')
117 | # smooth
118 | feature = self.smooth[i](feature)
119 | last_feature = feature
120 | feature_list.append(feature)
121 |
122 | return tuple(reversed(feature_list))
123 |
124 |
125 | class AFlowNet(nn.Module):
126 | def __init__(self, num_pyramid, fpn_dim=256):
127 | super(AFlowNet, self).__init__()
128 | self.netMain = []
129 | self.netRefine = []
130 | for i in range(num_pyramid):
131 | netMain_layer = torch.nn.Sequential(
132 | torch.nn.Conv2d(in_channels=49, out_channels=128, kernel_size=3, stride=1, padding=1),
133 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
134 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
135 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
136 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
137 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
138 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
139 | )
140 |
141 | netRefine_layer = torch.nn.Sequential(
142 | torch.nn.Conv2d(2 * fpn_dim, out_channels=128, kernel_size=3, stride=1, padding=1),
143 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
144 | torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
145 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
146 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),
147 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
148 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1)
149 | )
150 | self.netMain.append(netMain_layer)
151 | self.netRefine.append(netRefine_layer)
152 |
153 | self.netMain = nn.ModuleList(self.netMain)
154 | self.netRefine = nn.ModuleList(self.netRefine)
155 |
156 |
157 | def forward(self, x, x_edge, x_warps, x_conds, warp_feature=True):
158 | last_flow = None
159 | last_flow_all = []
160 | delta_list = []
161 | x_all = []
162 | x_edge_all = []
163 | cond_fea_all = []
164 | delta_x_all = []
165 | delta_y_all = []
166 | filter_x = [[0, 0, 0],
167 | [1, -2, 1],
168 | [0, 0, 0]]
169 | filter_y = [[0, 1, 0],
170 | [0, -2, 0],
171 | [0, 1, 0]]
172 | filter_diag1 = [[1, 0, 0],
173 | [0, -2, 0],
174 | [0, 0, 1]]
175 | filter_diag2 = [[0, 0, 1],
176 | [0, -2, 0],
177 | [1, 0, 0]]
178 | weight_array = np.ones([3, 3, 1, 4])
179 | weight_array[:, :, 0, 0] = filter_x
180 | weight_array[:, :, 0, 1] = filter_y
181 | weight_array[:, :, 0, 2] = filter_diag1
182 | weight_array[:, :, 0, 3] = filter_diag2
183 |
184 | weight_array = torch.cuda.FloatTensor(weight_array).permute(3,2,0,1)
185 | self.weight = nn.Parameter(data=weight_array, requires_grad=False)
186 |
187 | for i in range(len(x_warps)):
188 | x_warp = x_warps[len(x_warps) - 1 - i]
189 | x_cond = x_conds[len(x_warps) - 1 - i]
190 | cond_fea_all.append(x_cond)
191 |
192 | if last_flow is not None and warp_feature:
193 | x_warp_after = F.grid_sample(x_warp, last_flow.detach().permute(0, 2, 3, 1),
194 | mode='bilinear', padding_mode='border')
195 | else:
196 | x_warp_after = x_warp
197 |
198 | tenCorrelation = F.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=x_warp_after, tenSecond=x_cond, intStride=1), negative_slope=0.1, inplace=False)
199 | flow = self.netMain[i](tenCorrelation)
200 | delta_list.append(flow)
201 | flow = apply_offset(flow)
202 | if last_flow is not None:
203 | flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border')
204 | else:
205 | flow = flow.permute(0, 3, 1, 2)
206 |
207 | last_flow = flow
208 | x_warp = F.grid_sample(x_warp, flow.permute(0, 2, 3, 1),mode='bilinear', padding_mode='border')
209 | concat = torch.cat([x_warp,x_cond],1)
210 | flow = self.netRefine[i](concat)
211 | delta_list.append(flow)
212 | flow = apply_offset(flow)
213 | flow = F.grid_sample(last_flow, flow, mode='bilinear', padding_mode='border')
214 |
215 | last_flow = F.interpolate(flow, scale_factor=2, mode='bilinear')
216 | last_flow_all.append(last_flow)
217 | cur_x = F.interpolate(x, scale_factor=0.5**(len(x_warps)-1-i), mode='bilinear')
218 | cur_x_warp = F.grid_sample(cur_x, last_flow.permute(0, 2, 3, 1),mode='bilinear', padding_mode='border')
219 | x_all.append(cur_x_warp)
220 | cur_x_edge = F.interpolate(x_edge, scale_factor=0.5**(len(x_warps)-1-i), mode='bilinear')
221 | cur_x_warp_edge = F.grid_sample(cur_x_edge, last_flow.permute(0, 2, 3, 1),mode='bilinear', padding_mode='zeros')
222 | x_edge_all.append(cur_x_warp_edge)
223 | flow_x,flow_y = torch.split(last_flow,1,dim=1)
224 | delta_x = F.conv2d(flow_x, self.weight)
225 | delta_y = F.conv2d(flow_y,self.weight)
226 | delta_x_all.append(delta_x)
227 | delta_y_all.append(delta_y)
228 |
229 | x_warp = F.grid_sample(x, last_flow.permute(0, 2, 3, 1),
230 | mode='bilinear', padding_mode='border')
231 | return x_warp, last_flow, cond_fea_all, last_flow_all, delta_list, x_all, x_edge_all, delta_x_all, delta_y_all
232 |
233 |
234 | class AFWM(nn.Module):
235 |
236 | def __init__(self, opt, input_nc):
237 | super(AFWM, self).__init__()
238 | num_filters = [64,128,256,256,256]
239 | self.image_features = FeatureEncoder(3, num_filters)
240 | self.cond_features = FeatureEncoder(input_nc, num_filters)
241 | self.image_FPN = RefinePyramid(num_filters)
242 | self.cond_FPN = RefinePyramid(num_filters)
243 | self.aflow_net = AFlowNet(len(num_filters))
244 | self.old_lr = opt.lr
245 | self.old_lr_warp = opt.lr*0.2
246 |
247 | def forward(self, cond_input, image_input, image_edge):
248 | cond_pyramids = self.cond_FPN(self.cond_features(cond_input)) # maybe use nn.Sequential
249 | image_pyramids = self.image_FPN(self.image_features(image_input))
250 |
251 | x_warp, last_flow, last_flow_all, flow_all, delta_list, x_all, x_edge_all, delta_x_all, delta_y_all = self.aflow_net(image_input, image_edge, image_pyramids, cond_pyramids)
252 |
253 | return x_warp, last_flow, last_flow_all, flow_all, delta_list, x_all, x_edge_all, delta_x_all, delta_y_all
254 |
255 |
256 | def update_learning_rate(self,optimizer):
257 | lrd = opt.lr / opt.niter_decay
258 | lr = self.old_lr - lrd
259 | for param_group in optimizer.param_groups:
260 | param_group['lr'] = lr
261 | if opt.verbose:
262 | print('update learning rate: %f -> %f' % (self.old_lr, lr))
263 | self.old_lr = lr
264 |
265 | def update_learning_rate_warp(self,optimizer):
266 | lrd = 0.2 * opt.lr / opt.niter_decay
267 | lr = self.old_lr_warp - lrd
268 | for param_group in optimizer.param_groups:
269 | param_group['lr'] = lr
270 | if opt.verbose:
271 | print('update learning rate: %f -> %f' % (self.old_lr_warp, lr))
272 | self.old_lr_warp = lr
273 |
274 |
--------------------------------------------------------------------------------
/PF-AFN_train/models/correlation/README.md:
--------------------------------------------------------------------------------
1 | This is an adaptation of the FlowNet2 implementation in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately.
--------------------------------------------------------------------------------
/PF-AFN_train/models/correlation/__pycache__/correlation.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/correlation/__pycache__/correlation.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/correlation/__pycache__/correlation.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/models/correlation/__pycache__/correlation.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/models/correlation/correlation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import torch
4 |
5 | import cupy
6 | import math
7 | import re
8 |
9 | kernel_Correlation_rearrange = '''
10 | extern "C" __global__ void kernel_Correlation_rearrange(
11 | const int n,
12 | const float* input,
13 | float* output
14 | ) {
15 | int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
16 |
17 | if (intIndex >= n) {
18 | return;
19 | }
20 |
21 | int intSample = blockIdx.z;
22 | int intChannel = blockIdx.y;
23 |
24 | float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
25 |
26 | __syncthreads();
27 |
28 | int intPaddedY = (intIndex / SIZE_3(input)) + 3*{{intStride}};
29 | int intPaddedX = (intIndex % SIZE_3(input)) + 3*{{intStride}};
30 | int intRearrange = ((SIZE_3(input) + 6*{{intStride}}) * intPaddedY) + intPaddedX;
31 |
32 | output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue;
33 | }
34 | '''
35 |
36 | kernel_Correlation_updateOutput = '''
37 | extern "C" __global__ void kernel_Correlation_updateOutput(
38 | const int n,
39 | const float* rbot0,
40 | const float* rbot1,
41 | float* top
42 | ) {
43 | extern __shared__ char patch_data_char[];
44 |
45 | float *patch_data = (float *)patch_data_char;
46 |
47 | // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
48 | int x1 = (blockIdx.x + 3) * {{intStride}};
49 | int y1 = (blockIdx.y + 3) * {{intStride}};
50 | int item = blockIdx.z;
51 | int ch_off = threadIdx.x;
52 |
53 | // Load 3D patch into shared shared memory
54 | for (int j = 0; j < 1; j++) { // HEIGHT
55 | for (int i = 0; i < 1; i++) { // WIDTH
56 | int ji_off = (j + i) * SIZE_3(rbot0);
57 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
58 | int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
59 | int idxPatchData = ji_off + ch;
60 | patch_data[idxPatchData] = rbot0[idx1];
61 | }
62 | }
63 | }
64 |
65 | __syncthreads();
66 |
67 | __shared__ float sum[32];
68 |
69 | // Compute correlation
70 | for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
71 | sum[ch_off] = 0;
72 |
73 | int s2o = (top_channel % 7 - 3) * {{intStride}};
74 | int s2p = (top_channel / 7 - 3) * {{intStride}};
75 |
76 | for (int j = 0; j < 1; j++) { // HEIGHT
77 | for (int i = 0; i < 1; i++) { // WIDTH
78 | int ji_off = (j + i) * SIZE_3(rbot0);
79 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
80 | int x2 = x1 + s2o;
81 | int y2 = y1 + s2p;
82 |
83 | int idxPatchData = ji_off + ch;
84 | int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
85 |
86 | sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
87 | }
88 | }
89 | }
90 |
91 | __syncthreads();
92 |
93 | if (ch_off == 0) {
94 | float total_sum = 0;
95 | for (int idx = 0; idx < 32; idx++) {
96 | total_sum += sum[idx];
97 | }
98 | const int sumelems = SIZE_3(rbot0);
99 | const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
100 | top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
101 | }
102 | }
103 | }
104 | '''
105 |
106 | kernel_Correlation_updateGradFirst = '''
107 | #define ROUND_OFF 50000
108 |
109 | extern "C" __global__ void kernel_Correlation_updateGradFirst(
110 | const int n,
111 | const int intSample,
112 | const float* rbot0,
113 | const float* rbot1,
114 | const float* gradOutput,
115 | float* gradFirst,
116 | float* gradSecond
117 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
118 | int n = intIndex % SIZE_1(gradFirst); // channels
119 | int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 3*{{intStride}}; // w-pos
120 | int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 3*{{intStride}}; // h-pos
121 |
122 | // round_off is a trick to enable integer division with ceil, even for negative numbers
123 | // We use a large offset, for the inner part not to become negative.
124 | const int round_off = ROUND_OFF;
125 | const int round_off_s1 = {{intStride}} * round_off;
126 |
127 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
128 | int xmin = (l - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}}
129 | int ymin = (m - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}}
130 |
131 | // Same here:
132 | int xmax = (l - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}}) / {{intStride}}
133 | int ymax = (m - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}}) / {{intStride}}
134 |
135 | float sum = 0;
136 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
137 | xmin = max(0,xmin);
138 | xmax = min(SIZE_3(gradOutput)-1,xmax);
139 |
140 | ymin = max(0,ymin);
141 | ymax = min(SIZE_2(gradOutput)-1,ymax);
142 |
143 | for (int p = -3; p <= 3; p++) {
144 | for (int o = -3; o <= 3; o++) {
145 | // Get rbot1 data:
146 | int s2o = {{intStride}} * o;
147 | int s2p = {{intStride}} * p;
148 | int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
149 | float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
150 |
151 | // Index offset for gradOutput in following loops:
152 | int op = (p+3) * 7 + (o+3); // index[o,p]
153 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
154 |
155 | for (int y = ymin; y <= ymax; y++) {
156 | for (int x = xmin; x <= xmax; x++) {
157 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
158 | sum += gradOutput[idxgradOutput] * bot1tmp;
159 | }
160 | }
161 | }
162 | }
163 | }
164 | const int sumelems = SIZE_1(gradFirst);
165 | const int bot0index = ((n * SIZE_2(gradFirst)) + (m-3*{{intStride}})) * SIZE_3(gradFirst) + (l-3*{{intStride}});
166 | gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems;
167 | } }
168 | '''
169 |
170 | kernel_Correlation_updateGradSecond = '''
171 | #define ROUND_OFF 50000
172 |
173 | extern "C" __global__ void kernel_Correlation_updateGradSecond(
174 | const int n,
175 | const int intSample,
176 | const float* rbot0,
177 | const float* rbot1,
178 | const float* gradOutput,
179 | float* gradFirst,
180 | float* gradSecond
181 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
182 | int n = intIndex % SIZE_1(gradSecond); // channels
183 | int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 3*{{intStride}}; // w-pos
184 | int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 3*{{intStride}}; // h-pos
185 |
186 | // round_off is a trick to enable integer division with ceil, even for negative numbers
187 | // We use a large offset, for the inner part not to become negative.
188 | const int round_off = ROUND_OFF;
189 | const int round_off_s1 = {{intStride}} * round_off;
190 |
191 | float sum = 0;
192 | for (int p = -3; p <= 3; p++) {
193 | for (int o = -3; o <= 3; o++) {
194 | int s2o = {{intStride}} * o;
195 | int s2p = {{intStride}} * p;
196 |
197 | //Get X,Y ranges and clamp
198 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
199 | int xmin = (l - 3*{{intStride}} - s2o + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}}
200 | int ymin = (m - 3*{{intStride}} - s2p + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}}
201 |
202 | // Same here:
203 | int xmax = (l - 3*{{intStride}} - s2o + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}} - s2o) / {{intStride}}
204 | int ymax = (m - 3*{{intStride}} - s2p + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}} - s2p) / {{intStride}}
205 |
206 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
207 | xmin = max(0,xmin);
208 | xmax = min(SIZE_3(gradOutput)-1,xmax);
209 |
210 | ymin = max(0,ymin);
211 | ymax = min(SIZE_2(gradOutput)-1,ymax);
212 |
213 | // Get rbot0 data:
214 | int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
215 | float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
216 |
217 | // Index offset for gradOutput in following loops:
218 | int op = (p+3) * 7 + (o+3); // index[o,p]
219 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
220 |
221 | for (int y = ymin; y <= ymax; y++) {
222 | for (int x = xmin; x <= xmax; x++) {
223 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
224 | sum += gradOutput[idxgradOutput] * bot0tmp;
225 | }
226 | }
227 | }
228 | }
229 | }
230 | const int sumelems = SIZE_1(gradSecond);
231 | const int bot1index = ((n * SIZE_2(gradSecond)) + (m-3*{{intStride}})) * SIZE_3(gradSecond) + (l-3*{{intStride}});
232 | gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems;
233 | } }
234 | '''
235 |
236 | def cupy_kernel(strFunction, objVariables):
237 | strKernel = globals()[strFunction].replace('{{intStride}}', str(objVariables['intStride']))
238 |
239 | while True:
240 | objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
241 |
242 | if objMatch is None:
243 | break
244 | # end
245 |
246 | intArg = int(objMatch.group(2))
247 |
248 | strTensor = objMatch.group(4)
249 | intSizes = objVariables[strTensor].size()
250 |
251 | strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
252 | # end
253 |
254 | while True:
255 | objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
256 |
257 | if objMatch is None:
258 | break
259 | # end
260 |
261 | intArgs = int(objMatch.group(2))
262 | strArgs = objMatch.group(4).split(',')
263 |
264 | strTensor = strArgs[0]
265 | intStrides = objVariables[strTensor].stride()
266 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]
267 |
268 | strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
269 | # end
270 |
271 | return strKernel
272 | # end
273 |
274 | @cupy.util.memoize(for_each_device=True)
275 | def cupy_launch(strFunction, strKernel):
276 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction)
277 | # end
278 |
279 | class _FunctionCorrelation(torch.autograd.Function):
280 | @staticmethod
281 | def forward(self, first, second, intStride):
282 | rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + (6 * intStride), first.shape[3] + (6 * intStride), first.shape[1] ])
283 | rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + (6 * intStride), first.shape[3] + (6 * intStride), first.shape[1] ])
284 |
285 | self.save_for_backward(first, second, rbot0, rbot1)
286 |
287 | self.intStride = intStride
288 |
289 | assert(first.is_contiguous() == True)
290 | assert(second.is_contiguous() == True)
291 |
292 | output = first.new_zeros([ first.shape[0], 49, int(math.ceil(first.shape[2] / intStride)), int(math.ceil(first.shape[3] / intStride)) ])
293 |
294 | if first.is_cuda == True:
295 | n = first.shape[2] * first.shape[3]
296 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
297 | 'intStride': self.intStride,
298 | 'input': first,
299 | 'output': rbot0
300 | }))(
301 | grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]),
302 | block=tuple([ 16, 1, 1 ]),
303 | args=[ n, first.data_ptr(), rbot0.data_ptr() ]
304 | )
305 |
306 | n = second.shape[2] * second.shape[3]
307 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
308 | 'intStride': self.intStride,
309 | 'input': second,
310 | 'output': rbot1
311 | }))(
312 | grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]),
313 | block=tuple([ 16, 1, 1 ]),
314 | args=[ n, second.data_ptr(), rbot1.data_ptr() ]
315 | )
316 |
317 | n = output.shape[1] * output.shape[2] * output.shape[3]
318 | cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {
319 | 'intStride': self.intStride,
320 | 'rbot0': rbot0,
321 | 'rbot1': rbot1,
322 | 'top': output
323 | }))(
324 | grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]),
325 | block=tuple([ 32, 1, 1 ]),
326 | shared_mem=first.shape[1] * 4,
327 | args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ]
328 | )
329 |
330 | elif first.is_cuda == False:
331 | raise NotImplementedError()
332 |
333 | # end
334 |
335 | return output
336 | # end
337 |
338 | @staticmethod
339 | def backward(self, gradOutput):
340 | first, second, rbot0, rbot1 = self.saved_tensors
341 |
342 | assert(gradOutput.is_contiguous() == True)
343 |
344 | gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None
345 | gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None
346 |
347 | if first.is_cuda == True:
348 | if gradFirst is not None:
349 | for intSample in range(first.shape[0]):
350 | n = first.shape[1] * first.shape[2] * first.shape[3]
351 | cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', {
352 | 'intStride': self.intStride,
353 | 'rbot0': rbot0,
354 | 'rbot1': rbot1,
355 | 'gradOutput': gradOutput,
356 | 'gradFirst': gradFirst,
357 | 'gradSecond': None
358 | }))(
359 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
360 | block=tuple([ 512, 1, 1 ]),
361 | args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ]
362 | )
363 | # end
364 | # end
365 |
366 | if gradSecond is not None:
367 | for intSample in range(first.shape[0]):
368 | n = first.shape[1] * first.shape[2] * first.shape[3]
369 | cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', {
370 | 'intStride': self.intStride,
371 | 'rbot0': rbot0,
372 | 'rbot1': rbot1,
373 | 'gradOutput': gradOutput,
374 | 'gradFirst': None,
375 | 'gradSecond': gradSecond
376 | }))(
377 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
378 | block=tuple([ 512, 1, 1 ]),
379 | args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ]
380 | )
381 | # end
382 | # end
383 |
384 | elif first.is_cuda == False:
385 | raise NotImplementedError()
386 |
387 | # end
388 |
389 | return gradFirst, gradSecond, None
390 | # end
391 | # end
392 |
393 | def FunctionCorrelation(tenFirst, tenSecond, intStride):
394 | return _FunctionCorrelation.apply(tenFirst, tenSecond, intStride)
395 | # end
396 |
397 | class ModuleCorrelation(torch.nn.Module):
398 | def __init__(self):
399 | super(ModuleCorrelation, self).__init__()
400 | # end
401 |
402 | def forward(self, tenFirst, tenSecond, intStride):
403 | return _FunctionCorrelation.apply(tenFirst, tenSecond, intStride)
404 | # end
405 | # end
--------------------------------------------------------------------------------
/PF-AFN_train/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.parallel
4 | from torchvision import models
5 | from options.train_options import TrainOptions
6 | import os
7 |
8 | opt = TrainOptions().parse()
9 |
10 | class ResidualBlock(nn.Module):
11 | def __init__(self, in_features=64, norm_layer=nn.BatchNorm2d):
12 | super(ResidualBlock, self).__init__()
13 | self.relu = nn.ReLU(True)
14 | if norm_layer == None:
15 | self.block = nn.Sequential(
16 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
17 | nn.ReLU(inplace=True),
18 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
19 | )
20 | else:
21 | self.block = nn.Sequential(
22 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
23 | norm_layer(in_features),
24 | nn.ReLU(inplace=True),
25 | nn.Conv2d(in_features, in_features, 3, 1, 1, bias=False),
26 | norm_layer(in_features)
27 | )
28 |
29 | def forward(self, x):
30 | residual = x
31 | out = self.block(x)
32 | out += residual
33 | out = self.relu(out)
34 | return out
35 |
36 |
37 | class ResUnetGenerator(nn.Module):
38 | def __init__(self, input_nc, output_nc, num_downs, ngf=64,
39 | norm_layer=nn.BatchNorm2d, use_dropout=False):
40 | super(ResUnetGenerator, self).__init__()
41 | # construct unet structure
42 | unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
43 |
44 | for i in range(num_downs - 5):
45 | unet_block = ResUnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
46 | unet_block = ResUnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
47 | unet_block = ResUnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
48 | unet_block = ResUnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
49 | unet_block = ResUnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
50 |
51 | self.model = unet_block
52 | self.old_lr = opt.lr
53 | self.old_lr_gmm = 0.1*opt.lr
54 |
55 | def forward(self, input):
56 | return self.model(input)
57 |
58 |
59 | # Defines the submodule with skip connection.
60 | # X -------------------identity---------------------- X
61 | # |-- downsampling -- |submodule| -- upsampling --|
62 | class ResUnetSkipConnectionBlock(nn.Module):
63 | def __init__(self, outer_nc, inner_nc, input_nc=None,
64 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
65 | super(ResUnetSkipConnectionBlock, self).__init__()
66 | self.outermost = outermost
67 | use_bias = norm_layer == nn.InstanceNorm2d
68 |
69 | if input_nc is None:
70 | input_nc = outer_nc
71 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=3,
72 | stride=2, padding=1, bias=use_bias)
73 | # add two resblock
74 | res_downconv = [ResidualBlock(inner_nc, norm_layer), ResidualBlock(inner_nc, norm_layer)]
75 | res_upconv = [ResidualBlock(outer_nc, norm_layer), ResidualBlock(outer_nc, norm_layer)]
76 |
77 | downrelu = nn.ReLU(True)
78 | uprelu = nn.ReLU(True)
79 | if norm_layer != None:
80 | downnorm = norm_layer(inner_nc)
81 | upnorm = norm_layer(outer_nc)
82 |
83 | if outermost:
84 | upsample = nn.Upsample(scale_factor=2, mode='nearest')
85 | upconv = nn.Conv2d(inner_nc * 2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
86 | down = [downconv, downrelu] + res_downconv
87 | up = [upsample, upconv]
88 | model = down + [submodule] + up
89 | elif innermost:
90 | upsample = nn.Upsample(scale_factor=2, mode='nearest')
91 | upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
92 | down = [downconv, downrelu] + res_downconv
93 | if norm_layer == None:
94 | up = [upsample, upconv, uprelu] + res_upconv
95 | else:
96 | up = [upsample, upconv, upnorm, uprelu] + res_upconv
97 | model = down + up
98 | else:
99 | upsample = nn.Upsample(scale_factor=2, mode='nearest')
100 | upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3, stride=1, padding=1, bias=use_bias)
101 | if norm_layer == None:
102 | down = [downconv, downrelu] + res_downconv
103 | up = [upsample, upconv, uprelu] + res_upconv
104 | else:
105 | down = [downconv, downnorm, downrelu] + res_downconv
106 | up = [upsample, upconv, upnorm, uprelu] + res_upconv
107 |
108 | if use_dropout:
109 | model = down + [submodule] + up + [nn.Dropout(0.5)]
110 | else:
111 | model = down + [submodule] + up
112 |
113 | self.model = nn.Sequential(*model)
114 |
115 | def forward(self, x):
116 | if self.outermost:
117 | return self.model(x)
118 | else:
119 | return torch.cat([x, self.model(x)], 1)
120 |
121 |
122 | class Vgg19(nn.Module):
123 | def __init__(self, requires_grad=False):
124 | super(Vgg19, self).__init__()
125 | vgg_pretrained_features = models.vgg19(pretrained=True).features
126 | self.slice1 = nn.Sequential()
127 | self.slice2 = nn.Sequential()
128 | self.slice3 = nn.Sequential()
129 | self.slice4 = nn.Sequential()
130 | self.slice5 = nn.Sequential()
131 | for x in range(2):
132 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
133 | for x in range(2, 7):
134 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
135 | for x in range(7, 12):
136 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
137 | for x in range(12, 21):
138 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
139 | for x in range(21, 30):
140 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
141 | if not requires_grad:
142 | for param in self.parameters():
143 | param.requires_grad = False
144 |
145 | def forward(self, X):
146 | h_relu1 = self.slice1(X)
147 | h_relu2 = self.slice2(h_relu1)
148 | h_relu3 = self.slice3(h_relu2)
149 | h_relu4 = self.slice4(h_relu3)
150 | h_relu5 = self.slice5(h_relu4)
151 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
152 | return out
153 |
154 | class VGGLoss(nn.Module):
155 | def __init__(self, layids = None):
156 | super(VGGLoss, self).__init__()
157 | self.vgg = Vgg19()
158 | self.vgg.cuda()
159 | self.criterion = nn.L1Loss()
160 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
161 | self.layids = layids
162 |
163 | def forward(self, x, y):
164 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
165 | loss = 0
166 | if self.layids is None:
167 | self.layids = list(range(len(x_vgg)))
168 | for i in self.layids:
169 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
170 | return loss
171 |
172 | def save_checkpoint(model, save_path):
173 | if not os.path.exists(os.path.dirname(save_path)):
174 | os.makedirs(os.path.dirname(save_path))
175 | torch.save(model.state_dict(), save_path)
176 |
177 |
178 | def load_checkpoint_parallel(model, checkpoint_path):
179 |
180 | if not os.path.exists(checkpoint_path):
181 | print('No checkpoint!')
182 | return
183 |
184 | checkpoint = torch.load(checkpoint_path, map_location='cuda:{}'.format(opt.local_rank))
185 | checkpoint_new = model.state_dict()
186 | for param in checkpoint_new:
187 | checkpoint_new[param] = checkpoint[param]
188 | model.load_state_dict(checkpoint_new)
189 |
190 | def load_checkpoint_part_parallel(model, checkpoint_path):
191 |
192 | if not os.path.exists(checkpoint_path):
193 | print('No checkpoint!')
194 | return
195 | checkpoint = torch.load(checkpoint_path,map_location='cuda:{}'.format(opt.local_rank))
196 | checkpoint_new = model.state_dict()
197 | for param in checkpoint_new:
198 | if 'cond_' not in param and 'aflow_net.netRefine' not in param:
199 | checkpoint_new[param] = checkpoint[param]
200 | model.load_state_dict(checkpoint_new)
201 |
202 |
203 |
--------------------------------------------------------------------------------
/PF-AFN_train/options/__init__.py:
--------------------------------------------------------------------------------
1 | # options_init
--------------------------------------------------------------------------------
/PF-AFN_train/options/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/options/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/options/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/options/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/options/__pycache__/base_options.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/options/__pycache__/base_options.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/options/__pycache__/base_options.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/options/__pycache__/base_options.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/options/__pycache__/test_options.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/options/__pycache__/test_options.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/options/__pycache__/train_options.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/options/__pycache__/train_options.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/options/__pycache__/train_options.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/options/__pycache__/train_options.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 |
6 | class BaseOptions():
7 | def __init__(self):
8 | self.parser = argparse.ArgumentParser()
9 | self.initialized = False
10 |
11 | def initialize(self):
12 | # experiment specifics
13 | self.parser.add_argument('--name', type=str, default='flow', help='name of the experiment. It decides where to store samples and models')
14 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
15 | self.parser.add_argument('--num_gpus', type=int, default=1, help='the number of gpus')
16 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
17 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
18 | self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
19 | self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit")
20 | self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose')
21 |
22 | # input/output sizes
23 | self.parser.add_argument('--batchSize', type=int, default=32, help='input batch size')
24 | self.parser.add_argument('--loadSize', type=int, default=512, help='scale images to this size')
25 | self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
26 | self.parser.add_argument('--label_nc', type=int, default=20, help='# of input label channels')
27 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
28 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
29 |
30 | # for setting inputs
31 | self.parser.add_argument('--dataroot', type=str,default='dataset/VITON_traindata/')
32 | self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
33 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
34 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
35 | self.parser.add_argument('--nThreads', default=1, type=int, help='# threads for loading data')
36 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
37 |
38 | # for displays
39 | self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size')
40 | self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
41 |
42 | # for generator
43 | self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG')
44 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
45 | self.parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG')
46 | self.parser.add_argument('--n_blocks_global', type=int, default=4, help='number of residual blocks in the global generator network')
47 | self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network')
48 | self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use')
49 | self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer')
50 | self.parser.add_argument('--tv_weight', type=float, default=0.1, help='weight for TV loss')
51 |
52 | self.initialized = True
53 |
54 | def parse(self, save=True):
55 | if not self.initialized:
56 | self.initialize()
57 | self.opt = self.parser.parse_args()
58 | self.opt.isTrain = self.isTrain # train or test
59 |
60 | str_ids = self.opt.gpu_ids.split(',')
61 | self.opt.gpu_ids = []
62 | for str_id in str_ids:
63 | id = int(str_id)
64 | if id >= 0:
65 | self.opt.gpu_ids.append(id)
66 |
67 | # set gpu ids
68 | if len(self.opt.gpu_ids) > 0:
69 | torch.cuda.set_device(self.opt.gpu_ids[0])
70 |
71 | args = vars(self.opt)
72 |
73 | print('------------ Options -------------')
74 | for k, v in sorted(args.items()):
75 | print('%s: %s' % (str(k), str(v)))
76 | print('-------------- End ----------------')
77 |
78 | # save to the disk
79 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
80 | util.mkdirs(expr_dir)
81 | if save and not self.opt.continue_train:
82 | file_name = os.path.join(expr_dir, 'opt.txt')
83 | with open(file_name, 'wt') as opt_file:
84 | opt_file.write('------------ Options -------------\n')
85 | for k, v in sorted(args.items()):
86 | opt_file.write('%s: %s\n' % (str(k), str(v)))
87 | opt_file.write('-------------- End ----------------\n')
88 | return self.opt
89 |
--------------------------------------------------------------------------------
/PF-AFN_train/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 | class TrainOptions(BaseOptions):
4 | def initialize(self):
5 | BaseOptions.initialize(self)
6 | # for displays
7 | self.parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',help='job launcher')
8 | self.parser.add_argument('--local_rank', type=int, default=0)
9 |
10 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
11 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
12 | self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results')
13 | self.parser.add_argument('--save_epoch_freq', type=int, default=20, help='frequency of saving checkpoints at the end of epochs')
14 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
15 | self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')
16 |
17 | # for training
18 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
19 | self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location')
20 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
21 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
22 | self.parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate')
23 | self.parser.add_argument('--niter_decay', type=int, default=50, help='# of iter to linearly decay learning rate to zero')
24 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
25 | self.parser.add_argument('--lr', type=float, default=0.00005, help='initial learning rate for adam')
26 | self.parser.add_argument('--PFAFN_warp_checkpoint', type=str, help='load the pretrained model from the specified location')
27 | self.parser.add_argument('--PFAFN_gen_checkpoint', type=str, help='load the pretrained model from the specified location')
28 | self.parser.add_argument('--PBAFN_warp_checkpoint', type=str, help='load the pretrained model from the specified location')
29 | self.parser.add_argument('--PBAFN_gen_checkpoint', type=str, help='load the pretrained model from the specified location')
30 | # for discriminators
31 | self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use')
32 | self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
33 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
34 | self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
35 | self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
36 | self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')
37 | self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
38 | self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')
39 |
40 | self.isTrain = True
41 |
--------------------------------------------------------------------------------
/PF-AFN_train/runs/readme.txt:
--------------------------------------------------------------------------------
1 | The tensorboard logs will be saved here.
2 |
--------------------------------------------------------------------------------
/PF-AFN_train/sample/readme.txt:
--------------------------------------------------------------------------------
1 | The images during training will be saved here.
2 |
--------------------------------------------------------------------------------
/PF-AFN_train/scripts/train_PBAFN_e2e.sh:
--------------------------------------------------------------------------------
1 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4736 train_PBAFN_e2e.py --name PBAFN_e2e \
2 | --PBAFN_warp_checkpoint 'checkpoints/PBAFN_stage1/PBAFN_warp_epoch_101.pth' --resize_or_crop None --verbose --tf_log --batchSize 4 --num_gpus 8 --label_nc 14 --launcher pytorch
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/PF-AFN_train/scripts/train_PBAFN_stage1.sh:
--------------------------------------------------------------------------------
1 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=7129 train_PBAFN_stage1.py --name PBAFN_stage1 \
2 | --resize_or_crop None --verbose --tf_log --batchSize 4 --num_gpus 8 --label_nc 14 --launcher pytorch
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/PF-AFN_train/scripts/train_PFAFN_e2e.sh:
--------------------------------------------------------------------------------
1 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=7129 train_PFAFN_e2e.py --name PFAFN_e2e \
2 | --PFAFN_warp_checkpoint 'checkpoints/PFAFN_stage1/PFAFN_warp_epoch_201.pth' \
3 | --PBAFN_warp_checkpoint 'checkpoints/PBAFN_e2e/PBAFN_warp_epoch_101.pth' --PBAFN_gen_checkpoint 'checkpoints/PBAFN_e2e/PBAFN_gen_epoch_101.pth' \
4 | --resize_or_crop None --verbose --tf_log --batchSize 4 --num_gpus 8 --label_nc 14 --launcher pytorch
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/PF-AFN_train/scripts/train_PFAFN_stage1.sh:
--------------------------------------------------------------------------------
1 | python -m torch.distributed.launch --nproc_per_node=8 --master_port=4703 train_PFAFN_stage1.py --name PFAFN_stage1 \
2 | --PBAFN_warp_checkpoint 'checkpoints/PBAFN_e2e/PBAFN_warp_epoch_101.pth' --PBAFN_gen_checkpoint 'checkpoints/PBAFN_e2e/PBAFN_gen_epoch_101.pth' \
3 | --lr 0.00003 --niter 100 --niter_decay 100 --resize_or_crop None --verbose --tf_log --batchSize 4 --num_gpus 8 --label_nc 14 --launcher pytorch
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/PF-AFN_train/train_PBAFN_e2e.py:
--------------------------------------------------------------------------------
1 | import time
2 | from options.train_options import TrainOptions
3 | from models.networks import ResUnetGenerator, VGGLoss, save_checkpoint, load_checkpoint_parallel
4 | from models.afwm import TVLoss, AFWM
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import os
8 | import numpy as np
9 | import torch
10 | from torch.utils.data import DataLoader
11 | from torch.utils.data.distributed import DistributedSampler
12 | from tensorboardX import SummaryWriter
13 | import cv2
14 | import datetime
15 |
16 | opt = TrainOptions().parse()
17 | path = 'runs/'+opt.name
18 | os.makedirs(path,exist_ok=True)
19 |
20 | def CreateDataset(opt):
21 | from data.aligned_dataset import AlignedDataset
22 | dataset = AlignedDataset()
23 | print("dataset [%s] was created" % (dataset.name()))
24 | dataset.initialize(opt)
25 | return dataset
26 |
27 | os.makedirs('sample',exist_ok=True)
28 | opt = TrainOptions().parse()
29 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
30 |
31 | torch.cuda.set_device(opt.local_rank)
32 | torch.distributed.init_process_group(
33 | 'nccl',
34 | init_method='env://'
35 | )
36 | device = torch.device(f'cuda:{opt.local_rank}')
37 |
38 | start_epoch, epoch_iter = 1, 0
39 |
40 | train_data = CreateDataset(opt)
41 | train_sampler = DistributedSampler(train_data)
42 | train_loader = DataLoader(train_data, batch_size=opt.batchSize, shuffle=False,
43 | num_workers=4, pin_memory=True, sampler=train_sampler)
44 | dataset_size = len(train_loader)
45 |
46 | warp_model = AFWM(opt, 45)
47 | print(warp_model)
48 | warp_model.train()
49 | warp_model.cuda()
50 | load_checkpoint_parallel(warp_model, opt.PBAFN_warp_checkpoint)
51 |
52 | gen_model = ResUnetGenerator(8, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d)
53 | print(gen_model)
54 | gen_model.train()
55 | gen_model.cuda()
56 |
57 | warp_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(warp_model).to(device)
58 | gen_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(gen_model).to(device)
59 |
60 | if opt.isTrain and len(opt.gpu_ids):
61 | model = torch.nn.parallel.DistributedDataParallel(warp_model, device_ids=[opt.local_rank])
62 | model_gen = torch.nn.parallel.DistributedDataParallel(gen_model, device_ids=[opt.local_rank])
63 |
64 | criterionL1 = nn.L1Loss()
65 | criterionVGG = VGGLoss()
66 | # optimizer
67 | params_warp = [p for p in model.parameters()]
68 | params_gen = [p for p in model_gen.parameters()]
69 | optimizer_warp = torch.optim.Adam(params_warp, lr=0.2*opt.lr, betas=(opt.beta1, 0.999))
70 | optimizer_gen = torch.optim.Adam(params_gen, lr=opt.lr, betas=(opt.beta1, 0.999))
71 |
72 | total_steps = (start_epoch-1) * dataset_size + epoch_iter
73 |
74 | step = 0
75 | step_per_batch = dataset_size
76 |
77 | if opt.local_rank == 0:
78 | writer = SummaryWriter(path)
79 |
80 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
81 | epoch_start_time = time.time()
82 | if epoch != start_epoch:
83 | epoch_iter = epoch_iter % dataset_size
84 |
85 | train_sampler.set_epoch(epoch)
86 |
87 | for i, data in enumerate(train_loader):
88 |
89 | iter_start_time = time.time()
90 |
91 | total_steps += 1
92 | epoch_iter += 1
93 | save_fake = True
94 |
95 | t_mask = torch.FloatTensor((data['label'].cpu().numpy()==7).astype(np.float))
96 | data['label'] = data['label']*(1-t_mask)+t_mask*4
97 | edge = data['edge']
98 | pre_clothes_edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int))
99 | clothes = data['color']
100 | clothes = clothes * pre_clothes_edge
101 | person_clothes_edge = torch.FloatTensor((data['label'].cpu().numpy()==4).astype(np.int))
102 | real_image = data['image']
103 | person_clothes = real_image*person_clothes_edge
104 | pose = data['pose']
105 | size = data['label'].size()
106 | oneHot_size1 = (size[0], 25, size[2], size[3])
107 | densepose = torch.cuda.FloatTensor(torch.Size(oneHot_size1)).zero_()
108 | densepose = densepose.scatter_(1,data['densepose'].data.long().cuda(),1.0)
109 | densepose_fore = data['densepose']/24.0
110 | face_mask = torch.FloatTensor((data['label'].cpu().numpy()==1).astype(np.int))+torch.FloatTensor((data['label'].cpu().numpy()==12).astype(np.int))
111 | other_clothes_mask = torch.FloatTensor((data['label'].cpu().numpy()==5).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy()==6).astype(np.int))\
112 | + torch.FloatTensor((data['label'].cpu().numpy()==8).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy()==9).astype(np.int))\
113 | + torch.FloatTensor((data['label'].cpu().numpy()==10).astype(np.int))
114 | face_img = face_mask * real_image
115 | other_clothes_img = other_clothes_mask * real_image
116 | preserve_region = face_img + other_clothes_img
117 | preserve_mask = torch.cat([face_mask, other_clothes_mask],1)
118 | concat = torch.cat([preserve_mask.cuda(), densepose, pose.cuda()],1)
119 | arm_mask = torch.FloatTensor((data['label'].cpu().numpy()==11).astype(np.float)) + torch.FloatTensor((data['label'].cpu().numpy()==13).astype(np.float))
120 | hand_mask = torch.FloatTensor((data['densepose'].cpu().numpy()==3).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy()==4).astype(np.int))
121 | hand_mask = arm_mask*hand_mask
122 | hand_img = hand_mask*real_image
123 | dense_preserve_mask = torch.FloatTensor((data['densepose'].cpu().numpy()==15).astype(np.int))+torch.FloatTensor((data['densepose'].cpu().numpy()==16).astype(np.int))\
124 | +torch.FloatTensor((data['densepose'].cpu().numpy()==17).astype(np.int))+torch.FloatTensor((data['densepose'].cpu().numpy()==18).astype(np.int))\
125 | +torch.FloatTensor((data['densepose'].cpu().numpy()==19).astype(np.int))+torch.FloatTensor((data['densepose'].cpu().numpy()==20).astype(np.int))\
126 | +torch.FloatTensor((data['densepose'].cpu().numpy()==21).astype(np.int))+torch.FloatTensor((data['densepose'].cpu().numpy()==22))
127 | dense_preserve_mask = dense_preserve_mask.cuda()*(1-person_clothes_edge.cuda())
128 | preserve_region = face_img + other_clothes_img +hand_img
129 |
130 | flow_out = model(concat.cuda(), clothes.cuda(), pre_clothes_edge.cuda())
131 | warped_cloth, last_flow, _1, _2, delta_list, x_all, x_edge_all, delta_x_all, delta_y_all = flow_out
132 |
133 | epsilon = 0.001
134 | loss_smooth = sum([TVLoss(x) for x in delta_list])
135 | warp_loss = 0
136 |
137 | for num in range(5):
138 | cur_person_clothes = F.interpolate(person_clothes, scale_factor=0.5**(4-num), mode='bilinear')
139 | cur_person_clothes_edge = F.interpolate(person_clothes_edge, scale_factor=0.5**(4-num), mode='bilinear')
140 | loss_l1 = criterionL1(x_all[num], cur_person_clothes.cuda())
141 | loss_vgg = criterionVGG(x_all[num], cur_person_clothes.cuda())
142 | loss_edge = criterionL1(x_edge_all[num], cur_person_clothes_edge.cuda())
143 | b,c,h,w = delta_x_all[num].shape
144 | loss_flow_x = (delta_x_all[num].pow(2) + epsilon*epsilon).pow(0.45)
145 | loss_flow_x = torch.sum(loss_flow_x) / (b*c*h*w)
146 | loss_flow_y = (delta_y_all[num].pow(2) + epsilon*epsilon).pow(0.45)
147 | loss_flow_y = torch.sum(loss_flow_y) / (b*c*h*w)
148 | loss_second_smooth = loss_flow_x + loss_flow_y
149 | warp_loss = warp_loss + (num+1) * loss_l1 + (num+1) * 0.2 * loss_vgg + (num+1) * 2 * loss_edge + (num+1) * 6 * loss_second_smooth
150 |
151 | warp_loss = 0.01 * loss_smooth + warp_loss
152 |
153 | if opt.local_rank == 0:
154 | writer.add_scalar('warp_loss', warp_loss, step)
155 |
156 | warped_prod_edge = x_edge_all[4]
157 | gen_inputs = torch.cat([preserve_region.cuda(), warped_cloth, warped_prod_edge, dense_preserve_mask], 1)
158 |
159 | gen_outputs = model_gen(gen_inputs)
160 | p_rendered, m_composite = torch.split(gen_outputs, [3, 1], 1)
161 | p_rendered = torch.tanh(p_rendered)
162 | m_composite = torch.sigmoid(m_composite)
163 | m_composite1 = m_composite * warped_prod_edge
164 | m_composite = person_clothes_edge.cuda()*m_composite1
165 | p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite)
166 |
167 | loss_mask_l1 = torch.mean(torch.abs(1 - m_composite))
168 | loss_l1 = criterionL1(p_tryon, real_image.cuda())
169 | loss_vgg = criterionVGG(p_tryon,real_image.cuda())
170 | bg_loss_l1 = criterionL1(p_rendered, real_image.cuda())
171 | bg_loss_vgg = criterionVGG(p_rendered, real_image.cuda())
172 | gen_loss = (loss_l1 * 5 + loss_vgg + bg_loss_l1 * 5 + bg_loss_vgg + loss_mask_l1)
173 |
174 |
175 | if opt.local_rank == 0:
176 | writer.add_scalar('gen_loss', gen_loss, step)
177 |
178 | loss_all = 0.5 * warp_loss + 1.0 * gen_loss
179 |
180 | if opt.local_rank == 0:
181 | writer.add_scalar('loss_all', loss_all, step)
182 |
183 | optimizer_warp.zero_grad()
184 | optimizer_gen.zero_grad()
185 | loss_all.backward()
186 | optimizer_warp.step()
187 | optimizer_gen.step()
188 |
189 | ############## Display results and errors ##########
190 | path = 'sample/'+opt.name
191 | os.makedirs(path,exist_ok=True)
192 | if step % 1000 == 0:
193 | if opt.local_rank == 0:
194 | a = real_image.float().cuda()
195 | b = person_clothes.cuda()
196 | c = clothes.cuda()
197 | d = torch.cat([densepose_fore.cuda(),densepose_fore.cuda(),densepose_fore.cuda()],1)
198 | e = warped_cloth
199 | f = torch.cat([warped_prod_edge,warped_prod_edge,warped_prod_edge],1)
200 | g = preserve_region.cuda()
201 | h = torch.cat([dense_preserve_mask,dense_preserve_mask,dense_preserve_mask],1)
202 | i = p_rendered
203 | j = torch.cat([m_composite1,m_composite1,m_composite1],1)
204 | k = p_tryon
205 | combine = torch.cat([a[0],b[0],c[0],d[0],e[0],f[0],g[0],h[0],i[0],j[0],k[0]], 2).squeeze()
206 | cv_img = (combine.permute(1,2,0).detach().cpu().numpy()+1)/2
207 | writer.add_image('combine', (combine.data + 1) / 2.0, step)
208 | rgb = (cv_img*255).astype(np.uint8)
209 | bgr = cv2.cvtColor(rgb,cv2.COLOR_RGB2BGR)
210 | cv2.imwrite('sample/'+opt.name+'/'+str(step)+'.jpg',bgr)
211 |
212 | step += 1
213 | iter_end_time = time.time()
214 | iter_delta_time = iter_end_time - iter_start_time
215 | step_delta = (step_per_batch-step%step_per_batch) + step_per_batch*(opt.niter + opt.niter_decay-epoch)
216 | eta = iter_delta_time*step_delta
217 | eta = str(datetime.timedelta(seconds=int(eta)))
218 | time_stamp = datetime.datetime.now()
219 | now = time_stamp.strftime('%Y.%m.%d-%H:%M:%S')
220 |
221 | if step % 100 == 0:
222 | if opt.local_rank == 0:
223 | print('{}:{}:[step-{}]--[loss-{:.6f}]--[loss-{:.6f}]--[ETA-{}]'.format(now, epoch_iter, step, warp_loss, gen_loss, eta))
224 |
225 | if epoch_iter >= dataset_size:
226 | break
227 |
228 | iter_end_time = time.time()
229 | if opt.local_rank == 0:
230 | print('End of epoch %d / %d \t Time Taken: %d sec' %
231 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
232 |
233 | ### save model for this epoch
234 | if epoch % opt.save_epoch_freq == 0:
235 | if opt.local_rank == 0:
236 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
237 | save_checkpoint(model.module, os.path.join(opt.checkpoints_dir, opt.name, 'PBAFN_warp_epoch_%03d.pth' % (epoch+1)))
238 | save_checkpoint(model_gen.module, os.path.join(opt.checkpoints_dir, opt.name, 'PBAFN_gen_epoch_%03d.pth' % (epoch+1)))
239 |
240 | if epoch > opt.niter:
241 | model.module.update_learning_rate_warp(optimizer_warp)
242 | model.module.update_learning_rate(optimizer_gen)
243 |
--------------------------------------------------------------------------------
/PF-AFN_train/train_PBAFN_stage1.py:
--------------------------------------------------------------------------------
1 | import time
2 | from options.train_options import TrainOptions
3 | from models.networks import VGGLoss,save_checkpoint
4 | from models.afwm import TVLoss,AFWM
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import os
8 | import numpy as np
9 | import torch
10 | from torch.utils.data import DataLoader
11 | from torch.utils.data.distributed import DistributedSampler
12 | from tensorboardX import SummaryWriter
13 | import cv2
14 | import datetime
15 |
16 | opt = TrainOptions().parse()
17 | path = 'runs/'+opt.name
18 | os.makedirs(path,exist_ok=True)
19 |
20 | def CreateDataset(opt):
21 | from data.aligned_dataset import AlignedDataset
22 | dataset = AlignedDataset()
23 | print("dataset [%s] was created" % (dataset.name()))
24 | dataset.initialize(opt)
25 | return dataset
26 |
27 | os.makedirs('sample',exist_ok=True)
28 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
29 |
30 | torch.cuda.set_device(opt.local_rank)
31 | torch.distributed.init_process_group(
32 | 'nccl',
33 | init_method='env://'
34 | )
35 | device = torch.device(f'cuda:{opt.local_rank}')
36 |
37 | start_epoch, epoch_iter = 1, 0
38 |
39 | train_data = CreateDataset(opt)
40 | train_sampler = DistributedSampler(train_data)
41 | train_loader = DataLoader(train_data, batch_size=opt.batchSize, shuffle=False,
42 | num_workers=4, pin_memory=True, sampler=train_sampler)
43 | dataset_size = len(train_loader)
44 | print('#training images = %d' % dataset_size)
45 |
46 | warp_model = AFWM(opt, 45)
47 | print(warp_model)
48 | warp_model.train()
49 | warp_model.cuda()
50 | warp_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(warp_model).to(device)
51 |
52 | if opt.isTrain and len(opt.gpu_ids):
53 | model = torch.nn.parallel.DistributedDataParallel(warp_model, device_ids=[opt.local_rank])
54 |
55 | criterionL1 = nn.L1Loss()
56 | criterionVGG = VGGLoss()
57 |
58 | params_warp = [p for p in model.parameters()]
59 | optimizer_warp = torch.optim.Adam(params_warp, lr=opt.lr, betas=(opt.beta1, 0.999))
60 |
61 | total_steps = (start_epoch-1) * dataset_size + epoch_iter
62 | step = 0
63 | step_per_batch = dataset_size
64 |
65 | if opt.local_rank == 0:
66 | writer = SummaryWriter(path)
67 |
68 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
69 | epoch_start_time = time.time()
70 | if epoch != start_epoch:
71 | epoch_iter = epoch_iter % dataset_size
72 |
73 | train_sampler.set_epoch(epoch)
74 |
75 | for i, data in enumerate(train_loader):
76 | iter_start_time = time.time()
77 |
78 | total_steps += 1
79 | epoch_iter += 1
80 | save_fake = True
81 |
82 | t_mask = torch.FloatTensor((data['label'].cpu().numpy()==7).astype(np.float))
83 | data['label'] = data['label']*(1-t_mask)+t_mask*4
84 | edge = data['edge']
85 | pre_clothes_edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int))
86 | clothes = data['color']
87 | clothes = clothes * pre_clothes_edge
88 | person_clothes_edge = torch.FloatTensor((data['label'].cpu().numpy()==4).astype(np.int))
89 | real_image = data['image']
90 | person_clothes = real_image * person_clothes_edge
91 | pose = data['pose']
92 | size = data['label'].size()
93 | oneHot_size1 = (size[0], 25, size[2], size[3])
94 | densepose = torch.cuda.FloatTensor(torch.Size(oneHot_size1)).zero_()
95 | densepose = densepose.scatter_(1,data['densepose'].data.long().cuda(),1.0)
96 | densepose_fore = data['densepose']/24.0
97 | face_mask = torch.FloatTensor((data['label'].cpu().numpy()==1).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy()==12).astype(np.int))
98 | other_clothes_mask = torch.FloatTensor((data['label'].cpu().numpy()==5).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy()==6).astype(np.int)) + \
99 | torch.FloatTensor((data['label'].cpu().numpy()==8).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy()==9).astype(np.int)) + \
100 | torch.FloatTensor((data['label'].cpu().numpy()==10).astype(np.int))
101 | preserve_mask = torch.cat([face_mask,other_clothes_mask],1)
102 | concat = torch.cat([preserve_mask.cuda(),densepose,pose.cuda()],1)
103 |
104 | flow_out = model(concat.cuda(), clothes.cuda(), pre_clothes_edge.cuda())
105 | warped_cloth, last_flow, _1, _2, delta_list, x_all, x_edge_all, delta_x_all, delta_y_all = flow_out
106 | warped_prod_edge = x_edge_all[4]
107 |
108 | epsilon = 0.001
109 | loss_smooth = sum([TVLoss(x) for x in delta_list])
110 | loss_all = 0
111 |
112 | for num in range(5):
113 | cur_person_clothes = F.interpolate(person_clothes, scale_factor=0.5**(4-num), mode='bilinear')
114 | cur_person_clothes_edge = F.interpolate(person_clothes_edge, scale_factor=0.5**(4-num), mode='bilinear')
115 | loss_l1 = criterionL1(x_all[num], cur_person_clothes.cuda())
116 | loss_vgg = criterionVGG(x_all[num], cur_person_clothes.cuda())
117 | loss_edge = criterionL1(x_edge_all[num], cur_person_clothes_edge.cuda())
118 | b,c,h,w = delta_x_all[num].shape
119 | loss_flow_x = (delta_x_all[num].pow(2)+ epsilon*epsilon).pow(0.45)
120 | loss_flow_x = torch.sum(loss_flow_x)/(b*c*h*w)
121 | loss_flow_y = (delta_y_all[num].pow(2)+ epsilon*epsilon).pow(0.45)
122 | loss_flow_y = torch.sum(loss_flow_y)/(b*c*h*w)
123 | loss_second_smooth = loss_flow_x + loss_flow_y
124 | loss_all = loss_all + (num+1) * loss_l1 + (num + 1) * 0.2 * loss_vgg + (num+1) * 2 * loss_edge + (num + 1) * 6 * loss_second_smooth
125 |
126 | loss_all = 0.01 * loss_smooth + loss_all
127 |
128 | if opt.local_rank == 0:
129 | writer.add_scalar('loss_all', loss_all, step)
130 |
131 | optimizer_warp.zero_grad()
132 | loss_all.backward()
133 | optimizer_warp.step()
134 | ############## Display results and errors ##########
135 |
136 | path = 'sample/'+opt.name
137 | os.makedirs(path,exist_ok=True)
138 | if step % 1000 == 0:
139 | if opt.local_rank == 0:
140 | a = real_image.float().cuda()
141 | b = person_clothes.cuda()
142 | c = clothes.cuda()
143 | d = torch.cat([densepose_fore.cuda(),densepose_fore.cuda(),densepose_fore.cuda()],1)
144 | e = warped_cloth
145 | f = torch.cat([warped_prod_edge,warped_prod_edge,warped_prod_edge],1)
146 | combine = torch.cat([a[0],b[0],c[0],d[0],e[0],f[0]], 2).squeeze()
147 | cv_img=(combine.permute(1,2,0).detach().cpu().numpy()+1)/2
148 | writer.add_image('combine', (combine.data + 1) / 2.0, step)
149 | rgb=(cv_img*255).astype(np.uint8)
150 | bgr=cv2.cvtColor(rgb,cv2.COLOR_RGB2BGR)
151 | cv2.imwrite('sample/'+opt.name+'/'+str(step)+'.jpg',bgr)
152 |
153 | step += 1
154 | iter_end_time = time.time()
155 | iter_delta_time = iter_end_time - iter_start_time
156 | step_delta = (step_per_batch-step%step_per_batch) + step_per_batch*(opt.niter + opt.niter_decay-epoch)
157 | eta = iter_delta_time*step_delta
158 | eta = str(datetime.timedelta(seconds=int(eta)))
159 | time_stamp = datetime.datetime.now()
160 | now = time_stamp.strftime('%Y.%m.%d-%H:%M:%S')
161 | if step % 100 == 0:
162 | if opt.local_rank == 0:
163 | print('{}:{}:[step-{}]--[loss-{:.6f}]--[ETA-{}]'.format(now, epoch_iter,step, loss_all,eta))
164 |
165 | if epoch_iter >= dataset_size:
166 | break
167 |
168 | # end of epoch
169 | iter_end_time = time.time()
170 | if opt.local_rank == 0:
171 | print('End of epoch %d / %d \t Time Taken: %d sec' %
172 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
173 |
174 | ### save model for this epoch
175 | if epoch % opt.save_epoch_freq == 0:
176 | if opt.local_rank == 0:
177 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
178 | save_checkpoint(model.module, os.path.join(opt.checkpoints_dir, opt.name, 'PBAFN_warp_epoch_%03d.pth' % (epoch+1)))
179 |
180 | if epoch > opt.niter:
181 | model.module.update_learning_rate(optimizer_warp)
182 |
--------------------------------------------------------------------------------
/PF-AFN_train/train_PFAFN_e2e.py:
--------------------------------------------------------------------------------
1 | import time
2 | from options.train_options import TrainOptions
3 | from models.networks import ResUnetGenerator, VGGLoss, save_checkpoint, load_checkpoint_parallel
4 | from models.afwm import TVLoss, AFWM
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import os
8 | import numpy as np
9 | import torch
10 | from torch.utils.data import DataLoader
11 | from torch.utils.data.distributed import DistributedSampler
12 | from tensorboardX import SummaryWriter
13 | import datetime
14 | import cv2
15 |
16 | opt = TrainOptions().parse()
17 | path = 'runs/' + opt.name
18 | os.makedirs(path, exist_ok=True)
19 |
20 |
21 | def CreateDataset(opt):
22 | from data.aligned_dataset import AlignedDataset
23 | dataset = AlignedDataset()
24 | print("dataset [%s] was created" % (dataset.name()))
25 | dataset.initialize(opt)
26 | return dataset
27 |
28 |
29 | os.makedirs('sample', exist_ok=True)
30 | opt = TrainOptions().parse()
31 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
32 |
33 | torch.cuda.set_device(opt.local_rank)
34 | torch.distributed.init_process_group(
35 | 'nccl',
36 | init_method='env://'
37 | )
38 | device = torch.device(f'cuda:{opt.local_rank}')
39 |
40 | start_epoch, epoch_iter = 1, 0
41 |
42 | train_data = CreateDataset(opt)
43 | train_sampler = DistributedSampler(train_data)
44 | train_loader = DataLoader(train_data, batch_size=opt.batchSize, shuffle=False,
45 | num_workers=4, pin_memory=True, sampler=train_sampler)
46 | dataset_size = len(train_loader)
47 |
48 | PF_warp_model = AFWM(opt, 3)
49 | print(PF_warp_model)
50 | PF_warp_model.train()
51 | PF_warp_model.cuda()
52 | load_checkpoint_parallel(PF_warp_model, opt.PFAFN_warp_checkpoint)
53 |
54 | PF_gen_model = ResUnetGenerator(7, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d)
55 | print(PF_gen_model)
56 | PF_gen_model.train()
57 | PF_gen_model.cuda()
58 |
59 | PB_warp_model = AFWM(opt, 45)
60 | print(PB_warp_model)
61 | PB_warp_model.eval()
62 | PB_warp_model.cuda()
63 | load_checkpoint_parallel(PB_warp_model, opt.PBAFN_warp_checkpoint)
64 |
65 | PB_gen_model = ResUnetGenerator(8, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d)
66 | print(PB_gen_model)
67 | PB_gen_model.eval()
68 | PB_gen_model.cuda()
69 | load_checkpoint_parallel(PB_gen_model, opt.PBAFN_gen_checkpoint)
70 |
71 | PF_warp_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(PF_warp_model).to(device)
72 | PF_gen_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(PF_gen_model).to(device)
73 |
74 | if opt.isTrain and len(opt.gpu_ids):
75 | PF_warp_model = torch.nn.parallel.DistributedDataParallel(PF_warp_model, device_ids=[opt.local_rank])
76 | PF_gen_model = torch.nn.parallel.DistributedDataParallel(PF_gen_model, device_ids=[opt.local_rank])
77 | PB_warp_model = torch.nn.parallel.DistributedDataParallel(PB_warp_model, device_ids=[opt.local_rank])
78 | PB_gen_model = torch.nn.parallel.DistributedDataParallel(PB_gen_model, device_ids=[opt.local_rank])
79 |
80 | criterionL1 = nn.L1Loss()
81 | criterionVGG = VGGLoss()
82 | criterionL2 = nn.MSELoss('sum')
83 |
84 | params_warp = [p for p in PF_warp_model.parameters()]
85 | params_gen = [p for p in PF_gen_model.parameters()]
86 | optimizer_warp = torch.optim.Adam(params_warp, lr=0.2 * opt.lr, betas=(opt.beta1, 0.999))
87 | optimizer_gen = torch.optim.Adam(params_gen, lr=opt.lr, betas=(opt.beta1, 0.999))
88 |
89 | total_steps = (start_epoch - 1) * dataset_size + epoch_iter
90 |
91 | if opt.local_rank == 0:
92 | writer = SummaryWriter(path)
93 |
94 | step = 0
95 | step_per_batch = dataset_size
96 |
97 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
98 | epoch_start_time = time.time()
99 | if epoch != start_epoch:
100 | epoch_iter = epoch_iter % dataset_size
101 |
102 | train_sampler.set_epoch(epoch)
103 |
104 | for i, data in enumerate(train_loader):
105 |
106 | iter_start_time = time.time()
107 |
108 | total_steps += 1
109 | epoch_iter += 1
110 | save_fake = True
111 |
112 | t_mask = torch.FloatTensor((data['label'].cpu().numpy() == 7).astype(np.float))
113 | data['label'] = data['label'] * (1 - t_mask) + t_mask * 4
114 | edge = data['edge']
115 | pre_clothes_edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int))
116 | clothes = data['color']
117 | clothes = clothes * pre_clothes_edge
118 | edge_un = data['edge_un']
119 | pre_clothes_edge_un = torch.FloatTensor((edge_un.detach().numpy() > 0.5).astype(np.int))
120 | clothes_un = data['color_un']
121 | clothes_un = clothes_un * pre_clothes_edge_un
122 | person_clothes_edge = torch.FloatTensor((data['label'].cpu().numpy() == 4).astype(np.int))
123 | real_image = data['image']
124 | person_clothes = real_image * person_clothes_edge
125 | pose = data['pose']
126 | size = data['label'].size()
127 | oneHot_size1 = (size[0], 25, size[2], size[3])
128 | densepose = torch.cuda.FloatTensor(torch.Size(oneHot_size1)).zero_()
129 | densepose = densepose.scatter_(1, data['densepose'].data.long().cuda(), 1.0)
130 | densepose_fore = data['densepose'] / 24
131 | face_mask = torch.FloatTensor((data['label'].cpu().numpy() == 1).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy() == 12).astype(np.int))
132 | other_clothes_mask = torch.FloatTensor((data['label'].cpu().numpy() == 5).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy() == 6).astype(np.int)) \
133 | + torch.FloatTensor((data['label'].cpu().numpy() == 8).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy() == 9).astype(np.int)) \
134 | + torch.FloatTensor((data['label'].cpu().numpy() == 10).astype(np.int))
135 | face_img = face_mask * real_image
136 | other_clothes_img = other_clothes_mask * real_image
137 | preserve_mask = torch.cat([face_mask, other_clothes_mask], 1)
138 |
139 | concat_un = torch.cat([preserve_mask.cuda(), densepose, pose.cuda()], 1)
140 | flow_out_un = PB_warp_model(concat_un.cuda(), clothes_un.cuda(), pre_clothes_edge_un.cuda())
141 | warped_cloth_un, last_flow_un, cond_un_all, flow_un_all, delta_list_un, x_all_un, x_edge_all_un, delta_x_all_un, delta_y_all_un = flow_out_un
142 | warped_prod_edge_un = F.grid_sample(pre_clothes_edge_un.cuda(), last_flow_un.permute(0, 2, 3, 1),
143 | mode='bilinear', padding_mode='zeros')
144 |
145 | flow_out_sup = PB_warp_model(concat_un.cuda(), clothes.cuda(), pre_clothes_edge.cuda())
146 | warped_cloth_sup, last_flow_sup, cond_sup_all, flow_sup_all, delta_list_sup, x_all_sup, x_edge_all_sup, delta_x_all_sup, delta_y_all_sup = flow_out_sup
147 |
148 | arm_mask = torch.FloatTensor((data['label'].cpu().numpy() == 11).astype(np.float)) + torch.FloatTensor((data['label'].cpu().numpy() == 13).astype(np.float))
149 | hand_mask = torch.FloatTensor((data['densepose'].cpu().numpy() == 3).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 4).astype(np.int))
150 | dense_preserve_mask = torch.FloatTensor((data['densepose'].cpu().numpy() == 15).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 16).astype(np.int)) \
151 | + torch.FloatTensor((data['densepose'].cpu().numpy() == 17).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 18).astype(np.int)) \
152 | + torch.FloatTensor((data['densepose'].cpu().numpy() == 19).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 20).astype(np.int)) \
153 | + torch.FloatTensor((data['densepose'].cpu().numpy() == 21).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 22))
154 | hand_img = (arm_mask * hand_mask) * real_image
155 | dense_preserve_mask = dense_preserve_mask.cuda() * (1 - warped_prod_edge_un)
156 | preserve_region = face_img + other_clothes_img + hand_img
157 |
158 | gen_inputs_un = torch.cat([preserve_region.cuda(), warped_cloth_un, warped_prod_edge_un, dense_preserve_mask], 1)
159 | gen_outputs_un = PB_gen_model(gen_inputs_un)
160 | p_rendered_un, m_composite_un = torch.split(gen_outputs_un, [3, 1], 1)
161 | p_rendered_un = torch.tanh(p_rendered_un)
162 | m_composite_un = torch.sigmoid(m_composite_un)
163 | m_composite_un = m_composite_un * warped_prod_edge_un
164 | p_tryon_un = warped_cloth_un * m_composite_un + p_rendered_un * (1 - m_composite_un)
165 |
166 | flow_out = PF_warp_model(p_tryon_un.detach(), clothes.cuda(), pre_clothes_edge.cuda())
167 | warped_cloth, last_flow, cond_all, flow_all, delta_list, x_all, x_edge_all, delta_x_all, delta_y_all = flow_out
168 | warped_prod_edge = x_edge_all[4]
169 |
170 | epsilon = 0.001
171 | loss_smooth = sum([TVLoss(x) for x in delta_list])
172 | loss_warp = 0
173 | loss_fea_sup_all = 0
174 | loss_flow_sup_all = 0
175 |
176 | l1_loss_batch = torch.abs(warped_cloth_sup.detach() - person_clothes.cuda())
177 | l1_loss_batch = l1_loss_batch.reshape(opt.batchSize, 3 * 256 * 192)
178 | l1_loss_batch = l1_loss_batch.sum(dim=1) / (3 * 256 * 192)
179 | l1_loss_batch_pred = torch.abs(warped_cloth.detach() - person_clothes.cuda())
180 | l1_loss_batch_pred = l1_loss_batch_pred.reshape(opt.batchSize, 3 * 256 * 192)
181 | l1_loss_batch_pred = l1_loss_batch_pred.sum(dim=1) / (3 * 256 * 192)
182 | weight = (l1_loss_batch < l1_loss_batch_pred).float()
183 | num_all = len(np.where(weight.cpu().numpy() > 0)[0])
184 | if num_all == 0:
185 | num_all = 1
186 |
187 | for num in range(5):
188 | cur_person_clothes = F.interpolate(person_clothes, scale_factor=0.5 ** (4 - num), mode='bilinear')
189 | cur_person_clothes_edge = F.interpolate(person_clothes_edge, scale_factor=0.5 ** (4 - num), mode='bilinear')
190 | loss_l1 = criterionL1(x_all[num], cur_person_clothes.cuda())
191 | loss_vgg = criterionVGG(x_all[num], cur_person_clothes.cuda())
192 | loss_edge = criterionL1(x_edge_all[num], cur_person_clothes_edge.cuda())
193 | b, c, h, w = delta_x_all[num].shape
194 | loss_flow_x = (delta_x_all[num].pow(2) + epsilon * epsilon).pow(0.45)
195 | loss_flow_x = torch.sum(loss_flow_x) / (b * c * h * w)
196 | loss_flow_y = (delta_y_all[num].pow(2) + epsilon * epsilon).pow(0.45)
197 | loss_flow_y = torch.sum(loss_flow_y) / (b * c * h * w)
198 | loss_second_smooth = loss_flow_x + loss_flow_y
199 | b1, c1, h1, w1 = cond_all[num].shape
200 | weight_all = weight.reshape(-1, 1, 1, 1).repeat(1, 256, h1, w1)
201 | cond_sup_loss = ((cond_sup_all[num].detach() - cond_all[num]) ** 2 * weight_all).sum() / (256 * h1 * w1 * num_all)
202 | loss_fea_sup_all = loss_fea_sup_all + (5 - num) * 0.04 * cond_sup_loss
203 | loss_warp = loss_warp + (num + 1) * loss_l1 + (num + 1) * 0.2 * loss_vgg + (num + 1) * 2 * loss_edge + (num + 1) * 6 * loss_second_smooth + (5 - num) * 0.04 * cond_sup_loss
204 | if num >= 2:
205 | b1, c1, h1, w1 = flow_all[num].shape
206 | weight_all = weight.reshape(-1, 1, 1).repeat(1, h1, w1)
207 | flow_sup_loss = (torch.norm(flow_sup_all[num].detach() - flow_all[num], p=2, dim=1) * weight_all).sum() / (h1 * w1 * num_all)
208 | loss_flow_sup_all = loss_flow_sup_all + (num + 1) * 1 * flow_sup_loss
209 | loss_warp = loss_warp + (num + 1) * 1 * flow_sup_loss
210 |
211 | loss_warp = 0.01 * loss_smooth + loss_warp
212 |
213 | if opt.local_rank == 0:
214 | writer.add_scalar('loss_warp', loss_warp, step)
215 | writer.add_scalar('loss_fea_sup_all', loss_fea_sup_all, step)
216 | writer.add_scalar('loss_flow_sup_all', loss_flow_sup_all, step)
217 |
218 | skin_mask = warped_prod_edge_un.detach() * (1 - person_clothes_edge.cuda())
219 | gen_inputs = torch.cat([p_tryon_un.detach(), warped_cloth, warped_prod_edge], 1)
220 | gen_outputs = PF_gen_model(gen_inputs)
221 | p_rendered, m_composite = torch.split(gen_outputs, [3, 1], 1)
222 | p_rendered = torch.tanh(p_rendered)
223 | m_composite = torch.sigmoid(m_composite)
224 | m_composite1 = m_composite * warped_prod_edge
225 | m_composite = person_clothes_edge.cuda() * m_composite1
226 | p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite)
227 |
228 | loss_mask_l1 = torch.mean(torch.abs(1 - m_composite))
229 | loss_l1_skin = criterionL1(p_rendered * skin_mask, skin_mask * real_image.cuda())
230 | loss_vgg_skin = criterionVGG(p_rendered * skin_mask, skin_mask * real_image.cuda())
231 | loss_l1 = criterionL1(p_tryon, real_image.cuda())
232 | loss_vgg = criterionVGG(p_tryon, real_image.cuda())
233 | bg_loss_l1 = criterionL1(p_rendered, real_image.cuda())
234 | bg_loss_vgg = criterionVGG(p_rendered, real_image.cuda())
235 |
236 | if epoch < opt.niter:
237 | loss_gen = (loss_l1 * 5 + loss_l1_skin * 30 + loss_vgg + loss_vgg_skin * 2 + bg_loss_l1 * 5 + bg_loss_vgg + 1 * loss_mask_l1)
238 | else:
239 | loss_gen = (loss_l1 * 5 + loss_l1_skin * 60 + loss_vgg + loss_vgg_skin * 4 + bg_loss_l1 * 5 + bg_loss_vgg + 1 * loss_mask_l1)
240 |
241 | loss_all = 0.25 * loss_warp + loss_gen
242 |
243 | if opt.local_rank == 0:
244 | writer.add_scalar('loss_gen', loss_gen, step)
245 |
246 | optimizer_warp.zero_grad()
247 | optimizer_gen.zero_grad()
248 | loss_all.backward()
249 | optimizer_warp.step()
250 | optimizer_gen.step()
251 |
252 | ############## Display results and errors ##########
253 | path = 'sample/' + opt.name
254 | os.makedirs(path, exist_ok=True)
255 | ### display output images
256 | if step % 1000 == 0:
257 | if opt.local_rank == 0:
258 | a = real_image.float().cuda()
259 | b = p_tryon_un.detach()
260 | c = clothes.cuda()
261 | d = person_clothes.cuda()
262 | e = torch.cat([skin_mask.cuda(), skin_mask.cuda(), skin_mask.cuda()], 1)
263 | f = warped_cloth
264 | g = p_rendered
265 | h = torch.cat([m_composite1, m_composite1, m_composite1], 1)
266 | i = p_tryon
267 | combine = torch.cat([a[0], b[0], c[0], d[0], e[0], f[0], g[0], h[0], i[0]], 2).squeeze()
268 | cv_img = (combine.permute(1, 2, 0).detach().cpu().numpy() + 1) / 2
269 | writer.add_image('combine', (combine.data + 1) / 2.0, step)
270 | rgb = (cv_img * 255).astype(np.uint8)
271 | bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
272 | cv2.imwrite('sample/' + opt.name + '/' + str(step) + '.jpg', bgr)
273 |
274 | step += 1
275 | iter_end_time = time.time()
276 | iter_delta_time = iter_end_time - iter_start_time
277 | step_delta = (step_per_batch - step % step_per_batch) + step_per_batch * (opt.niter + opt.niter_decay - epoch)
278 | eta = iter_delta_time * step_delta
279 | eta = str(datetime.timedelta(seconds=int(eta)))
280 | time_stamp = datetime.datetime.now()
281 | now = time_stamp.strftime('%Y.%m.%d-%H:%M:%S')
282 |
283 | if step % 100 == 0:
284 | if opt.local_rank == 0:
285 | print('{}:{}:[step-{}]--[loss-{:.6f}]--[loss-{:.6f}]--[ETA-{}]'.format(now, epoch_iter, step, loss_gen, loss_warp, eta))
286 |
287 | if epoch_iter >= dataset_size:
288 | break
289 |
290 | # end of epoch
291 | iter_end_time = time.time()
292 | if opt.local_rank == 0:
293 | print('End of epoch %d / %d \t Time Taken: %d sec' %
294 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
295 |
296 | if epoch % opt.save_epoch_freq == 0:
297 | if opt.local_rank == 0:
298 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
299 | save_checkpoint(PF_warp_model.module,
300 | os.path.join(opt.checkpoints_dir, opt.name, 'PFAFN_warp_epoch_%03d.pth' % (epoch + 1)))
301 | save_checkpoint(PF_gen_model.module,
302 | os.path.join(opt.checkpoints_dir, opt.name, 'PFAFN_gen_epoch_%03d.pth' % (epoch + 1)))
303 |
304 | if epoch > opt.niter:
305 | PF_warp_model.module.update_learning_rate_warp(optimizer_warp)
306 | PF_warp_model.module.update_learning_rate(optimizer_gen)
307 |
--------------------------------------------------------------------------------
/PF-AFN_train/train_PFAFN_stage1.py:
--------------------------------------------------------------------------------
1 | import time
2 | from options.train_options import TrainOptions
3 | from models.networks import ResUnetGenerator, VGGLoss, save_checkpoint, load_checkpoint_part_parallel, \
4 | load_checkpoint_parallel
5 | from models.afwm import TVLoss, AFWM
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import os
9 | import numpy as np
10 | import torch
11 | from torch.utils.data import DataLoader
12 | from torch.utils.data.distributed import DistributedSampler
13 | from tensorboardX import SummaryWriter
14 | import cv2
15 | import datetime
16 |
17 | opt = TrainOptions().parse()
18 | path = 'runs/' + opt.name
19 | os.makedirs(path, exist_ok=True)
20 |
21 |
22 | def CreateDataset(opt):
23 | from data.aligned_dataset import AlignedDataset
24 | dataset = AlignedDataset()
25 | print("dataset [%s] was created" % (dataset.name()))
26 | dataset.initialize(opt)
27 | return dataset
28 |
29 |
30 | os.makedirs('sample', exist_ok=True)
31 | opt = TrainOptions().parse()
32 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
33 |
34 | torch.cuda.set_device(opt.local_rank)
35 | torch.distributed.init_process_group(
36 | 'nccl',
37 | init_method='env://'
38 | )
39 | device = torch.device(f'cuda:{opt.local_rank}')
40 |
41 | start_epoch, epoch_iter = 1, 0
42 |
43 | train_data = CreateDataset(opt)
44 | train_sampler = DistributedSampler(train_data)
45 | train_loader = DataLoader(train_data, batch_size=opt.batchSize, shuffle=False,
46 | num_workers=4, pin_memory=True, sampler=train_sampler)
47 | dataset_size = len(train_loader)
48 | print('#training images = %d' % dataset_size)
49 |
50 | PF_warp_model = AFWM(opt, 3)
51 | print(PF_warp_model)
52 | PF_warp_model.train()
53 | PF_warp_model.cuda()
54 | load_checkpoint_part_parallel(PF_warp_model, opt.PBAFN_warp_checkpoint)
55 |
56 | PB_warp_model = AFWM(opt, 45)
57 | print(PB_warp_model)
58 | PB_warp_model.eval()
59 | PB_warp_model.cuda()
60 | load_checkpoint_parallel(PB_warp_model, opt.PBAFN_warp_checkpoint)
61 |
62 | PB_gen_model = ResUnetGenerator(8, 4, 5, ngf=64, norm_layer=nn.BatchNorm2d)
63 | print(PB_gen_model)
64 | PB_gen_model.eval()
65 | PB_gen_model.cuda()
66 | load_checkpoint_parallel(PB_gen_model, opt.PBAFN_gen_checkpoint)
67 |
68 | PF_warp_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(PF_warp_model).to(device)
69 |
70 | if opt.isTrain and len(opt.gpu_ids):
71 | PF_warp_model = torch.nn.parallel.DistributedDataParallel(PF_warp_model, device_ids=[opt.local_rank])
72 | PB_warp_model = torch.nn.parallel.DistributedDataParallel(PB_warp_model, device_ids=[opt.local_rank])
73 | PB_gen_model = torch.nn.parallel.DistributedDataParallel(PB_gen_model, device_ids=[opt.local_rank])
74 |
75 | criterionL1 = nn.L1Loss()
76 | criterionVGG = VGGLoss()
77 | criterionL2 = nn.MSELoss('sum')
78 |
79 | # optimizer
80 | params = [p for p in PF_warp_model.parameters()]
81 | optimizer = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
82 |
83 | params_part = []
84 | for name, param in PF_warp_model.named_parameters():
85 | if 'cond_' in name or 'aflow_net.netRefine' in name:
86 | params_part.append(param)
87 | optimizer_part = torch.optim.Adam(params_part, lr=opt.lr, betas=(opt.beta1, 0.999))
88 |
89 | total_steps = (start_epoch - 1) * dataset_size + epoch_iter
90 |
91 | if opt.local_rank == 0:
92 | writer = SummaryWriter(path)
93 |
94 | step = 0
95 | step_per_batch = dataset_size
96 |
97 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
98 | epoch_start_time = time.time()
99 | if epoch != start_epoch:
100 | epoch_iter = epoch_iter % dataset_size
101 |
102 | train_sampler.set_epoch(epoch)
103 |
104 | for i, data in enumerate(train_loader):
105 |
106 | iter_start_time = time.time()
107 |
108 | total_steps += 1
109 | epoch_iter += 1
110 | save_fake = True
111 |
112 | t_mask = torch.FloatTensor((data['label'].cpu().numpy() == 7).astype(np.float))
113 | data['label'] = data['label'] * (1 - t_mask) + t_mask * 4
114 | edge = data['edge']
115 | pre_clothes_edge = torch.FloatTensor((edge.detach().numpy() > 0.5).astype(np.int))
116 | clothes = data['color']
117 | clothes = clothes * pre_clothes_edge
118 | edge_un = data['edge_un']
119 | pre_clothes_edge_un = torch.FloatTensor((edge_un.detach().numpy() > 0.5).astype(np.int))
120 | clothes_un = data['color_un']
121 | clothes_un = clothes_un * pre_clothes_edge_un
122 | person_clothes_edge = torch.FloatTensor((data['label'].cpu().numpy() == 4).astype(np.int))
123 | real_image = data['image']
124 | person_clothes = real_image * person_clothes_edge
125 | pose = data['pose']
126 | size = data['label'].size()
127 | oneHot_size1 = (size[0], 25, size[2], size[3])
128 | densepose = torch.cuda.FloatTensor(torch.Size(oneHot_size1)).zero_()
129 | densepose = densepose.scatter_(1, data['densepose'].data.long().cuda(), 1.0)
130 | densepose_fore = data['densepose'] / 24
131 | face_mask = torch.FloatTensor((data['label'].cpu().numpy() == 1).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy() == 12).astype(np.int))
132 | other_clothes_mask = torch.FloatTensor((data['label'].cpu().numpy() == 5).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy() == 6).astype(np.int)) \
133 | + torch.FloatTensor((data['label'].cpu().numpy() == 8).astype(np.int)) + torch.FloatTensor((data['label'].cpu().numpy() == 9).astype(np.int)) \
134 | + torch.FloatTensor((data['label'].cpu().numpy() == 10).astype(np.int))
135 | face_img = face_mask * real_image
136 | other_clothes_img = other_clothes_mask * real_image
137 | preserve_mask = torch.cat([face_mask, other_clothes_mask], 1)
138 |
139 | concat_un = torch.cat([preserve_mask.cuda(), densepose, pose.cuda()], 1)
140 | flow_out_un = PB_warp_model(concat_un.cuda(), clothes_un.cuda(), pre_clothes_edge_un.cuda())
141 | warped_cloth_un, last_flow_un, cond_un_all, flow_un_all, delta_list_un, x_all_un, x_edge_all_un, delta_x_all_un, delta_y_all_un = flow_out_un
142 | warped_prod_edge_un = F.grid_sample(pre_clothes_edge_un.cuda(), last_flow_un.permute(0, 2, 3, 1),
143 | mode='bilinear', padding_mode='zeros')
144 |
145 | flow_out_sup = PB_warp_model(concat_un.cuda(), clothes.cuda(), pre_clothes_edge.cuda())
146 | warped_cloth_sup, last_flow_sup, cond_sup_all, flow_sup_all, delta_list_sup, x_all_sup, x_edge_all_sup, delta_x_all_sup, delta_y_all_sup = flow_out_sup
147 |
148 | arm_mask = torch.FloatTensor((data['label'].cpu().numpy() == 11).astype(np.float)) + torch.FloatTensor((data['label'].cpu().numpy() == 13).astype(np.float))
149 | hand_mask = torch.FloatTensor((data['densepose'].cpu().numpy() == 3).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 4).astype(np.int))
150 | dense_preserve_mask = torch.FloatTensor((data['densepose'].cpu().numpy() == 15).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 16).astype(np.int)) \
151 | + torch.FloatTensor((data['densepose'].cpu().numpy() == 17).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 18).astype(np.int)) \
152 | + torch.FloatTensor((data['densepose'].cpu().numpy() == 19).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 20).astype(np.int)) \
153 | + torch.FloatTensor((data['densepose'].cpu().numpy() == 21).astype(np.int)) + torch.FloatTensor((data['densepose'].cpu().numpy() == 22))
154 | hand_img = (arm_mask * hand_mask) * real_image
155 | dense_preserve_mask = dense_preserve_mask.cuda() * (1 - warped_prod_edge_un)
156 | preserve_region = face_img + other_clothes_img + hand_img
157 |
158 | gen_inputs_un = torch.cat([preserve_region.cuda(), warped_cloth_un, warped_prod_edge_un, dense_preserve_mask], 1)
159 | gen_outputs_un = PB_gen_model(gen_inputs_un)
160 | p_rendered_un, m_composite_un = torch.split(gen_outputs_un, [3, 1], 1)
161 | p_rendered_un = torch.tanh(p_rendered_un)
162 | m_composite_un = torch.sigmoid(m_composite_un)
163 | m_composite_un = m_composite_un * warped_prod_edge_un
164 | p_tryon_un = warped_cloth_un * m_composite_un + p_rendered_un * (1 - m_composite_un)
165 |
166 | flow_out = PF_warp_model(p_tryon_un.detach(), clothes.cuda(), pre_clothes_edge.cuda())
167 | warped_cloth, last_flow, cond_all, flow_all, delta_list, x_all, x_edge_all, delta_x_all, delta_y_all = flow_out
168 | warped_prod_edge = x_edge_all[4]
169 |
170 | epsilon = 0.001
171 | loss_smooth = sum([TVLoss(x) for x in delta_list])
172 | loss_all = 0
173 | loss_fea_sup_all = 0
174 | loss_flow_sup_all = 0
175 |
176 | l1_loss_batch = torch.abs(warped_cloth_sup.detach() - person_clothes.cuda())
177 | l1_loss_batch = l1_loss_batch.reshape(opt.batchSize, 3 * 256 * 192)
178 | l1_loss_batch = l1_loss_batch.sum(dim=1) / (3 * 256 * 192)
179 | l1_loss_batch_pred = torch.abs(warped_cloth.detach() - person_clothes.cuda())
180 | l1_loss_batch_pred = l1_loss_batch_pred.reshape(opt.batchSize, 3 * 256 * 192)
181 | l1_loss_batch_pred = l1_loss_batch_pred.sum(dim=1) / (3 * 256 * 192)
182 | weight = (l1_loss_batch < l1_loss_batch_pred).float()
183 | num_all = len(np.where(weight.cpu().numpy() > 0)[0])
184 | if num_all == 0:
185 | num_all = 1
186 |
187 | for num in range(5):
188 | cur_person_clothes = F.interpolate(person_clothes, scale_factor=0.5 ** (4 - num), mode='bilinear')
189 | cur_person_clothes_edge = F.interpolate(person_clothes_edge, scale_factor=0.5 ** (4 - num), mode='bilinear')
190 | loss_l1 = criterionL1(x_all[num], cur_person_clothes.cuda())
191 | loss_vgg = criterionVGG(x_all[num], cur_person_clothes.cuda())
192 | loss_edge = criterionL1(x_edge_all[num], cur_person_clothes_edge.cuda())
193 | b, c, h, w = delta_x_all[num].shape
194 | loss_flow_x = (delta_x_all[num].pow(2) + epsilon * epsilon).pow(0.45)
195 | loss_flow_x = torch.sum(loss_flow_x) / (b * c * h * w)
196 | loss_flow_y = (delta_y_all[num].pow(2) + epsilon * epsilon).pow(0.45)
197 | loss_flow_y = torch.sum(loss_flow_y) / (b * c * h * w)
198 | loss_second_smooth = loss_flow_x + loss_flow_y
199 | b1, c1, h1, w1 = cond_all[num].shape
200 | weight_all = weight.reshape(-1, 1, 1, 1).repeat(1, 256, h1, w1)
201 | cond_sup_loss = ((cond_sup_all[num].detach() - cond_all[num]) ** 2 * weight_all).sum() / (256 * h1 * w1 * num_all)
202 | loss_fea_sup_all = loss_fea_sup_all + (5 - num) * 0.04 * cond_sup_loss
203 | loss_all = loss_all + (num + 1) * loss_l1 + (num + 1) * 0.2 * loss_vgg + (num + 1) * 2 * loss_edge + (num + 1) * 6 * loss_second_smooth + (5 - num) * 0.04 * cond_sup_loss
204 | if num >= 2:
205 | b1, c1, h1, w1 = flow_all[num].shape
206 | weight_all = weight.reshape(-1, 1, 1).repeat(1, h1, w1)
207 | flow_sup_loss = (torch.norm(flow_sup_all[num].detach() - flow_all[num], p=2, dim=1) * weight_all).sum() / (h1 * w1 * num_all)
208 | loss_flow_sup_all = loss_flow_sup_all + (num + 1) * 1 * flow_sup_loss
209 | loss_all = loss_all + (num + 1) * 1 * flow_sup_loss
210 |
211 | loss_all = 0.01 * loss_smooth + loss_all
212 |
213 | # sum per device losses
214 | if opt.local_rank == 0:
215 | writer.add_scalar('loss_all', loss_all, step)
216 | writer.add_scalar('loss_fea_sup_all', loss_fea_sup_all, step)
217 | writer.add_scalar('loss_flow_sup_all', loss_flow_sup_all, step)
218 |
219 | if epoch < opt.niter:
220 | optimizer_part.zero_grad()
221 | loss_all.backward()
222 | optimizer_part.step()
223 | else:
224 | optimizer.zero_grad()
225 | loss_all.backward()
226 | optimizer.step()
227 |
228 | ############## Display results and errors ##########
229 | path = 'sample/' + opt.name
230 | os.makedirs(path, exist_ok=True)
231 | ### display output images
232 | if step % 1000 == 0:
233 | if opt.local_rank == 0:
234 | a = real_image.float().cuda()
235 | b = p_tryon_un.detach()
236 | c = clothes.cuda()
237 | d = person_clothes.cuda()
238 | e = torch.cat([person_clothes_edge.cuda(), person_clothes_edge.cuda(), person_clothes_edge.cuda()], 1)
239 | f = torch.cat([densepose_fore.cuda(), densepose_fore.cuda(), densepose_fore.cuda()], 1)
240 | g = warped_cloth
241 | h = torch.cat([warped_prod_edge, warped_prod_edge, warped_prod_edge], 1)
242 | combine = torch.cat([a[0], b[0], c[0], d[0], e[0], f[0], g[0], h[0]], 2).squeeze()
243 | cv_img = (combine.permute(1, 2, 0).detach().cpu().numpy() + 1) / 2
244 | writer.add_image('combine', (combine.data + 1) / 2.0, step)
245 | rgb = (cv_img * 255).astype(np.uint8)
246 | bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
247 | cv2.imwrite('sample/' + opt.name + '/' + str(step) + '.jpg', bgr)
248 |
249 | step += 1
250 | iter_end_time = time.time()
251 | iter_delta_time = iter_end_time - iter_start_time
252 | step_delta = (step_per_batch - step % step_per_batch) + step_per_batch * (opt.niter + opt.niter_decay - epoch)
253 | eta = iter_delta_time * step_delta
254 | eta = str(datetime.timedelta(seconds=int(eta)))
255 | time_stamp = datetime.datetime.now()
256 | now = time_stamp.strftime('%Y.%m.%d-%H:%M:%S')
257 | if step % 100 == 0:
258 | if opt.local_rank == 0:
259 | print('{}:{}:[step-{}]--[loss-{:.6f}]--[loss-{:.6f}]--[loss-{:.6f}]--[ETA-{}]'.format(now, epoch_iter, step, loss_all, loss_fea_sup_all, loss_flow_sup_all, eta))
260 |
261 | if epoch_iter >= dataset_size:
262 | break
263 |
264 | # end of epoch
265 | iter_end_time = time.time()
266 | if opt.local_rank == 0:
267 | print('End of epoch %d / %d \t Time Taken: %d sec' %
268 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
269 |
270 | ### save model for this epoch
271 | if epoch % opt.save_epoch_freq == 0:
272 | if opt.local_rank == 0:
273 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
274 | save_checkpoint(PF_warp_model.module,
275 | os.path.join(opt.checkpoints_dir, opt.name, 'PFAFN_warp_epoch_%03d.pth' % (epoch + 1)))
276 |
277 | if epoch > opt.niter:
278 | PF_warp_model.module.update_learning_rate(optimizer)
279 |
--------------------------------------------------------------------------------
/PF-AFN_train/util/__init__.py:
--------------------------------------------------------------------------------
1 | # util_init
--------------------------------------------------------------------------------
/PF-AFN_train/util/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/util/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/util/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/util/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/util/__pycache__/image_pool.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/util/__pycache__/image_pool.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/util/__pycache__/util.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/util/__pycache__/util.cpython-36.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/util/__pycache__/util.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PF-AFN_train/util/__pycache__/util.cpython-37.pyc
--------------------------------------------------------------------------------
/PF-AFN_train/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from torch.autograd import Variable
4 | class ImagePool():
5 | def __init__(self, pool_size):
6 | self.pool_size = pool_size
7 | if self.pool_size > 0:
8 | self.num_imgs = 0
9 | self.images = []
10 |
11 | def query(self, images):
12 | if self.pool_size == 0:
13 | return images
14 | return_images = []
15 | for image in images.data:
16 | image = torch.unsqueeze(image, 0)
17 | if self.num_imgs < self.pool_size:
18 | self.num_imgs = self.num_imgs + 1
19 | self.images.append(image)
20 | return_images.append(image)
21 | else:
22 | p = random.uniform(0, 1)
23 | if p > 0.5:
24 | random_id = random.randint(0, self.pool_size-1)
25 | tmp = self.images[random_id].clone()
26 | self.images[random_id] = image
27 | return_images.append(tmp)
28 | else:
29 | return_images.append(image)
30 | return_images = Variable(torch.cat(return_images, 0))
31 | return return_images
32 |
--------------------------------------------------------------------------------
/PF-AFN_train/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | import numpy as np
7 | import os
8 |
9 | # Converts a Tensor into a Numpy array
10 | # |imtype|: the desired type of the converted numpy array
11 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
12 | if isinstance(image_tensor, list):
13 | image_numpy = []
14 | for i in range(len(image_tensor)):
15 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
16 | return image_numpy
17 | image_numpy = image_tensor.cpu().float().numpy()
18 | #if normalize:
19 | # image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
20 | #else:
21 | # image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
22 | image_numpy = (image_numpy + 1) / 2.0
23 | image_numpy = np.clip(image_numpy, 0, 1)
24 | if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
25 | image_numpy = image_numpy[:,:,0]
26 |
27 | return image_numpy
28 |
29 | # Converts a one-hot tensor into a colorful label map
30 | def tensor2label(label_tensor, n_label, imtype=np.uint8):
31 | if n_label == 0:
32 | return tensor2im(label_tensor, imtype)
33 | label_tensor = label_tensor.cpu().float()
34 | if label_tensor.size()[0] > 1:
35 | label_tensor = label_tensor.max(0, keepdim=True)[1]
36 | label_tensor = Colorize(n_label)(label_tensor)
37 | #label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
38 | label_numpy = label_tensor.numpy()
39 | label_numpy = label_numpy / 255.0
40 |
41 | return label_numpy
42 |
43 | def save_image(image_numpy, image_path):
44 | image_pil = Image.fromarray(image_numpy)
45 | image_pil.save(image_path)
46 |
47 | def mkdirs(paths):
48 | if isinstance(paths, list) and not isinstance(paths, str):
49 | for path in paths:
50 | mkdir(path)
51 | else:
52 | mkdir(paths)
53 |
54 | def mkdir(path):
55 | if not os.path.exists(path):
56 | os.makedirs(path)
57 |
58 | ###############################################################################
59 | # Code from
60 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py
61 | # Modified so it complies with the Citscape label map colors
62 | ###############################################################################
63 | def uint82bin(n, count=8):
64 | """returns the binary of integer n, count refers to amount of bits"""
65 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
66 |
67 | def labelcolormap(N):
68 | if N == 35: # cityscape
69 | cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
70 | (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
71 | (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
72 | (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
73 | ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
74 | dtype=np.uint8)
75 | else:
76 | cmap = np.zeros((N, 3), dtype=np.uint8)
77 | for i in range(N):
78 | r, g, b = 0, 0, 0
79 | id = i
80 | for j in range(7):
81 | str_id = uint82bin(id)
82 | r = r ^ (np.uint8(str_id[-1]) << (7-j))
83 | g = g ^ (np.uint8(str_id[-2]) << (7-j))
84 | b = b ^ (np.uint8(str_id[-3]) << (7-j))
85 | id = id >> 3
86 | cmap[i, 0] = r
87 | cmap[i, 1] = g
88 | cmap[i, 2] = b
89 | return cmap
90 |
91 | class Colorize(object):
92 | def __init__(self, n=35):
93 | self.cmap = labelcolormap(n)
94 | self.cmap = torch.from_numpy(self.cmap[:n])
95 |
96 | def __call__(self, gray_image):
97 | size = gray_image.size()
98 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
99 |
100 | for label in range(0, len(self.cmap)):
101 | mask = (label == gray_image[0]).cpu()
102 | color_image[0][mask] = self.cmap[label][0]
103 | color_image[1][mask] = self.cmap[label][1]
104 | color_image[2][mask] = self.cmap[label][2]
105 |
106 | return color_image
107 |
--------------------------------------------------------------------------------
/PFAFN_supp.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/PFAFN_supp.pdf
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Parser-Free Virtual Try-on via Distilling Appearance Flows, CVPR 2021
2 | Official code for CVPR 2021 paper 'Parser-Free Virtual Try-on via Distilling Appearance Flows'
3 |
4 |
5 | **The training code has been released.**
6 |
7 | 
8 |
9 | [[Paper]](https://openaccess.thecvf.com/content/CVPR2021/papers/Ge_Parser-Free_Virtual_Try-On_via_Distilling_Appearance_Flows_CVPR_2021_paper.pdf) [[Supplementary Material]](https://github.com/geyuying/PF-AFN/blob/main/PFAFN_supp.pdf)
10 |
11 | [[Checkpoints for Test]](https://drive.google.com/file/d/1_a0AiN8Y_d_9TNDhHIcRlERz3zptyYWV/view?usp=sharing)
12 |
13 | [[Training_Data]](https://drive.google.com/file/d/1Uc0DTTkSfCPXDhd4CMx2TQlzlC6bDolK/view?usp=sharing)
14 | [[Test_Data]](https://drive.google.com/file/d/1Y7uV0gomwWyxCvvH8TIbY7D9cTAUy6om/view?usp=sharing)
15 |
16 | [[VGG_Model]](https://drive.google.com/file/d/1Mw24L52FfOT9xXm3I1GL8btn7vttsHd9/view?usp=sharing)
17 |
18 | ## Our Environment
19 | anaconda3
20 |
21 | pytorch 1.1.0
22 |
23 | torchvision 0.3.0
24 |
25 | cuda 9.0
26 |
27 | cupy 6.0.0
28 |
29 | opencv-python 4.5.1
30 |
31 | 8 GTX1080 GPU for training; 1 GTX1080 GPU for test
32 |
33 | python 3.6
34 |
35 | ## Installation
36 | conda create -n tryon python=3.6
37 |
38 | source activate tryon or conda activate tryon
39 |
40 | conda install pytorch=1.1.0 torchvision=0.3.0 cudatoolkit=9.0 -c pytorch
41 |
42 | conda install cupy or pip install cupy==6.0.0
43 |
44 | pip install opencv-python
45 |
46 | git clone https://github.com/geyuying/PF-AFN.git
47 |
48 | cd PF-AFN
49 |
50 | ## Training on VITON dataset
51 | 1. cd PF-AFN_train
52 | 2. Download the VITON training set from [VITON_train](https://drive.google.com/file/d/1Uc0DTTkSfCPXDhd4CMx2TQlzlC6bDolK/view?usp=sharing) and put the folder "VITON_traindata" under the folder "dataset".
53 | 3. Dowload the VGG_19 model from [VGG_Model](https://drive.google.com/file/d/1Mw24L52FfOT9xXm3I1GL8btn7vttsHd9/view?usp=sharing) and put "vgg19-dcbb9e9d.pth" under the folder "models".
54 | 4. First train the parser-based network PBAFN. Run **scripts/train_PBAFN_stage1.sh**. After the parser-based warping module is trained, run **scripts/train_PBAFN_e2e.sh**.
55 | 5. After training the parser-based network PBAFN, train the parser-free network PFAFN. Run **scripts/train_PFAFN_stage1.sh**. After the parser-free warping module is trained, run **scripts/train_PFAFN_e2e.sh**.
56 | 6. Following the above insructions with the provided training code, the [[trained PF-AFN]](https://drive.google.com/file/d/1Pz2kA65N4Ih9w6NFYBDmdtVdB-nrrdc3/view?usp=sharing) achieves FID 9.92 on VITON test set with the test_pairs.txt (You can find it in https://github.com/minar09/cp-vton-plus/blob/master/data/test_pairs.txt).
57 |
58 | ## Run the demo
59 | 1. cd PF-AFN_test
60 | 2. First, you need to download the checkpoints from [checkpoints](https://drive.google.com/file/d/1_a0AiN8Y_d_9TNDhHIcRlERz3zptyYWV/view?usp=sharing) and put the folder "PFAFN" under the folder "checkpoints". The folder "checkpoints/PFAFN" shold contain "warp_model_final.pth" and "gen_model_final.pth".
61 | 3. The "dataset" folder contains the demo images for test, where the "test_img" folder contains the person images, the "test_clothes" folder contains the clothes images, and the "test_edge" folder contains edges extracted from the clothes images with the built-in function in python (We saved the extracted edges from the clothes images for convenience). 'demo.txt' records the test pairs.
62 | 4. During test, a person image, a clothes image and its extracted edge are fed into the network to generate the try-on image. **No human parsing results or human pose estimation results are needed for test.**
63 | 5. To test with the saved model, run **test.sh** and the results will be saved in the folder "results".
64 | 6. **To reproduce our results from the saved model, your test environment should be the same as our test environment, especifically for the version of cupy.**
65 |
66 | 
67 | ## Dataset
68 | 1. [VITON](https://github.com/xthan/VITON) contains a training set of 14,221 image pairs and a test set of 2,032 image pairs, each of which has a front-view woman photo and a top clothing image with the resolution 256 x 192. Our saved model is trained on the VITON training set and tested on the VITON test set.
69 | 2. To train from scratch on VITON training set, you can download [VITON_train](https://drive.google.com/file/d/1Uc0DTTkSfCPXDhd4CMx2TQlzlC6bDolK/view?usp=sharing).
70 | 3. To test our saved model on the complete VITON test set, you can download [VITON_test](https://drive.google.com/file/d/1Y7uV0gomwWyxCvvH8TIbY7D9cTAUy6om/view?usp=sharing).
71 |
72 | ## License
73 | The use of this code is RESTRICTED to non-commercial research and educational purposes.
74 |
75 | ## Acknowledgement
76 | Our code is based on the implementation of "Clothflow: A flow-based model for clothed person generation" (See the citation below), including the implementation of the feature pyramid networks (FPN) and the ResUnetGenerator, and the adaptation of the cascaded structure to predict the appearance flows. If you use our code, please also cite their work as below.
77 |
78 |
79 | ## Citation
80 | If our code is helpful to your work, please cite:
81 | ```
82 | @article{ge2021parser,
83 | title={Parser-Free Virtual Try-on via Distilling Appearance Flows},
84 | author={Ge, Yuying and Song, Yibing and Zhang, Ruimao and Ge, Chongjian and Liu, Wei and Luo, Ping},
85 | journal={arXiv preprint arXiv:2103.04559},
86 | year={2021}
87 | }
88 | ```
89 | ```
90 | @inproceedings{han2019clothflow,
91 | title={Clothflow: A flow-based model for clothed person generation},
92 | author={Han, Xintong and Hu, Xiaojun and Huang, Weilin and Scott, Matthew R},
93 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
94 | pages={10471--10480},
95 | year={2019}
96 | }
97 | ```
98 |
--------------------------------------------------------------------------------
/show/compare.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/show/compare.jpg
--------------------------------------------------------------------------------
/show/compare_both.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geyuying/PF-AFN/e2bb71e2f0b472479f386fa1b9c146ec3340f2e8/show/compare_both.jpg
--------------------------------------------------------------------------------