├── LICENSE.md ├── README.md ├── deepfashion ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── base_data_loader.cpython-35.pyc │ │ ├── base_data_loader.cpython-36.pyc │ │ ├── base_dataset.cpython-35.pyc │ │ ├── base_dataset.cpython-36.pyc │ │ ├── custom_dataset_data_loader.cpython-35.pyc │ │ ├── custom_dataset_data_loader.cpython-36.pyc │ │ ├── data_loader.cpython-35.pyc │ │ ├── data_loader.cpython-36.pyc │ │ ├── image_folder.cpython-35.pyc │ │ ├── image_folder.cpython-36.pyc │ │ ├── keypoint.cpython-35.pyc │ │ └── keypoint.cpython-36.pyc │ ├── base_data_loader.py │ ├── base_dataset.py │ ├── custom_dataset_data_loader.py │ ├── data_loader.py │ ├── image_folder.py │ └── keypoint.py ├── losses │ ├── L1_plus_perceptualLoss.py │ ├── __init__.py │ └── __pycache__ │ │ ├── L1_plus_perceptualLoss.cpython-35.pyc │ │ ├── L1_plus_perceptualLoss.cpython-36.pyc │ │ ├── __init__.cpython-35.pyc │ │ └── __init__.cpython-36.pyc ├── models │ ├── BiGraphGAN.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── PATN.cpython-35.pyc │ │ ├── PATN.cpython-36.pyc │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── base_model.cpython-35.pyc │ │ ├── base_model.cpython-36.pyc │ │ ├── model_variants.cpython-35.pyc │ │ ├── model_variants.cpython-36.pyc │ │ ├── models.cpython-35.pyc │ │ ├── models.cpython-36.pyc │ │ ├── networks.cpython-35.pyc │ │ └── networks.cpython-36.pyc │ ├── base_model.py │ ├── model_variants.py │ ├── models.py │ ├── networks.py │ └── test_model.py ├── options │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── base_options.cpython-35.pyc │ │ ├── base_options.cpython-36.pyc │ │ ├── base_options.cpython-37.pyc │ │ ├── test_options.cpython-35.pyc │ │ ├── test_options.cpython-36.pyc │ │ ├── train_options.cpython-35.pyc │ │ └── train_options.cpython-37.pyc │ ├── base_options.py │ ├── base_options.pyc │ ├── test_options.py │ ├── train_options.py │ └── train_options.pyc ├── test.py ├── test_deepfashion.sh ├── test_deepfashion_pretrained.sh ├── tool │ ├── calPCKH_fashion.py │ ├── calPCKH_market.py │ ├── cmd.py │ ├── compute_coordinates.py │ ├── create_pairs_dataset.py │ ├── crop_fashion.py │ ├── crop_market.py │ ├── generate_fashion_datasets.py │ ├── generate_pose_map_fashion.py │ ├── generate_pose_map_market.py │ ├── getMetrics_fashion.py │ ├── getMetrics_market.py │ ├── inception_score.py │ ├── pose_utils.py │ ├── resize_fashion.py │ └── rm_insnorm_running_vars.py ├── train.py ├── train_deepfashion.sh └── util │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── html.cpython-35.pyc │ ├── html.cpython-36.pyc │ ├── image_pool.cpython-35.pyc │ ├── image_pool.cpython-36.pyc │ ├── util.cpython-35.pyc │ ├── util.cpython-36.pyc │ ├── util.cpython-37.pyc │ ├── visualizer.cpython-35.pyc │ └── visualizer.cpython-36.pyc │ ├── get_data.py │ ├── html.py │ ├── image_pool.py │ ├── png.py │ ├── util.py │ ├── util.pyc │ └── visualizer.py ├── facial ├── README.md ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── aligned.cpython-38.pyc │ │ ├── base_data_loader.cpython-35.pyc │ │ ├── base_data_loader.cpython-36.pyc │ │ ├── base_data_loader.cpython-38.pyc │ │ ├── base_dataset.cpython-35.pyc │ │ ├── base_dataset.cpython-36.pyc │ │ ├── base_dataset.cpython-38.pyc │ │ ├── custom_dataset_data_loader.cpython-35.pyc │ │ ├── custom_dataset_data_loader.cpython-36.pyc │ │ ├── custom_dataset_data_loader.cpython-38.pyc │ │ ├── data_loader.cpython-35.pyc │ │ ├── data_loader.cpython-36.pyc │ │ ├── data_loader.cpython-38.pyc │ │ ├── image_folder.cpython-35.pyc │ │ ├── image_folder.cpython-36.pyc │ │ ├── image_folder.cpython-38.pyc │ │ ├── keypoint.cpython-35.pyc │ │ └── keypoint.cpython-36.pyc │ ├── aligned.py │ ├── base_data_loader.py │ ├── base_dataset.py │ ├── custom_dataset_data_loader.py │ ├── data_loader.py │ ├── image_folder.py │ └── keypoint.py ├── losses │ ├── L1_plus_perceptualLoss.py │ ├── __init__.py │ └── __pycache__ │ │ ├── L1_plus_perceptualLoss.cpython-35.pyc │ │ ├── L1_plus_perceptualLoss.cpython-36.pyc │ │ ├── L1_plus_perceptualLoss.cpython-38.pyc │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ └── __init__.cpython-38.pyc ├── models │ ├── BiGraphGAN.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── BiGraphGAN.cpython-38.pyc │ │ ├── PATN.cpython-35.pyc │ │ ├── PATN.cpython-36.pyc │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── base_model.cpython-35.pyc │ │ ├── base_model.cpython-36.pyc │ │ ├── base_model.cpython-38.pyc │ │ ├── model_variants.cpython-35.pyc │ │ ├── model_variants.cpython-36.pyc │ │ ├── model_variants.cpython-38.pyc │ │ ├── models.cpython-35.pyc │ │ ├── models.cpython-36.pyc │ │ ├── models.cpython-38.pyc │ │ ├── networks.cpython-35.pyc │ │ ├── networks.cpython-36.pyc │ │ └── networks.cpython-38.pyc │ ├── base_model.py │ ├── model_variants.py │ ├── models.py │ ├── networks.py │ └── test_model.py ├── options │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── base_options.cpython-35.pyc │ │ ├── base_options.cpython-36.pyc │ │ ├── base_options.cpython-37.pyc │ │ ├── base_options.cpython-38.pyc │ │ ├── test_options.cpython-35.pyc │ │ ├── test_options.cpython-36.pyc │ │ ├── test_options.cpython-38.pyc │ │ ├── train_options.cpython-35.pyc │ │ ├── train_options.cpython-37.pyc │ │ └── train_options.cpython-38.pyc │ ├── base_options.py │ ├── base_options.pyc │ ├── test_options.py │ ├── train_options.py │ └── train_options.pyc ├── scripts │ ├── download_bigraphgan_model.sh │ └── download_bigraphgan_result.sh ├── test.py ├── test_facial.sh ├── tool │ ├── calPCKH_fashion.py │ ├── calPCKH_market.py │ ├── cmd.py │ ├── compute_coordinates.py │ ├── create_pairs_dataset.py │ ├── crop_fashion.py │ ├── crop_market.py │ ├── generate_fashion_datasets.py │ ├── generate_pose_map_fashion.py │ ├── generate_pose_map_market.py │ ├── getMetrics_fashion.py │ ├── getMetrics_market.py │ ├── inception_score.py │ ├── pose_utils.py │ ├── resize_fashion.py │ └── rm_insnorm_running_vars.py ├── train.py ├── train_facial.sh └── util │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── html.cpython-35.pyc │ ├── html.cpython-36.pyc │ ├── html.cpython-38.pyc │ ├── image_pool.cpython-35.pyc │ ├── image_pool.cpython-36.pyc │ ├── image_pool.cpython-38.pyc │ ├── util.cpython-35.pyc │ ├── util.cpython-36.pyc │ ├── util.cpython-37.pyc │ ├── util.cpython-38.pyc │ ├── visualizer.cpython-35.pyc │ ├── visualizer.cpython-36.pyc │ └── visualizer.cpython-38.pyc │ ├── get_data.py │ ├── html.py │ ├── image_pool.py │ ├── png.py │ ├── util.py │ ├── util.pyc │ └── visualizer.py ├── imgs ├── face_results.jpeg ├── fashion_results.jpg ├── market_results.jpg ├── method.jpg └── motivation.jpg ├── market_1501 ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── base_data_loader.cpython-35.pyc │ │ ├── base_data_loader.cpython-36.pyc │ │ ├── base_dataset.cpython-35.pyc │ │ ├── base_dataset.cpython-36.pyc │ │ ├── custom_dataset_data_loader.cpython-35.pyc │ │ ├── custom_dataset_data_loader.cpython-36.pyc │ │ ├── data_loader.cpython-35.pyc │ │ ├── data_loader.cpython-36.pyc │ │ ├── image_folder.cpython-35.pyc │ │ ├── image_folder.cpython-36.pyc │ │ ├── keypoint.cpython-35.pyc │ │ └── keypoint.cpython-36.pyc │ ├── base_data_loader.py │ ├── base_dataset.py │ ├── custom_dataset_data_loader.py │ ├── data_loader.py │ ├── image_folder.py │ └── keypoint.py ├── losses │ ├── L1_plus_perceptualLoss.py │ ├── __init__.py │ └── __pycache__ │ │ ├── L1_plus_perceptualLoss.cpython-35.pyc │ │ ├── L1_plus_perceptualLoss.cpython-36.pyc │ │ ├── __init__.cpython-35.pyc │ │ └── __init__.cpython-36.pyc ├── models │ ├── BiGraphGAN.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── PATN.cpython-35.pyc │ │ ├── PATN.cpython-36.pyc │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── base_model.cpython-35.pyc │ │ ├── base_model.cpython-36.pyc │ │ ├── model_variants.cpython-35.pyc │ │ ├── model_variants.cpython-36.pyc │ │ ├── models.cpython-35.pyc │ │ ├── models.cpython-36.pyc │ │ ├── networks.cpython-35.pyc │ │ └── networks.cpython-36.pyc │ ├── base_model.py │ ├── model_variants.py │ ├── models.py │ ├── networks.py │ └── test_model.py ├── options │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ │ ├── __init__.cpython-35.pyc │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── base_options.cpython-35.pyc │ │ ├── base_options.cpython-36.pyc │ │ ├── base_options.cpython-37.pyc │ │ ├── test_options.cpython-35.pyc │ │ ├── test_options.cpython-36.pyc │ │ ├── train_options.cpython-35.pyc │ │ └── train_options.cpython-37.pyc │ ├── base_options.py │ ├── base_options.pyc │ ├── test_options.py │ ├── train_options.py │ └── train_options.pyc ├── test.py ├── test_market.sh ├── test_market_pretrained.sh ├── tool │ ├── calPCKH_fashion.py │ ├── calPCKH_market.py │ ├── cmd.py │ ├── compute_coordinates.py │ ├── create_pairs_dataset.py │ ├── crop_fashion.py │ ├── crop_market.py │ ├── generate_fashion_datasets.py │ ├── generate_pose_map_fashion.py │ ├── generate_pose_map_market.py │ ├── getMetrics_fashion.py │ ├── getMetrics_market.py │ ├── inception_score.py │ ├── pose_utils.py │ ├── resize_fashion.py │ └── rm_insnorm_running_vars.py ├── train.py ├── train_market.sh └── util │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── html.cpython-35.pyc │ ├── html.cpython-36.pyc │ ├── image_pool.cpython-35.pyc │ ├── image_pool.cpython-36.pyc │ ├── util.cpython-35.pyc │ ├── util.cpython-36.pyc │ ├── util.cpython-37.pyc │ ├── visualizer.cpython-35.pyc │ └── visualizer.cpython-36.pyc │ ├── get_data.py │ ├── html.py │ ├── image_pool.py │ ├── png.py │ ├── util.py │ ├── util.pyc │ └── visualizer.py └── scripts ├── download_bigraphgan_model.sh └── download_bigraphgan_result.sh /deepfashion/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/data/__init__.py -------------------------------------------------------------------------------- /deepfashion/data/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/data/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/data/__pycache__/base_data_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/data/__pycache__/base_data_loader.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/data/__pycache__/base_data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/data/__pycache__/base_data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/data/__pycache__/base_dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/data/__pycache__/base_dataset.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/data/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/data/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/data/__pycache__/custom_dataset_data_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/data/__pycache__/custom_dataset_data_loader.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/data/__pycache__/custom_dataset_data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/data/__pycache__/custom_dataset_data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/data/__pycache__/data_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/data/__pycache__/data_loader.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/data/__pycache__/data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/data/__pycache__/data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/data/__pycache__/image_folder.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/data/__pycache__/image_folder.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/data/__pycache__/image_folder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/data/__pycache__/image_folder.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/data/__pycache__/keypoint.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/data/__pycache__/keypoint.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/data/__pycache__/keypoint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/data/__pycache__/keypoint.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataLoader(): 3 | def __init__(self): 4 | pass 5 | 6 | def initialize(self, opt): 7 | self.opt = opt 8 | pass 9 | 10 | def load_data(): 11 | return None 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /deepfashion/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | class BaseDataset(data.Dataset): 6 | def __init__(self): 7 | super(BaseDataset, self).__init__() 8 | 9 | def name(self): 10 | return 'BaseDataset' 11 | 12 | def initialize(self, opt): 13 | pass 14 | 15 | def get_transform(opt): 16 | transform_list = [] 17 | if opt.resize_or_crop == 'resize_and_crop': 18 | osize = [opt.loadSize, opt.loadSize] 19 | transform_list.append(transforms.Scale(osize, Image.BICUBIC)) 20 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 21 | elif opt.resize_or_crop == 'crop': 22 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 23 | elif opt.resize_or_crop == 'scale_width': 24 | transform_list.append(transforms.Lambda( 25 | lambda img: __scale_width(img, opt.fineSize))) 26 | elif opt.resize_or_crop == 'scale_width_and_crop': 27 | transform_list.append(transforms.Lambda( 28 | lambda img: __scale_width(img, opt.loadSize))) 29 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 30 | 31 | transform_list += [transforms.ToTensor(), 32 | transforms.Normalize((0.5, 0.5, 0.5), 33 | (0.5, 0.5, 0.5))] 34 | return transforms.Compose(transform_list) 35 | 36 | def __scale_width(img, target_width): 37 | ow, oh = img.size 38 | if (ow == target_width): 39 | return img 40 | w = target_width 41 | h = int(target_width * oh / ow) 42 | return img.resize((w, h), Image.BICUBIC) 43 | -------------------------------------------------------------------------------- /deepfashion/data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | 8 | if opt.dataset_mode == 'keypoint': 9 | from data.keypoint import KeyDataset 10 | dataset = KeyDataset() 11 | 12 | else: 13 | raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) 14 | 15 | print("dataset [%s] was created" % (dataset.name())) 16 | dataset.initialize(opt) 17 | return dataset 18 | 19 | 20 | class CustomDatasetDataLoader(BaseDataLoader): 21 | def name(self): 22 | return 'CustomDatasetDataLoader' 23 | 24 | def initialize(self, opt): 25 | BaseDataLoader.initialize(self, opt) 26 | self.dataset = CreateDataset(opt) 27 | self.dataloader = torch.utils.data.DataLoader( 28 | self.dataset, 29 | batch_size=opt.batchSize, 30 | shuffle=not opt.serial_batches, 31 | num_workers=int(opt.nThreads)) 32 | 33 | def load_data(self): 34 | return self 35 | 36 | def __len__(self): 37 | return min(len(self.dataset), self.opt.max_dataset_size) 38 | 39 | def __iter__(self): 40 | for i, data in enumerate(self.dataloader): 41 | if i >= self.opt.max_dataset_size: 42 | break 43 | yield data 44 | -------------------------------------------------------------------------------- /deepfashion/data/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | def CreateDataLoader(opt): 3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 4 | data_loader = CustomDatasetDataLoader() 5 | print(data_loader.name()) 6 | data_loader.initialize(opt) 7 | return data_loader 8 | -------------------------------------------------------------------------------- /deepfashion/data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | 34 | return images 35 | 36 | 37 | def default_loader(path): 38 | return Image.open(path).convert('RGB') 39 | 40 | 41 | class ImageFolder(data.Dataset): 42 | 43 | def __init__(self, root, transform=None, return_paths=False, 44 | loader=default_loader): 45 | imgs = make_dataset(root) 46 | if len(imgs) == 0: 47 | raise(RuntimeError("Found 0 images in: " + root + "\n" 48 | "Supported image extensions are: " + 49 | ",".join(IMG_EXTENSIONS))) 50 | 51 | self.root = root 52 | self.imgs = imgs 53 | self.transform = transform 54 | self.return_paths = return_paths 55 | self.loader = loader 56 | 57 | def __getitem__(self, index): 58 | path = self.imgs[index] 59 | img = self.loader(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.return_paths: 63 | return img, path 64 | else: 65 | return img 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | -------------------------------------------------------------------------------- /deepfashion/losses/L1_plus_perceptualLoss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import torchvision.models as models 8 | 9 | class L1_plus_perceptualLoss(nn.Module): 10 | def __init__(self, lambda_L1, lambda_perceptual, perceptual_layers, gpu_ids, percep_is_l1): 11 | super(L1_plus_perceptualLoss, self).__init__() 12 | 13 | self.lambda_L1 = lambda_L1 14 | self.lambda_perceptual = lambda_perceptual 15 | self.gpu_ids = gpu_ids 16 | 17 | self.percep_is_l1 = percep_is_l1 18 | 19 | vgg = models.vgg19(pretrained=True).features 20 | self.vgg_submodel = nn.Sequential() 21 | for i,layer in enumerate(list(vgg)): 22 | self.vgg_submodel.add_module(str(i),layer) 23 | if i == perceptual_layers: 24 | break 25 | self.vgg_submodel = torch.nn.DataParallel(self.vgg_submodel, device_ids=gpu_ids).cuda() 26 | 27 | print(self.vgg_submodel) 28 | 29 | def forward(self, inputs, targets): 30 | if self.lambda_L1 == 0 and self.lambda_perceptual == 0: 31 | return torch.zeros(1).cuda(), torch.zeros(1), torch.zeros(1) 32 | # normal L1 33 | loss_l1 = F.l1_loss(inputs, targets) * self.lambda_L1 34 | 35 | # perceptual L1 36 | mean = torch.FloatTensor(3) 37 | mean[0] = 0.485 38 | mean[1] = 0.456 39 | mean[2] = 0.406 40 | mean = mean.resize(1, 3, 1, 1).cuda() 41 | 42 | std = torch.FloatTensor(3) 43 | std[0] = 0.229 44 | std[1] = 0.224 45 | std[2] = 0.225 46 | std = std.resize(1, 3, 1, 1).cuda() 47 | 48 | fake_p2_norm = (inputs + 1)/2 # [-1, 1] => [0, 1] 49 | fake_p2_norm = (fake_p2_norm - mean)/std 50 | 51 | input_p2_norm = (targets + 1)/2 # [-1, 1] => [0, 1] 52 | input_p2_norm = (input_p2_norm - mean)/std 53 | 54 | 55 | fake_p2_norm = self.vgg_submodel(fake_p2_norm) 56 | input_p2_norm = self.vgg_submodel(input_p2_norm) 57 | input_p2_norm_no_grad = input_p2_norm.detach() 58 | 59 | if self.percep_is_l1 == 1: 60 | # use l1 for perceptual loss 61 | loss_perceptual = F.l1_loss(fake_p2_norm, input_p2_norm_no_grad) * self.lambda_perceptual 62 | else: 63 | # use l2 for perceptual loss 64 | loss_perceptual = F.mse_loss(fake_p2_norm, input_p2_norm_no_grad) * self.lambda_perceptual 65 | 66 | loss = loss_l1 + loss_perceptual 67 | 68 | return loss, loss_l1, loss_perceptual 69 | 70 | -------------------------------------------------------------------------------- /deepfashion/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/losses/__init__.py -------------------------------------------------------------------------------- /deepfashion/losses/__pycache__/L1_plus_perceptualLoss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/losses/__pycache__/L1_plus_perceptualLoss.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/losses/__pycache__/L1_plus_perceptualLoss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/losses/__pycache__/L1_plus_perceptualLoss.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/losses/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/losses/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/losses/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/losses/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/models/__init__.py -------------------------------------------------------------------------------- /deepfashion/models/__pycache__/PATN.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/models/__pycache__/PATN.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/models/__pycache__/PATN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/models/__pycache__/PATN.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/models/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/models/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/models/__pycache__/base_model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/models/__pycache__/base_model.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/models/__pycache__/base_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/models/__pycache__/base_model.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/models/__pycache__/model_variants.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/models/__pycache__/model_variants.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/models/__pycache__/model_variants.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/models/__pycache__/model_variants.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/models/__pycache__/models.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/models/__pycache__/models.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/models/__pycache__/models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/models/__pycache__/models.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/models/__pycache__/networks.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/models/__pycache__/networks.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/models/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/models/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class BaseModel(nn.Module): 7 | 8 | def __init__(self): 9 | super(BaseModel, self).__init__() 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, opt): 15 | self.opt = opt 16 | self.gpu_ids = opt.gpu_ids 17 | self.isTrain = opt.isTrain 18 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 19 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 20 | 21 | def set_input(self, input): 22 | self.input = input 23 | 24 | def forward(self): 25 | pass 26 | 27 | # used in test time, no backprop 28 | def test(self): 29 | pass 30 | 31 | def get_image_paths(self): 32 | pass 33 | 34 | def optimize_parameters(self): 35 | pass 36 | 37 | def get_current_visuals(self): 38 | return self.input 39 | 40 | def get_current_errors(self): 41 | return {} 42 | 43 | def save(self, label): 44 | pass 45 | 46 | # helper saving function that can be used by subclasses 47 | def save_network(self, network, network_label, epoch_label, gpu_ids): 48 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 49 | save_path = os.path.join(self.save_dir, save_filename) 50 | torch.save(network.cpu().state_dict(), save_path) 51 | if len(gpu_ids) and torch.cuda.is_available(): 52 | network.cuda(gpu_ids[0]) 53 | 54 | # helper loading function that can be used by subclasses 55 | def load_network(self, network, network_label, epoch_label): 56 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 57 | save_path = os.path.join(self.save_dir, save_filename) 58 | network.load_state_dict(torch.load(save_path)) 59 | 60 | # update learning rate (called once every epoch) 61 | def update_learning_rate(self): 62 | for scheduler in self.schedulers: 63 | scheduler.step() 64 | lr = self.optimizers[0].param_groups[0]['lr'] 65 | print('learning rate = %.7f' % lr) 66 | -------------------------------------------------------------------------------- /deepfashion/models/models.py: -------------------------------------------------------------------------------- 1 | 2 | def create_model(opt): 3 | model = None 4 | print(opt.model) 5 | 6 | if opt.model == 'BiGraphGAN': 7 | assert opt.dataset_mode == 'keypoint' 8 | from .BiGraphGAN import TransferModel 9 | model = TransferModel() 10 | 11 | else: 12 | raise ValueError("Model [%s] not recognized." % opt.model) 13 | model.initialize(opt) 14 | print("model [%s] was created" % (model.name())) 15 | return model 16 | -------------------------------------------------------------------------------- /deepfashion/models/test_model.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | from collections import OrderedDict 3 | import util.util as util 4 | from .base_model import BaseModel 5 | from . import networks 6 | 7 | 8 | class TestModel(BaseModel): 9 | def name(self): 10 | return 'TestModel' 11 | 12 | def initialize(self, opt): 13 | assert(not opt.isTrain) 14 | BaseModel.initialize(self, opt) 15 | self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) 16 | 17 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, 18 | opt.ngf, opt.which_model_netG, 19 | opt.norm, not opt.no_dropout, 20 | opt.init_type, 21 | self.gpu_ids) 22 | which_epoch = opt.which_epoch 23 | self.load_network(self.netG, 'G', which_epoch) 24 | 25 | print('---------- Networks initialized -------------') 26 | networks.print_network(self.netG) 27 | print('-----------------------------------------------') 28 | 29 | def set_input(self, input): 30 | # we need to use single_dataset mode 31 | input_A = input['A'] 32 | self.input_A.resize_(input_A.size()).copy_(input_A) 33 | self.image_paths = input['A_paths'] 34 | 35 | def test(self): 36 | self.real_A = Variable(self.input_A) 37 | self.fake_B = self.netG(self.real_A) 38 | 39 | # get image paths 40 | def get_image_paths(self): 41 | return self.image_paths 42 | 43 | def get_current_visuals(self): 44 | real_A = util.tensor2im(self.real_A.data) 45 | fake_B = util.tensor2im(self.fake_B.data) 46 | return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) 47 | -------------------------------------------------------------------------------- /deepfashion/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/options/__init__.py -------------------------------------------------------------------------------- /deepfashion/options/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/options/__init__.pyc -------------------------------------------------------------------------------- /deepfashion/options/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/options/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/options/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/options/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/options/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/options/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /deepfashion/options/__pycache__/base_options.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/options/__pycache__/base_options.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/options/__pycache__/base_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/options/__pycache__/base_options.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/options/__pycache__/base_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/options/__pycache__/base_options.cpython-37.pyc -------------------------------------------------------------------------------- /deepfashion/options/__pycache__/test_options.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/options/__pycache__/test_options.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/options/__pycache__/test_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/options/__pycache__/test_options.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/options/__pycache__/train_options.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/options/__pycache__/train_options.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/options/__pycache__/train_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/options/__pycache__/train_options.cpython-37.pyc -------------------------------------------------------------------------------- /deepfashion/options/base_options.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/options/base_options.pyc -------------------------------------------------------------------------------- /deepfashion/options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 8 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 9 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 10 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 11 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 12 | self.parser.add_argument('--how_many', type=int, default=200, help='how many test images to run') 13 | 14 | self.isTrain = False 15 | -------------------------------------------------------------------------------- /deepfashion/options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 8 | self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 9 | self.parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 10 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 11 | self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 12 | self.parser.add_argument('--save_epoch_freq', type=int, default=20, help='frequency of saving checkpoints at the end of epochs') 13 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 14 | self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 15 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 16 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 17 | self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 18 | self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 19 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 20 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 21 | self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 22 | self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for L1 loss') 23 | self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for perceptual L1 loss') 24 | self.parser.add_argument('--lambda_GAN', type=float, default=5.0, help='weight of GAN loss') 25 | 26 | self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 27 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 28 | self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') 29 | self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 30 | 31 | self.parser.add_argument('--L1_type', type=str, default='origin', help='use which kind of L1 loss. (origin|l1_plus_perL1)') 32 | self.parser.add_argument('--perceptual_layers', type=int, default=3, help='index of vgg layer for extracting perceptual features.') 33 | self.parser.add_argument('--percep_is_l1', type=int, default=1, help='type of perceptual loss: l1 or l2') 34 | self.parser.add_argument('--no_dropout_D', action='store_true', help='no dropout for the discriminator') 35 | self.parser.add_argument('--DG_ratio', type=int, default=1, help='how many times for D training after training G once') 36 | 37 | 38 | 39 | 40 | self.isTrain = True 41 | -------------------------------------------------------------------------------- /deepfashion/options/train_options.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/options/train_options.pyc -------------------------------------------------------------------------------- /deepfashion/test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from options.test_options import TestOptions 4 | from data.data_loader import CreateDataLoader 5 | from models.models import create_model 6 | from util.visualizer import Visualizer 7 | from util import html 8 | import time 9 | 10 | opt = TestOptions().parse() 11 | opt.nThreads = 1 # test code only supports nThreads = 1 12 | opt.batchSize = 1 # test code only supports batchSize = 1 13 | opt.serial_batches = True # no shuffle 14 | opt.no_flip = True # no flip 15 | 16 | data_loader = CreateDataLoader(opt) 17 | dataset = data_loader.load_data() 18 | model = create_model(opt) 19 | visualizer = Visualizer(opt) 20 | # create website 21 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 22 | 23 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 24 | 25 | print(opt.how_many) 26 | print(len(dataset)) 27 | 28 | model = model.eval() 29 | print(model.training) 30 | 31 | opt.how_many = 999999 32 | # test 33 | for i, data in enumerate(dataset): 34 | print(' process %d/%d img ..'%(i,opt.how_many)) 35 | if i >= opt.how_many: 36 | break 37 | model.set_input(data) 38 | startTime = time.time() 39 | model.test() 40 | endTime = time.time() 41 | print(endTime-startTime) 42 | visuals = model.get_current_visuals() 43 | img_path = model.get_image_paths() 44 | img_path = [img_path] 45 | print(img_path) 46 | visualizer.save_images(webpage, visuals, img_path) 47 | 48 | webpage.save() 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /deepfashion/test_deepfashion.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0; 2 | python test.py --dataroot ./SelectionGAN/person_transfer/datasets/fashion_data/ --name fashion_exp --model BiGraphGAN --phase test --dataset_mode keypoint --norm instance --batchSize 1 --resize_or_crop no --gpu_ids 0 --BP_input_nc 18 --no_flip --which_model_netG Graph --checkpoints_dir ./checkpoints --pairLst ./SelectionGAN/person_transfer/datasets/fashion_data/fasion-resize-pairs-test.csv --which_epoch 700 --results_dir ./results --display_id 0; 3 | -------------------------------------------------------------------------------- /deepfashion/test_deepfashion_pretrained.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0; 2 | python test.py --dataroot ./SelectionGAN/person_transfer/datasets/fashion_data/ --name deepfashion_pretrained --model BiGraphGAN --phase test --dataset_mode keypoint --norm instance --batchSize 1 --resize_or_crop no --gpu_ids 0 --BP_input_nc 18 --no_flip --which_model_netG Graph --checkpoints_dir ./BiGraphGAN/scripts/checkpoints --pairLst ./SelectionGAN/person_transfer/datasets/fashion_data/fasion-resize-pairs-test.csv --which_epoch latest --results_dir ./results --display_id 0; -------------------------------------------------------------------------------- /deepfashion/tool/calPCKH_fashion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pandas as pd 3 | import numpy as np 4 | import json 5 | 6 | MISSING_VALUE = -1 7 | 8 | PARTS_SEL = [0, 1, 14, 15, 16, 17] 9 | 10 | target_annotation = './fashion_data/fasion-resize-annotation-test.csv' 11 | pred_annotation = './results/fashion_PATN/pckh.csv' 12 | 13 | 14 | ''' 15 | hz: head size 16 | alpha: norm factor 17 | px, py: predict coords 18 | tx, ty: target coords 19 | ''' 20 | def isRight(px, py, tx, ty, hz, alpha): 21 | if px == -1 or py == -1 or tx == -1 or ty == -1: 22 | return 0 23 | 24 | if abs(px - tx) < hz[0] * alpha and abs(py - ty) < hz[1] * alpha: 25 | return 1 26 | else: 27 | return 0 28 | 29 | 30 | def how_many_right_seq(px, py, tx, ty, hz, alpha): 31 | nRight = 0 32 | for i in range(len(px)): 33 | nRight = nRight + isRight(px[i], py[i], tx[i], ty[i], hz, alpha) 34 | 35 | return nRight 36 | 37 | 38 | def ValidPoints(tx): 39 | nValid = 0 40 | for item in tx: 41 | if item != -1: 42 | nValid = nValid + 1 43 | return nValid 44 | 45 | 46 | def get_head_wh(x_coords, y_coords): 47 | final_w, final_h = -1, -1 48 | component_count = 0 49 | save_componets = [] 50 | for component in PARTS_SEL: 51 | if x_coords[component] == MISSING_VALUE or y_coords[component] == MISSING_VALUE: 52 | continue 53 | else: 54 | component_count += 1 55 | save_componets.append([x_coords[component], y_coords[component]]) 56 | if component_count >= 2: 57 | x_cords = [] 58 | y_cords = [] 59 | for component in save_componets: 60 | x_cords.append(component[0]) 61 | y_cords.append(component[1]) 62 | xmin = min(x_cords) 63 | xmax = max(x_cords) 64 | ymin = min(y_cords) 65 | ymax = max(y_cords) 66 | final_w = xmax - xmin 67 | final_h = ymax - ymin 68 | return final_w, final_h 69 | 70 | 71 | tAnno = pd.read_csv(target_annotation, sep=':') 72 | pAnno = pd.read_csv(pred_annotation, sep=':') 73 | 74 | pRows = pAnno.shape[0] 75 | 76 | nAll = 0 77 | nCorrect = 0 78 | alpha = 0.5 79 | for i in range(pRows): 80 | pValues = pAnno.iloc[i].values 81 | pname = pValues[0] 82 | pycords = json.loads(pValues[1]) # list of numbers 83 | pxcords = json.loads(pValues[2]) 84 | 85 | if '_vis' in pname: 86 | tname = pname[:-8] 87 | else: 88 | tname = pname[:-4] 89 | 90 | if '___' in tname: 91 | tname = tname.split('___')[1] 92 | else: 93 | tname = tname.split('jpg_')[1] 94 | 95 | print(tname) 96 | tValues = tAnno.query('name == "%s"' % (tname)).values[0] 97 | tycords = json.loads(tValues[1]) # list of numbers 98 | txcords = json.loads(tValues[2]) 99 | 100 | 101 | xBox, yBox = get_head_wh(txcords, tycords) 102 | if xBox == -1 or yBox == -1: 103 | continue 104 | 105 | head_size = (xBox, yBox) 106 | nAll = nAll + ValidPoints(tycords) 107 | nCorrect = nCorrect + how_many_right_seq(pxcords, pycords, txcords, tycords, head_size, alpha) 108 | 109 | print('%d/%d %f' % (nCorrect, nAll, nCorrect * 1.0 / nAll)) 110 | -------------------------------------------------------------------------------- /deepfashion/tool/calPCKH_market.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pandas as pd 3 | import numpy as np 4 | import json 5 | 6 | MISSING_VALUE = -1 7 | 8 | PARTS_SEL = [0, 1, 14, 15, 16, 17] 9 | 10 | # fix the PATH 11 | target_annotation = './market_data/market-annotation-test.csv' 12 | pred_annotation = '/results/market_PATN/pckh.csv' 13 | 14 | 15 | ''' 16 | hz: head size 17 | alpha: norm factor 18 | px, py: predict coords 19 | tx, ty: target coords 20 | ''' 21 | def isRight(px, py, tx, ty, hz, alpha): 22 | if px == -1 or py == -1 or tx == -1 or ty == -1: 23 | return 0 24 | 25 | if abs(px-tx) < hz[0]*alpha and abs(py-ty) < hz[1]*alpha: 26 | return 1 27 | else: 28 | return 0 29 | 30 | def how_many_right_seq(px, py, tx, ty, hz, alpha): 31 | nRight = 0 32 | for i in range(len(px)): 33 | nRight = nRight + isRight(px[i], py[i], tx[i], ty[i], hz, alpha) 34 | 35 | return nRight 36 | 37 | def ValidPoints(tx): 38 | nValid = 0 39 | for item in tx: 40 | if item != -1: 41 | nValid = nValid + 1 42 | return nValid 43 | 44 | def get_head_wh(x_coords, y_coords): 45 | final_w, final_h = -1, -1 46 | component_count = 0 47 | save_componets = [] 48 | for component in PARTS_SEL: 49 | if x_coords[component] == MISSING_VALUE or y_coords[component] == MISSING_VALUE: 50 | continue 51 | else: 52 | component_count += 1 53 | save_componets.append([x_coords[component], y_coords[component]]) 54 | if component_count >= 2: 55 | x_cords = [] 56 | y_cords = [] 57 | for component in save_componets: 58 | x_cords.append(component[0]) 59 | y_cords.append(component[1]) 60 | xmin = min(x_cords) 61 | xmax = max(x_cords) 62 | ymin = min(y_cords) 63 | ymax = max(y_cords) 64 | final_w = xmax - xmin 65 | final_h = ymax - ymin 66 | return final_w, final_h 67 | 68 | 69 | 70 | 71 | 72 | tAnno = pd.read_csv(target_annotation, sep=':') 73 | pAnno = pd.read_csv(pred_annotation, sep=':') 74 | 75 | pRows = pAnno.shape[0] 76 | 77 | nAll = 0 78 | nCorrect = 0 79 | alpha = 0.5 80 | for i in range(pRows): 81 | pValues = pAnno.iloc[i].values 82 | pname = pValues[0] 83 | pycords = json.loads(pValues[1]) #list of numbers 84 | pxcords = json.loads(pValues[2]) 85 | 86 | if '_vis' in pname: 87 | tname = pname[:-8] 88 | else: 89 | tname = pname[:-4] 90 | 91 | if '___' in tname: 92 | tname = tname.split('___')[1] 93 | else: 94 | tname = tname.split('jpg_')[1] 95 | 96 | print(tname) 97 | tValues = tAnno.query('name == "%s"' %(tname)).values[0] 98 | tycords = json.loads(tValues[1]) #list of numbers 99 | txcords = json.loads(tValues[2]) 100 | 101 | xBox, yBox = get_head_wh(txcords, tycords) 102 | if xBox == -1 or yBox == -1: 103 | continue 104 | 105 | head_size = (xBox, yBox) 106 | nAll = nAll + ValidPoints(tycords) 107 | nCorrect = nCorrect + how_many_right_seq(pxcords, pycords, txcords, tycords, head_size, alpha) 108 | 109 | 110 | print('%d/%d %f' %(nCorrect, nAll, nCorrect*1.0/nAll)) 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /deepfashion/tool/create_pairs_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pose_utils 3 | from itertools import permutations 4 | 5 | LABELS = ['nose', 'neck', 'Rsho', 'Relb', 'Rwri', 'Lsho', 'Lelb', 'Lwri', 6 | 'Rhip', 'Rkne', 'Rank', 'Lhip', 'Lkne', 'Lank', 'Leye', 'Reye', 'Lear', 'Rear'] 7 | 8 | MISSING_VALUE = -1 9 | 10 | def give_name_to_keypoints(array): 11 | res = {} 12 | for i, name in enumerate(LABELS): 13 | if array[i][0] != MISSING_VALUE and array[i][1] != MISSING_VALUE: 14 | res[name] = array[i][::-1] 15 | return res 16 | 17 | 18 | def pose_check_valid(kp_array): 19 | kp = give_name_to_keypoints(kp_array) 20 | return check_keypoints_present(kp, ['Rhip', 'Lhip', 'Lsho', 'Rsho']) 21 | 22 | 23 | def check_keypoints_present(kp, kp_names): 24 | result = True 25 | for name in kp_names: 26 | result = result and (name in kp) 27 | return result 28 | 29 | def filter_not_valid(df_keypoints): 30 | def check_valid(x): 31 | kp_array = pose_utils.load_pose_cords_from_strings(x['keypoints_y'], x['keypoints_x']) 32 | distractor = x['name'].startswith('-1') or x['name'].startswith('0000') 33 | return pose_check_valid(kp_array) and not distractor 34 | return df_keypoints[df_keypoints.apply(check_valid, axis=1)].copy() 35 | 36 | 37 | def make_pairs(df): 38 | persons = df.apply(lambda x: '_'.join(x['name'].split('_')[0:1]), axis=1) 39 | df['person'] = persons 40 | fr, to = [], [] 41 | for person in pd.unique(persons): 42 | pairs = zip(*list(permutations(df[df['person'] == person]['name'], 2))) 43 | if len(pairs) != 0: 44 | fr += list(pairs[0]) 45 | to += list(pairs[1]) 46 | pair_df = pd.DataFrame(index=range(len(fr))) 47 | pair_df['from'] = fr 48 | pair_df['to'] = to 49 | return pair_df 50 | 51 | 52 | if __name__ == "__main__": 53 | images_for_test = 12000 54 | 55 | annotations_file_train = './market_data/market-annotation-test.csv' 56 | pairs_file_train = './market_data/example_market-pairs-train.csv' 57 | 58 | df_keypoints = pd.read_csv(annotations_file_train, sep=':') 59 | df = filter_not_valid(df_keypoints) 60 | print ('Compute pair dataset for train...') 61 | pairs_df_train = make_pairs(df) 62 | print ('Number of pairs: %s' % len(pairs_df_train)) 63 | pairs_df_train.to_csv(pairs_file_train, index=False) 64 | 65 | annotations_file_test= './market_data/market-annotation-test.csv' 66 | pairs_file_test = './market_data/example_market-pairs-test.csv' 67 | 68 | print ('Compute pair dataset for test...') 69 | df_keypoints = pd.read_csv(annotations_file_test, sep=':') 70 | df = filter_not_valid(df_keypoints) 71 | pairs_df_test = make_pairs(df) 72 | pairs_df_test = pairs_df_test.sample(n=min(images_for_test, pairs_df_test.shape[0]), replace=False, random_state=0) 73 | print ('Number of pairs: %s' % len(pairs_df_test)) 74 | pairs_df_test.to_csv(pairs_file_test, index=False) 75 | 76 | -------------------------------------------------------------------------------- /deepfashion/tool/crop_fashion.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | 4 | img_dir = './results/fashion_PATN_test/test_latest/images' 5 | save_dir = './results/fashion_PATN_test/test_latest/images_crop' 6 | 7 | if not os.path.exists(save_dir): 8 | os.mkdir(save_dir) 9 | 10 | cnt = 0 11 | 12 | for item in os.listdir(img_dir): 13 | if not item.endswith('.jpg') and not item.endswith('.png'): 14 | continue 15 | cnt = cnt + 1 16 | print('%d/8570 ...' %(cnt)) 17 | img = Image.open(os.path.join(img_dir, item)) 18 | imgcrop = img.crop((704, 0, 880, 256)) 19 | imgcrop.save(os.path.join(save_dir, item)) 20 | -------------------------------------------------------------------------------- /deepfashion/tool/crop_market.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | 4 | img_dir = './results/market_PATN_test/test_latest/images' 5 | save_dir = './results/market_PATN_test/test_latest/images_crop' 6 | 7 | if not os.path.exists(save_dir): 8 | os.mkdir(save_dir) 9 | 10 | cnt = 0 11 | for item in os.listdir(img_dir): 12 | if not item.endswith('.jpg') and not item.endswith('.png'): 13 | continue 14 | cnt = cnt + 1 15 | print('%d/12000 ...' %(cnt)) 16 | img = Image.open(os.path.join(img_dir, item)) 17 | # for 5 split 18 | imgcrop = img.crop((256, 0, 320, 128)) 19 | imgcrop.save(os.path.join(save_dir, item)) 20 | -------------------------------------------------------------------------------- /deepfashion/tool/generate_fashion_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from PIL import Image 4 | 5 | IMG_EXTENSIONS = [ 6 | '.jpg', '.JPG', '.jpeg', '.JPEG', 7 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 8 | ] 9 | 10 | def is_image_file(filename): 11 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 12 | 13 | def make_dataset(dir): 14 | images = [] 15 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 16 | new_root = './fashion_data' 17 | if not os.path.exists(new_root): 18 | os.mkdir(new_root) 19 | 20 | train_root = './fashion_data/train' 21 | if not os.path.exists(train_root): 22 | os.mkdir(train_root) 23 | 24 | test_root = './fashion_data/test' 25 | if not os.path.exists(test_root): 26 | os.mkdir(test_root) 27 | 28 | train_images = [] 29 | train_f = open('./fashion_data/train.lst', 'r') 30 | for lines in train_f: 31 | lines = lines.strip() 32 | if lines.endswith('.jpg'): 33 | train_images.append(lines) 34 | 35 | test_images = [] 36 | test_f = open('./fashion_data/test.lst', 'r') 37 | for lines in test_f: 38 | lines = lines.strip() 39 | if lines.endswith('.jpg'): 40 | test_images.append(lines) 41 | 42 | print(train_images, test_images) 43 | 44 | 45 | for root, _, fnames in sorted(os.walk(dir)): 46 | for fname in fnames: 47 | if is_image_file(fname): 48 | path = os.path.join(root, fname) 49 | path_names = path.split('/') 50 | # path_names[2] = path_names[2].replace('_', '') 51 | path_names[3] = path_names[3].replace('_', '') 52 | path_names[4] = path_names[4].split('_')[0] + "_" + "".join(path_names[4].split('_')[1:]) 53 | path_names = "".join(path_names) 54 | # new_path = os.path.join(root, path_names) 55 | img = Image.open(path) 56 | imgcrop = img.crop((40, 0, 216, 256)) 57 | if new_path in train_images: 58 | imgcrop.save(os.path.join(train_root, path_names)) 59 | elif new_path in test_images: 60 | imgcrop.save(os.path.join(test_root, path_names)) 61 | 62 | make_dataset('./fashion') 63 | -------------------------------------------------------------------------------- /deepfashion/tool/generate_pose_map_fashion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import json 4 | import os 5 | 6 | MISSING_VALUE = -1 7 | # fix PATH 8 | img_dir = 'fashion_data' #raw image path 9 | annotations_file = 'fashion_data/fasion-resize-annotation-train.csv' #pose annotation path 10 | save_path = 'fashion_data/trainK' #path to store pose maps 11 | 12 | def load_pose_cords_from_strings(y_str, x_str): 13 | y_cords = json.loads(y_str) 14 | x_cords = json.loads(x_str) 15 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1) 16 | 17 | def cords_to_map(cords, img_size, sigma=6): 18 | result = np.zeros(img_size + cords.shape[0:1], dtype='uint8') 19 | for i, point in enumerate(cords): 20 | if point[0] == MISSING_VALUE or point[1] == MISSING_VALUE: 21 | continue 22 | xx, yy = np.meshgrid(np.arange(img_size[1]), np.arange(img_size[0])) 23 | result[..., i] = np.exp(-((yy - point[0]) ** 2 + (xx - point[1]) ** 2) / (2 * sigma ** 2)) 24 | # result[..., i] = np.where(((yy - point[0]) ** 2 + (xx - point[1]) ** 2) < (sigma ** 2), 1, 0) 25 | return result 26 | 27 | def compute_pose(image_dir, annotations_file, savePath, sigma): 28 | annotations_file = pd.read_csv(annotations_file, sep=':') 29 | annotations_file = annotations_file.set_index('name') 30 | image_size = (256, 176) 31 | cnt = len(annotations_file) 32 | for i in range(cnt): 33 | print('processing %d / %d ...' %(i, cnt)) 34 | row = annotations_file.iloc[i] 35 | name = row.name 36 | print(savePath, name) 37 | file_name = os.path.join(savePath, name + '.npy') 38 | kp_array = load_pose_cords_from_strings(row.keypoints_y, row.keypoints_x) 39 | pose = cords_to_map(kp_array, image_size, sigma) 40 | np.save(file_name, pose) 41 | # input() 42 | 43 | compute_pose(img_dir, annotations_file, save_path) 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /deepfashion/tool/generate_pose_map_market.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import json 4 | import os 5 | 6 | MISSING_VALUE = -1 7 | 8 | img_dir = 'market_data/train' #raw image path 9 | annotations_file = 'market_data/market-annotation-train.csv' #pose annotation path 10 | save_path = 'market_data/trainK' #path to store pose maps 11 | 12 | def load_pose_cords_from_strings(y_str, x_str): 13 | y_cords = json.loads(y_str) 14 | x_cords = json.loads(x_str) 15 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1) 16 | 17 | def cords_to_map(cords, img_size, sigma=6): 18 | result = np.zeros(img_size + cords.shape[0:1], dtype='float32') 19 | for i, point in enumerate(cords): 20 | if point[0] == MISSING_VALUE or point[1] == MISSING_VALUE: 21 | continue 22 | xx, yy = np.meshgrid(np.arange(img_size[1]), np.arange(img_size[0])) 23 | result[..., i] = np.exp(-((yy - point[0]) ** 2 + (xx - point[1]) ** 2) / (2 * sigma ** 2)) 24 | return result 25 | 26 | def compute_pose(image_dir, annotations_file, savePath): 27 | annotations_file = pd.read_csv(annotations_file, sep=':') 28 | annotations_file = annotations_file.set_index('name') 29 | image_size = (128, 64) 30 | cnt = len(annotations_file) 31 | for i in range(cnt): 32 | print('processing %d / %d ...' %(i, cnt)) 33 | row = annotations_file.iloc[i] 34 | name = row.name 35 | print(savePath, name) 36 | file_name = os.path.join(savePath, name + '.npy') 37 | kp_array = load_pose_cords_from_strings(row.keypoints_y, row.keypoints_x) 38 | pose = cords_to_map(kp_array, image_size) 39 | np.save(file_name, pose) 40 | 41 | compute_pose(img_dir, annotations_file, save_path) 42 | 43 | -------------------------------------------------------------------------------- /deepfashion/tool/inception_score.py: -------------------------------------------------------------------------------- 1 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os.path 7 | import sys 8 | import tarfile 9 | 10 | import numpy as np 11 | from six.moves import urllib 12 | import tensorflow as tf 13 | import glob 14 | import scipy.misc 15 | import math 16 | import sys 17 | 18 | MODEL_DIR = '~/models' 19 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 20 | softmax = None 21 | 22 | # Call this function with list of images. Each of elements should be a 23 | # numpy array with values ranging from 0 to 255. 24 | def get_inception_score(images, splits=10): 25 | #assert(type(images) == list) 26 | assert(type(images[0]) == np.ndarray) 27 | assert(len(images[0].shape) == 3) 28 | assert(np.max(images[0]) > 10) 29 | assert(np.min(images[0]) >= 0.0) 30 | inps = [] 31 | for img in images: 32 | img = img.astype(np.float32) 33 | inps.append(np.expand_dims(img, 0)) 34 | bs = 10 35 | with tf.Session() as sess: 36 | preds = [] 37 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 38 | for i in range(n_batches): 39 | sys.stdout.write(".") 40 | sys.stdout.flush() 41 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 42 | inp = np.concatenate(inp, 0) 43 | pred = sess.run(softmax, {'ExpandDims:0': inp}) 44 | preds.append(pred) 45 | preds = np.concatenate(preds, 0) 46 | scores = [] 47 | for i in range(splits): 48 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 49 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 50 | kl = np.mean(np.sum(kl, 1)) 51 | scores.append(np.exp(kl)) 52 | return np.mean(scores), np.std(scores) 53 | 54 | # This function is called automatically. 55 | def _init_inception(): 56 | global softmax 57 | if not os.path.exists(MODEL_DIR): 58 | os.makedirs(MODEL_DIR) 59 | filename = DATA_URL.split('/')[-1] 60 | filepath = os.path.join(MODEL_DIR, filename) 61 | if not os.path.exists(filepath): 62 | def _progress(count, block_size, total_size): 63 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 64 | filename, float(count * block_size) / float(total_size) * 100.0)) 65 | sys.stdout.flush() 66 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 67 | print() 68 | statinfo = os.stat(filepath) 69 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 70 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 71 | with tf.gfile.FastGFile(os.path.join( 72 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 73 | graph_def = tf.GraphDef() 74 | graph_def.ParseFromString(f.read()) 75 | _ = tf.import_graph_def(graph_def, name='') 76 | # Works with an arbitrary minibatch size. 77 | with tf.Session() as sess: 78 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 79 | ops = pool3.graph.get_operations() 80 | for op_idx, op in enumerate(ops): 81 | for o in op.outputs: 82 | shape = o.get_shape() 83 | shape = [s.value for s in shape] 84 | new_shape = [] 85 | for j, s in enumerate(shape): 86 | if s == 1 and j == 0: 87 | new_shape.append(None) 88 | else: 89 | new_shape.append(s) 90 | o._shape = tf.TensorShape(new_shape) 91 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] 92 | logits = tf.matmul(tf.squeeze(pool3), w) 93 | softmax = tf.nn.softmax(logits) 94 | 95 | if softmax is None: 96 | _init_inception() 97 | -------------------------------------------------------------------------------- /deepfashion/tool/resize_fashion.py: -------------------------------------------------------------------------------- 1 | from skimage.io import imread, imsave 2 | from skimage.transform import resize 3 | import os 4 | import numpy as np 5 | import pandas as pd 6 | import json 7 | 8 | def resize_dataset(folder, new_folder, new_size = (256, 176), crop_bord=40): 9 | if not os.path.exists(new_folder): 10 | os.makedirs(new_folder) 11 | for name in os.listdir(folder): 12 | old_name = os.path.join(folder, name) 13 | new_name = os.path.join(new_folder, name) 14 | 15 | img = imread(old_name) 16 | if crop_bord == 0: 17 | pass 18 | else: 19 | img = img[:, crop_bord:-crop_bord] 20 | 21 | img = resize(img, new_size, preserve_range=True).astype(np.uint8) 22 | 23 | imsave(new_name, img) 24 | 25 | def resize_annotations(name, new_name, new_size = (256, 176), old_size = (256, 256), crop_bord=40): 26 | df = pd.read_csv(name, sep=':') 27 | 28 | ratio_y = new_size[0] / float(old_size[0]) 29 | ratio_x = new_size[1] / float(old_size[1] - 2 * crop_bord) 30 | 31 | def modify(values, ratio, crop): 32 | val = np.array(json.loads(values)) 33 | mask = val == -1 34 | val = ((val - crop) * ratio).astype(int) 35 | val[mask] = -1 36 | return str(list(val)) 37 | 38 | df['keypoints_y'] = df.apply(lambda row: modify(row['keypoints_y'], ratio_y, 0), axis=1) 39 | df['keypoints_x'] = df.apply(lambda row: modify(row['keypoints_x'], ratio_x, crop_bord), axis=1) 40 | 41 | df.to_csv(new_name, sep=':', index=False) 42 | 43 | 44 | root_dir = 'xxx' 45 | resize_dataset(root_dir + '/test', root_dir + 'fashion_resize/test') 46 | resize_annotations(root_dir + 'fasion-annotation-test.csv', root_dir + 'fasion-resize-annotation-test.csv') 47 | 48 | resize_dataset(root_dir + '/train', root_dir + 'fashion_resize/train') 49 | resize_annotations(root_dir + 'fasion-annotation-train.csv', root_dir + 'fasion-resize-annotation-train.csv') 50 | 51 | 52 | -------------------------------------------------------------------------------- /deepfashion/tool/rm_insnorm_running_vars.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ckp_path = './checkpoints/fashion_PATN/latest_net_netG.pth' 4 | save_path = './checkpoints/fashion_PATN_v1.0/latest_net_netG.pth' 5 | states_dict = torch.load(ckp_path) 6 | states_dict_new = states_dict.copy() 7 | for key in states_dict.keys(): 8 | if "running_var" in key or "running_mean" in key: 9 | del states_dict_new[key] 10 | 11 | torch.save(states_dict_new, save_path) -------------------------------------------------------------------------------- /deepfashion/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data.data_loader import CreateDataLoader 4 | from models.models import create_model 5 | from util.visualizer import Visualizer 6 | 7 | opt = TrainOptions().parse() 8 | data_loader = CreateDataLoader(opt) 9 | dataset = data_loader.load_data() 10 | dataset_size = len(data_loader) 11 | print('#training images = %d' % dataset_size) 12 | 13 | model = create_model(opt) 14 | visualizer = Visualizer(opt) 15 | total_steps = 0 16 | 17 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): 18 | epoch_start_time = time.time() 19 | epoch_iter = 0 20 | 21 | for i, data in enumerate(dataset): 22 | iter_start_time = time.time() 23 | visualizer.reset() 24 | total_steps += opt.batchSize 25 | epoch_iter += opt.batchSize 26 | model.set_input(data) 27 | model.optimize_parameters() 28 | 29 | if total_steps % opt.display_freq == 0: 30 | save_result = total_steps % opt.update_html_freq == 0 31 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 32 | 33 | if total_steps % opt.print_freq == 0: 34 | errors = model.get_current_errors() 35 | t = (time.time() - iter_start_time) / opt.batchSize 36 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 37 | if opt.display_id > 0: 38 | visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) 39 | 40 | if total_steps % opt.save_latest_freq == 0: 41 | print('saving the latest model (epoch %d, total_steps %d)' % 42 | (epoch, total_steps)) 43 | model.save('latest') 44 | 45 | if epoch % opt.save_epoch_freq == 0: 46 | print('saving the model at the end of epoch %d, iters %d' % 47 | (epoch, total_steps)) 48 | model.save('latest') 49 | model.save(epoch) 50 | 51 | print('End of epoch %d / %d \t Time Taken: %d sec' % 52 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 53 | model.update_learning_rate() 54 | -------------------------------------------------------------------------------- /deepfashion/train_deepfashion.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3; 2 | python train.py --dataroot ./SelectionGAN/person_transfer/datasets/fashion_data/ --name fashion_exp --model BiGraphGAN --lambda_GAN 5 --lambda_A 10 --lambda_B 10 --dataset_mode keypoint --n_layers 3 --norm instance --batchSize 32 --pool_size 0 --resize_or_crop no --gpu_ids 0 --BP_input_nc 18 --no_flip --which_model_netG Graph --niter 500 --niter_decay 200 --checkpoints_dir ./checkpoints --pairLst ./SelectionGAN/person_transfer/datasets/fashion_data/fasion-resize-pairs-train.csv --L1_type l1_plus_perL1 --n_layers_D 3 --with_D_PP 1 --with_D_PB 1 --display_id 0 --gpu_ids 0,1,2,3 3 | # --continue_train --which_epoch 680 --epoch_count 681 -------------------------------------------------------------------------------- /deepfashion/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/util/__init__.py -------------------------------------------------------------------------------- /deepfashion/util/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/util/__init__.pyc -------------------------------------------------------------------------------- /deepfashion/util/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/util/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /deepfashion/util/__pycache__/html.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/util/__pycache__/html.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/util/__pycache__/html.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/util/__pycache__/html.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/util/__pycache__/image_pool.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/util/__pycache__/image_pool.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/util/__pycache__/image_pool.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/util/__pycache__/image_pool.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/util/__pycache__/util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/util/__pycache__/util.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/util/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/util/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/util/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/util/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /deepfashion/util/__pycache__/visualizer.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/util/__pycache__/visualizer.cpython-35.pyc -------------------------------------------------------------------------------- /deepfashion/util/__pycache__/visualizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/util/__pycache__/visualizer.cpython-36.pyc -------------------------------------------------------------------------------- /deepfashion/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | self.doc = dominate.document(title=title) 18 | if reflesh > 0: 19 | with self.doc.head: 20 | meta(http_equiv="reflesh", content=str(reflesh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, width=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=os.path.join('images', link)): 41 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 42 | br() 43 | p(txt) 44 | 45 | def save(self): 46 | html_file = '%s/index.html' % self.web_dir 47 | f = open(html_file, 'wt') 48 | f.write(self.doc.render()) 49 | f.close() 50 | 51 | 52 | if __name__ == '__main__': 53 | html = HTML('web/', 'test_html') 54 | html.add_header('hello world') 55 | 56 | ims = [] 57 | txts = [] 58 | links = [] 59 | for n in range(4): 60 | ims.append('image_%d.png' % n) 61 | txts.append('text_%d' % n) 62 | links.append('image_%d.png' % n) 63 | html.add_images(ims, txts, links) 64 | html.save() 65 | -------------------------------------------------------------------------------- /deepfashion/util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | 7 | class ImagePool(): 8 | def __init__(self, pool_size): 9 | self.pool_size = pool_size 10 | if self.pool_size > 0: 11 | self.num_imgs = 0 12 | self.images = [] 13 | 14 | def query(self, images): 15 | if self.pool_size == 0: 16 | return Variable(images) 17 | return_images = [] 18 | for image in images: 19 | image = torch.unsqueeze(image, 0) 20 | if self.num_imgs < self.pool_size: 21 | self.num_imgs = self.num_imgs + 1 22 | self.images.append(image) 23 | return_images.append(image) 24 | else: 25 | p = random.uniform(0, 1) 26 | if p > 0.5: 27 | random_id = random.randint(0, self.pool_size-1) 28 | tmp = self.images[random_id].clone() 29 | self.images[random_id] = image 30 | return_images.append(tmp) 31 | else: 32 | return_images.append(image) 33 | return_images = Variable(torch.cat(return_images, 0)) 34 | return return_images 35 | -------------------------------------------------------------------------------- /deepfashion/util/png.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import zlib 3 | 4 | def encode(buf, width, height): 5 | """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """ 6 | assert (width * height * 3 == len(buf)) 7 | bpp = 3 8 | 9 | def raw_data(): 10 | # reverse the vertical line order and add null bytes at the start 11 | row_bytes = width * bpp 12 | for row_start in range((height - 1) * width * bpp, -1, -row_bytes): 13 | yield b'\x00' 14 | yield buf[row_start:row_start + row_bytes] 15 | 16 | def chunk(tag, data): 17 | return [ 18 | struct.pack("!I", len(data)), 19 | tag, 20 | data, 21 | struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag))) 22 | ] 23 | 24 | SIGNATURE = b'\x89PNG\r\n\x1a\n' 25 | COLOR_TYPE_RGB = 2 26 | COLOR_TYPE_RGBA = 6 27 | bit_depth = 8 28 | return b''.join( 29 | [ SIGNATURE ] + 30 | chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) + 31 | chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) + 32 | chunk(b'IEND', b'') 33 | ) 34 | -------------------------------------------------------------------------------- /deepfashion/util/util.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/deepfashion/util/util.pyc -------------------------------------------------------------------------------- /facial/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__init__.py -------------------------------------------------------------------------------- /facial/data/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/aligned.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/aligned.cpython-38.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/base_data_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/base_data_loader.cpython-35.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/base_data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/base_data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/base_data_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/base_data_loader.cpython-38.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/base_dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/base_dataset.cpython-35.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/base_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/base_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/custom_dataset_data_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/custom_dataset_data_loader.cpython-35.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/custom_dataset_data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/custom_dataset_data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/custom_dataset_data_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/custom_dataset_data_loader.cpython-38.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/data_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/data_loader.cpython-35.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/data_loader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/data_loader.cpython-38.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/image_folder.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/image_folder.cpython-35.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/image_folder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/image_folder.cpython-36.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/image_folder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/image_folder.cpython-38.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/keypoint.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/keypoint.cpython-35.pyc -------------------------------------------------------------------------------- /facial/data/__pycache__/keypoint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/data/__pycache__/keypoint.cpython-36.pyc -------------------------------------------------------------------------------- /facial/data/aligned.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import torchvision.transforms as transforms 4 | import torch 5 | from data.base_dataset import BaseDataset 6 | from data.image_folder import make_dataset 7 | from PIL import Image 8 | 9 | class AlignedDataset(BaseDataset): 10 | @staticmethod 11 | def modify_commandline_options(parser, is_train): 12 | return parser 13 | 14 | def initialize(self, opt): 15 | self.opt = opt 16 | self.root = opt.dataroot 17 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) 18 | self.AB_paths = sorted(make_dataset(self.dir_AB)) 19 | # assert(opt.resize_or_crop == 'resize_and_crop') 20 | 21 | def __getitem__(self, index): 22 | AB_path = self.AB_paths[index] 23 | ABCD = Image.open(AB_path).convert('RGB') 24 | w, h = ABCD.size 25 | w2 = int(w / 4) 26 | A = ABCD.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 27 | B = ABCD.crop((w2, 0, w2+w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 28 | C = ABCD.crop((w2+w2, 0, w2+w2+w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 29 | D = ABCD.crop((w2+w2+w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 30 | 31 | A = transforms.ToTensor()(A) 32 | B = transforms.ToTensor()(B) 33 | C = transforms.ToTensor()(C) 34 | D = transforms.ToTensor()(D) 35 | w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) 36 | h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) 37 | 38 | A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] 39 | B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] 40 | C = C[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] 41 | D = D[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] 42 | 43 | A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) 44 | B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B) 45 | C = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(C) 46 | D = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(D) 47 | 48 | if self.opt.which_direction == 'BtoA': 49 | input_nc = self.opt.output_nc 50 | output_nc = self.opt.input_nc 51 | else: 52 | input_nc = self.opt.input_nc 53 | output_nc = self.opt.output_nc 54 | 55 | if (not self.opt.no_flip) and random.random() < 0.5: 56 | idx = [i for i in range(A.size(2) - 1, -1, -1)] 57 | idx = torch.LongTensor(idx) 58 | A = A.index_select(2, idx) 59 | B = B.index_select(2, idx) 60 | 61 | if input_nc == 1: # RGB to gray 62 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 63 | A = tmp.unsqueeze(0) 64 | 65 | if output_nc == 1: # RGB to gray 66 | tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 67 | B = tmp.unsqueeze(0) 68 | 69 | return {'P1': A, 'P2': B, 'BP1': C, 'BP2': D, 70 | 'P1_path': AB_path, 'P2_path': AB_path} 71 | 72 | 73 | 74 | def __len__(self): 75 | return len(self.AB_paths) 76 | 77 | def name(self): 78 | return 'AlignedDataset' 79 | -------------------------------------------------------------------------------- /facial/data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataLoader(): 3 | def __init__(self): 4 | pass 5 | 6 | def initialize(self, opt): 7 | self.opt = opt 8 | pass 9 | 10 | def load_data(): 11 | return None 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /facial/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | class BaseDataset(data.Dataset): 6 | def __init__(self): 7 | super(BaseDataset, self).__init__() 8 | 9 | def name(self): 10 | return 'BaseDataset' 11 | 12 | def initialize(self, opt): 13 | pass 14 | 15 | def get_transform(opt): 16 | transform_list = [] 17 | if opt.resize_or_crop == 'resize_and_crop': 18 | osize = [opt.loadSize, opt.loadSize] 19 | transform_list.append(transforms.Scale(osize, Image.BICUBIC)) 20 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 21 | elif opt.resize_or_crop == 'crop': 22 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 23 | elif opt.resize_or_crop == 'scale_width': 24 | transform_list.append(transforms.Lambda( 25 | lambda img: __scale_width(img, opt.fineSize))) 26 | elif opt.resize_or_crop == 'scale_width_and_crop': 27 | transform_list.append(transforms.Lambda( 28 | lambda img: __scale_width(img, opt.loadSize))) 29 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 30 | 31 | transform_list += [transforms.ToTensor(), 32 | transforms.Normalize((0.5, 0.5, 0.5), 33 | (0.5, 0.5, 0.5))] 34 | return transforms.Compose(transform_list) 35 | 36 | def __scale_width(img, target_width): 37 | ow, oh = img.size 38 | if (ow == target_width): 39 | return img 40 | w = target_width 41 | h = int(target_width * oh / ow) 42 | return img.resize((w, h), Image.BICUBIC) 43 | -------------------------------------------------------------------------------- /facial/data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | 8 | if opt.dataset_mode == 'keypoint': 9 | from data.keypoint import KeyDataset 10 | dataset = KeyDataset() 11 | elif opt.dataset_mode == 'aligned': 12 | from data.aligned import AlignedDataset 13 | dataset = AlignedDataset() 14 | else: 15 | raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) 16 | 17 | print("dataset [%s] was created" % (dataset.name())) 18 | dataset.initialize(opt) 19 | return dataset 20 | 21 | 22 | class CustomDatasetDataLoader(BaseDataLoader): 23 | def name(self): 24 | return 'CustomDatasetDataLoader' 25 | 26 | def initialize(self, opt): 27 | BaseDataLoader.initialize(self, opt) 28 | self.dataset = CreateDataset(opt) 29 | self.dataloader = torch.utils.data.DataLoader( 30 | self.dataset, 31 | batch_size=opt.batchSize, 32 | shuffle=not opt.serial_batches, 33 | num_workers=int(opt.nThreads)) 34 | 35 | def load_data(self): 36 | return self 37 | 38 | def __len__(self): 39 | return min(len(self.dataset), self.opt.max_dataset_size) 40 | 41 | def __iter__(self): 42 | for i, data in enumerate(self.dataloader): 43 | if i >= self.opt.max_dataset_size: 44 | break 45 | yield data 46 | -------------------------------------------------------------------------------- /facial/data/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | def CreateDataLoader(opt): 3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 4 | data_loader = CustomDatasetDataLoader() 5 | print(data_loader.name()) 6 | data_loader.initialize(opt) 7 | return data_loader 8 | -------------------------------------------------------------------------------- /facial/data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | 34 | return images 35 | 36 | 37 | def default_loader(path): 38 | return Image.open(path).convert('RGB') 39 | 40 | 41 | class ImageFolder(data.Dataset): 42 | 43 | def __init__(self, root, transform=None, return_paths=False, 44 | loader=default_loader): 45 | imgs = make_dataset(root) 46 | if len(imgs) == 0: 47 | raise(RuntimeError("Found 0 images in: " + root + "\n" 48 | "Supported image extensions are: " + 49 | ",".join(IMG_EXTENSIONS))) 50 | 51 | self.root = root 52 | self.imgs = imgs 53 | self.transform = transform 54 | self.return_paths = return_paths 55 | self.loader = loader 56 | 57 | def __getitem__(self, index): 58 | path = self.imgs[index] 59 | img = self.loader(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.return_paths: 63 | return img, path 64 | else: 65 | return img 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | -------------------------------------------------------------------------------- /facial/losses/L1_plus_perceptualLoss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import torchvision.models as models 8 | 9 | class L1_plus_perceptualLoss(nn.Module): 10 | def __init__(self, lambda_L1, lambda_perceptual, perceptual_layers, gpu_ids, percep_is_l1): 11 | super(L1_plus_perceptualLoss, self).__init__() 12 | 13 | self.lambda_L1 = lambda_L1 14 | self.lambda_perceptual = lambda_perceptual 15 | self.gpu_ids = gpu_ids 16 | 17 | self.percep_is_l1 = percep_is_l1 18 | 19 | vgg = models.vgg19(pretrained=True).features 20 | self.vgg_submodel = nn.Sequential() 21 | for i,layer in enumerate(list(vgg)): 22 | self.vgg_submodel.add_module(str(i),layer) 23 | if i == perceptual_layers: 24 | break 25 | self.vgg_submodel = torch.nn.DataParallel(self.vgg_submodel, device_ids=gpu_ids).cuda() 26 | 27 | print(self.vgg_submodel) 28 | 29 | def forward(self, inputs, targets): 30 | if self.lambda_L1 == 0 and self.lambda_perceptual == 0: 31 | return torch.zeros(1).cuda(), torch.zeros(1), torch.zeros(1) 32 | # normal L1 33 | loss_l1 = F.l1_loss(inputs, targets) * self.lambda_L1 34 | 35 | # perceptual L1 36 | mean = torch.FloatTensor(3) 37 | mean[0] = 0.485 38 | mean[1] = 0.456 39 | mean[2] = 0.406 40 | mean = mean.resize(1, 3, 1, 1).cuda() 41 | 42 | std = torch.FloatTensor(3) 43 | std[0] = 0.229 44 | std[1] = 0.224 45 | std[2] = 0.225 46 | std = std.resize(1, 3, 1, 1).cuda() 47 | 48 | fake_p2_norm = (inputs + 1)/2 # [-1, 1] => [0, 1] 49 | fake_p2_norm = (fake_p2_norm - mean)/std 50 | 51 | input_p2_norm = (targets + 1)/2 # [-1, 1] => [0, 1] 52 | input_p2_norm = (input_p2_norm - mean)/std 53 | 54 | 55 | fake_p2_norm = self.vgg_submodel(fake_p2_norm) 56 | input_p2_norm = self.vgg_submodel(input_p2_norm) 57 | input_p2_norm_no_grad = input_p2_norm.detach() 58 | 59 | if self.percep_is_l1 == 1: 60 | # use l1 for perceptual loss 61 | loss_perceptual = F.l1_loss(fake_p2_norm, input_p2_norm_no_grad) * self.lambda_perceptual 62 | else: 63 | # use l2 for perceptual loss 64 | loss_perceptual = F.mse_loss(fake_p2_norm, input_p2_norm_no_grad) * self.lambda_perceptual 65 | 66 | loss = loss_l1 + loss_perceptual 67 | 68 | return loss, loss_l1, loss_perceptual 69 | 70 | -------------------------------------------------------------------------------- /facial/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/losses/__init__.py -------------------------------------------------------------------------------- /facial/losses/__pycache__/L1_plus_perceptualLoss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/losses/__pycache__/L1_plus_perceptualLoss.cpython-35.pyc -------------------------------------------------------------------------------- /facial/losses/__pycache__/L1_plus_perceptualLoss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/losses/__pycache__/L1_plus_perceptualLoss.cpython-36.pyc -------------------------------------------------------------------------------- /facial/losses/__pycache__/L1_plus_perceptualLoss.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/losses/__pycache__/L1_plus_perceptualLoss.cpython-38.pyc -------------------------------------------------------------------------------- /facial/losses/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/losses/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /facial/losses/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/losses/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /facial/losses/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/losses/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /facial/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__init__.py -------------------------------------------------------------------------------- /facial/models/__pycache__/BiGraphGAN.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/BiGraphGAN.cpython-38.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/PATN.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/PATN.cpython-35.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/PATN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/PATN.cpython-36.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/base_model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/base_model.cpython-35.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/base_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/base_model.cpython-36.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/base_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/base_model.cpython-38.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/model_variants.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/model_variants.cpython-35.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/model_variants.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/model_variants.cpython-36.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/model_variants.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/model_variants.cpython-38.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/models.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/models.cpython-35.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/models.cpython-36.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/models.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/models.cpython-38.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/networks.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/networks.cpython-35.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /facial/models/__pycache__/networks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/models/__pycache__/networks.cpython-38.pyc -------------------------------------------------------------------------------- /facial/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class BaseModel(nn.Module): 7 | 8 | def __init__(self): 9 | super(BaseModel, self).__init__() 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, opt): 15 | self.opt = opt 16 | self.gpu_ids = opt.gpu_ids 17 | self.isTrain = opt.isTrain 18 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 19 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 20 | 21 | def set_input(self, input): 22 | self.input = input 23 | 24 | def forward(self): 25 | pass 26 | 27 | # used in test time, no backprop 28 | def test(self): 29 | pass 30 | 31 | def get_image_paths(self): 32 | pass 33 | 34 | def optimize_parameters(self): 35 | pass 36 | 37 | def get_current_visuals(self): 38 | return self.input 39 | 40 | def get_current_errors(self): 41 | return {} 42 | 43 | def save(self, label): 44 | pass 45 | 46 | # helper saving function that can be used by subclasses 47 | def save_network(self, network, network_label, epoch_label, gpu_ids): 48 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 49 | save_path = os.path.join(self.save_dir, save_filename) 50 | torch.save(network.cpu().state_dict(), save_path) 51 | if len(gpu_ids) and torch.cuda.is_available(): 52 | network.cuda(gpu_ids[0]) 53 | 54 | # helper loading function that can be used by subclasses 55 | def load_network(self, network, network_label, epoch_label): 56 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 57 | save_path = os.path.join(self.save_dir, save_filename) 58 | network.load_state_dict(torch.load(save_path)) 59 | 60 | # update learning rate (called once every epoch) 61 | def update_learning_rate(self): 62 | for scheduler in self.schedulers: 63 | scheduler.step() 64 | lr = self.optimizers[0].param_groups[0]['lr'] 65 | print('learning rate = %.7f' % lr) 66 | -------------------------------------------------------------------------------- /facial/models/models.py: -------------------------------------------------------------------------------- 1 | 2 | def create_model(opt): 3 | model = None 4 | print(opt.model) 5 | 6 | if opt.model == 'BiGraphGAN': 7 | # assert opt.dataset_mode == 'keypoint' 8 | from .BiGraphGAN import TransferModel 9 | model = TransferModel() 10 | 11 | else: 12 | raise ValueError("Model [%s] not recognized." % opt.model) 13 | model.initialize(opt) 14 | print("model [%s] was created" % (model.name())) 15 | return model 16 | -------------------------------------------------------------------------------- /facial/models/test_model.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | from collections import OrderedDict 3 | import util.util as util 4 | from .base_model import BaseModel 5 | from . import networks 6 | 7 | 8 | class TestModel(BaseModel): 9 | def name(self): 10 | return 'TestModel' 11 | 12 | def initialize(self, opt): 13 | assert(not opt.isTrain) 14 | BaseModel.initialize(self, opt) 15 | self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) 16 | 17 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, 18 | opt.ngf, opt.which_model_netG, 19 | opt.norm, not opt.no_dropout, 20 | opt.init_type, 21 | self.gpu_ids) 22 | which_epoch = opt.which_epoch 23 | self.load_network(self.netG, 'G', which_epoch) 24 | 25 | print('---------- Networks initialized -------------') 26 | networks.print_network(self.netG) 27 | print('-----------------------------------------------') 28 | 29 | def set_input(self, input): 30 | # we need to use single_dataset mode 31 | input_A = input['A'] 32 | self.input_A.resize_(input_A.size()).copy_(input_A) 33 | self.image_paths = input['A_paths'] 34 | 35 | def test(self): 36 | self.real_A = Variable(self.input_A) 37 | self.fake_B = self.netG(self.real_A) 38 | 39 | # get image paths 40 | def get_image_paths(self): 41 | return self.image_paths 42 | 43 | def get_current_visuals(self): 44 | real_A = util.tensor2im(self.real_A.data) 45 | fake_B = util.tensor2im(self.fake_B.data) 46 | return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) 47 | -------------------------------------------------------------------------------- /facial/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/__init__.py -------------------------------------------------------------------------------- /facial/options/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/__init__.pyc -------------------------------------------------------------------------------- /facial/options/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /facial/options/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /facial/options/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /facial/options/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /facial/options/__pycache__/base_options.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/__pycache__/base_options.cpython-35.pyc -------------------------------------------------------------------------------- /facial/options/__pycache__/base_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/__pycache__/base_options.cpython-36.pyc -------------------------------------------------------------------------------- /facial/options/__pycache__/base_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/__pycache__/base_options.cpython-37.pyc -------------------------------------------------------------------------------- /facial/options/__pycache__/base_options.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/__pycache__/base_options.cpython-38.pyc -------------------------------------------------------------------------------- /facial/options/__pycache__/test_options.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/__pycache__/test_options.cpython-35.pyc -------------------------------------------------------------------------------- /facial/options/__pycache__/test_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/__pycache__/test_options.cpython-36.pyc -------------------------------------------------------------------------------- /facial/options/__pycache__/test_options.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/__pycache__/test_options.cpython-38.pyc -------------------------------------------------------------------------------- /facial/options/__pycache__/train_options.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/__pycache__/train_options.cpython-35.pyc -------------------------------------------------------------------------------- /facial/options/__pycache__/train_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/__pycache__/train_options.cpython-37.pyc -------------------------------------------------------------------------------- /facial/options/__pycache__/train_options.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/__pycache__/train_options.cpython-38.pyc -------------------------------------------------------------------------------- /facial/options/base_options.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/base_options.pyc -------------------------------------------------------------------------------- /facial/options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 8 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 9 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 10 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 11 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 12 | self.parser.add_argument('--how_many', type=int, default=200, help='how many test images to run') 13 | 14 | self.isTrain = False 15 | -------------------------------------------------------------------------------- /facial/options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 8 | self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 9 | self.parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 10 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 11 | self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 12 | self.parser.add_argument('--save_epoch_freq', type=int, default=20, help='frequency of saving checkpoints at the end of epochs') 13 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 14 | self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 15 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 16 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 17 | self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 18 | self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 19 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 20 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 21 | self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 22 | self.parser.add_argument('--lambda_A', type=float, default=10.0, help='weight for L1 loss') 23 | self.parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for perceptual L1 loss') 24 | self.parser.add_argument('--lambda_GAN', type=float, default=5.0, help='weight of GAN loss') 25 | 26 | self.parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 27 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 28 | self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') 29 | self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 30 | 31 | self.parser.add_argument('--L1_type', type=str, default='origin', help='use which kind of L1 loss. (origin|l1_plus_perL1)') 32 | self.parser.add_argument('--perceptual_layers', type=int, default=3, help='index of vgg layer for extracting perceptual features.') 33 | self.parser.add_argument('--percep_is_l1', type=int, default=1, help='type of perceptual loss: l1 or l2') 34 | self.parser.add_argument('--no_dropout_D', action='store_true', help='no dropout for the discriminator') 35 | self.parser.add_argument('--DG_ratio', type=int, default=1, help='how many times for D training after training G once') 36 | 37 | 38 | 39 | 40 | self.isTrain = True 41 | -------------------------------------------------------------------------------- /facial/options/train_options.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/options/train_options.pyc -------------------------------------------------------------------------------- /facial/scripts/download_bigraphgan_model.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | echo "Note: available models are facial" 4 | echo "Specified [$FILE]" 5 | 6 | URL=http://disi.unitn.it/~hao.tang/uploads/models/BiGraphGAN/${FILE}_pretrained.tar.gz 7 | TAR_FILE=../checkpoints/${FILE}_pretrained.tar.gz 8 | TARGET_DIR=../checkpoints/${FILE}_pretrained/ 9 | 10 | wget -N $URL -O $TAR_FILE 11 | 12 | mkdir -p $TARGET_DIR 13 | tar -zxvf $TAR_FILE -C ../checkpoints/ 14 | rm $TAR_FILE 15 | -------------------------------------------------------------------------------- /facial/scripts/download_bigraphgan_result.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | echo "Note: available results are facial" 4 | echo "Specified [$FILE]" 5 | 6 | URL=http://disi.unitn.it/~hao.tang/uploads/results/BiGraphGAN/${FILE}_results.tar.gz 7 | TAR_FILE=../results_by_author/${FILE}_results.tar.gz 8 | TARGET_DIR=../results_by_author/${FILE}_results/ 9 | 10 | wget -N $URL -O $TAR_FILE 11 | 12 | mkdir -p $TARGET_DIR 13 | tar -zxvf $TAR_FILE -C ../results_by_author/ 14 | rm $TAR_FILE 15 | -------------------------------------------------------------------------------- /facial/test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from options.test_options import TestOptions 4 | from data.data_loader import CreateDataLoader 5 | from models.models import create_model 6 | from util.visualizer import Visualizer 7 | from util import html 8 | import time 9 | 10 | opt = TestOptions().parse() 11 | opt.nThreads = 1 # test code only supports nThreads = 1 12 | opt.batchSize = 1 # test code only supports batchSize = 1 13 | opt.serial_batches = True # no shuffle 14 | opt.no_flip = True # no flip 15 | 16 | data_loader = CreateDataLoader(opt) 17 | dataset = data_loader.load_data() 18 | model = create_model(opt) 19 | visualizer = Visualizer(opt) 20 | # create website 21 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 22 | 23 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 24 | 25 | print(opt.how_many) 26 | print(len(dataset)) 27 | 28 | model = model.eval() 29 | print(model.training) 30 | 31 | opt.how_many = 999999 32 | # test 33 | for i, data in enumerate(dataset): 34 | print(' process %d/%d img ..'%(i,opt.how_many)) 35 | if i >= opt.how_many: 36 | break 37 | model.set_input(data) 38 | startTime = time.time() 39 | model.test() 40 | endTime = time.time() 41 | print(endTime-startTime) 42 | visuals = model.get_current_visuals() 43 | img_path = model.get_image_paths() 44 | #img_path = [img_path] 45 | print(img_path) 46 | #img_path.replace("/home/ht1/Radboud_selectiongan/test", "/home/ht1/BiGraphGAN_otherTask/market_1501/results/facial_exp") 47 | img_path = img_path[0].split('/')[-1] 48 | #print('img_path', img_path) 49 | img_path = [img_path] 50 | visualizer.save_images(webpage, visuals, img_path) 51 | 52 | webpage.save() 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /facial/test_facial.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0; 2 | python test.py --dataroot /home/ht1/Radboud_selectiongan/ --name facial_exp --model BiGraphGAN --phase test --dataset_mode aligned --norm batch --batchSize 1 --resize_or_crop no --gpu_ids 0 --BP_input_nc 3 --no_flip --which_model_netG Graph --checkpoints_dir ./checkpoints --which_epoch 200 --results_dir ./results/ --display_id 0; 3 | 4 | -------------------------------------------------------------------------------- /facial/tool/calPCKH_fashion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pandas as pd 3 | import numpy as np 4 | import json 5 | 6 | MISSING_VALUE = -1 7 | 8 | PARTS_SEL = [0, 1, 14, 15, 16, 17] 9 | 10 | target_annotation = './fashion_data/fasion-resize-annotation-test.csv' 11 | pred_annotation = './results/fashion_PATN/pckh.csv' 12 | 13 | 14 | ''' 15 | hz: head size 16 | alpha: norm factor 17 | px, py: predict coords 18 | tx, ty: target coords 19 | ''' 20 | def isRight(px, py, tx, ty, hz, alpha): 21 | if px == -1 or py == -1 or tx == -1 or ty == -1: 22 | return 0 23 | 24 | if abs(px - tx) < hz[0] * alpha and abs(py - ty) < hz[1] * alpha: 25 | return 1 26 | else: 27 | return 0 28 | 29 | 30 | def how_many_right_seq(px, py, tx, ty, hz, alpha): 31 | nRight = 0 32 | for i in range(len(px)): 33 | nRight = nRight + isRight(px[i], py[i], tx[i], ty[i], hz, alpha) 34 | 35 | return nRight 36 | 37 | 38 | def ValidPoints(tx): 39 | nValid = 0 40 | for item in tx: 41 | if item != -1: 42 | nValid = nValid + 1 43 | return nValid 44 | 45 | 46 | def get_head_wh(x_coords, y_coords): 47 | final_w, final_h = -1, -1 48 | component_count = 0 49 | save_componets = [] 50 | for component in PARTS_SEL: 51 | if x_coords[component] == MISSING_VALUE or y_coords[component] == MISSING_VALUE: 52 | continue 53 | else: 54 | component_count += 1 55 | save_componets.append([x_coords[component], y_coords[component]]) 56 | if component_count >= 2: 57 | x_cords = [] 58 | y_cords = [] 59 | for component in save_componets: 60 | x_cords.append(component[0]) 61 | y_cords.append(component[1]) 62 | xmin = min(x_cords) 63 | xmax = max(x_cords) 64 | ymin = min(y_cords) 65 | ymax = max(y_cords) 66 | final_w = xmax - xmin 67 | final_h = ymax - ymin 68 | return final_w, final_h 69 | 70 | 71 | tAnno = pd.read_csv(target_annotation, sep=':') 72 | pAnno = pd.read_csv(pred_annotation, sep=':') 73 | 74 | pRows = pAnno.shape[0] 75 | 76 | nAll = 0 77 | nCorrect = 0 78 | alpha = 0.5 79 | for i in range(pRows): 80 | pValues = pAnno.iloc[i].values 81 | pname = pValues[0] 82 | pycords = json.loads(pValues[1]) # list of numbers 83 | pxcords = json.loads(pValues[2]) 84 | 85 | if '_vis' in pname: 86 | tname = pname[:-8] 87 | else: 88 | tname = pname[:-4] 89 | 90 | if '___' in tname: 91 | tname = tname.split('___')[1] 92 | else: 93 | tname = tname.split('jpg_')[1] 94 | 95 | print(tname) 96 | tValues = tAnno.query('name == "%s"' % (tname)).values[0] 97 | tycords = json.loads(tValues[1]) # list of numbers 98 | txcords = json.loads(tValues[2]) 99 | 100 | 101 | xBox, yBox = get_head_wh(txcords, tycords) 102 | if xBox == -1 or yBox == -1: 103 | continue 104 | 105 | head_size = (xBox, yBox) 106 | nAll = nAll + ValidPoints(tycords) 107 | nCorrect = nCorrect + how_many_right_seq(pxcords, pycords, txcords, tycords, head_size, alpha) 108 | 109 | print('%d/%d %f' % (nCorrect, nAll, nCorrect * 1.0 / nAll)) 110 | -------------------------------------------------------------------------------- /facial/tool/calPCKH_market.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pandas as pd 3 | import numpy as np 4 | import json 5 | 6 | MISSING_VALUE = -1 7 | 8 | PARTS_SEL = [0, 1, 14, 15, 16, 17] 9 | 10 | # fix the PATH 11 | target_annotation = './market_data/market-annotation-test.csv' 12 | pred_annotation = '/results/market_PATN/pckh.csv' 13 | 14 | 15 | ''' 16 | hz: head size 17 | alpha: norm factor 18 | px, py: predict coords 19 | tx, ty: target coords 20 | ''' 21 | def isRight(px, py, tx, ty, hz, alpha): 22 | if px == -1 or py == -1 or tx == -1 or ty == -1: 23 | return 0 24 | 25 | if abs(px-tx) < hz[0]*alpha and abs(py-ty) < hz[1]*alpha: 26 | return 1 27 | else: 28 | return 0 29 | 30 | def how_many_right_seq(px, py, tx, ty, hz, alpha): 31 | nRight = 0 32 | for i in range(len(px)): 33 | nRight = nRight + isRight(px[i], py[i], tx[i], ty[i], hz, alpha) 34 | 35 | return nRight 36 | 37 | def ValidPoints(tx): 38 | nValid = 0 39 | for item in tx: 40 | if item != -1: 41 | nValid = nValid + 1 42 | return nValid 43 | 44 | def get_head_wh(x_coords, y_coords): 45 | final_w, final_h = -1, -1 46 | component_count = 0 47 | save_componets = [] 48 | for component in PARTS_SEL: 49 | if x_coords[component] == MISSING_VALUE or y_coords[component] == MISSING_VALUE: 50 | continue 51 | else: 52 | component_count += 1 53 | save_componets.append([x_coords[component], y_coords[component]]) 54 | if component_count >= 2: 55 | x_cords = [] 56 | y_cords = [] 57 | for component in save_componets: 58 | x_cords.append(component[0]) 59 | y_cords.append(component[1]) 60 | xmin = min(x_cords) 61 | xmax = max(x_cords) 62 | ymin = min(y_cords) 63 | ymax = max(y_cords) 64 | final_w = xmax - xmin 65 | final_h = ymax - ymin 66 | return final_w, final_h 67 | 68 | 69 | 70 | 71 | 72 | tAnno = pd.read_csv(target_annotation, sep=':') 73 | pAnno = pd.read_csv(pred_annotation, sep=':') 74 | 75 | pRows = pAnno.shape[0] 76 | 77 | nAll = 0 78 | nCorrect = 0 79 | alpha = 0.5 80 | for i in range(pRows): 81 | pValues = pAnno.iloc[i].values 82 | pname = pValues[0] 83 | pycords = json.loads(pValues[1]) #list of numbers 84 | pxcords = json.loads(pValues[2]) 85 | 86 | if '_vis' in pname: 87 | tname = pname[:-8] 88 | else: 89 | tname = pname[:-4] 90 | 91 | if '___' in tname: 92 | tname = tname.split('___')[1] 93 | else: 94 | tname = tname.split('jpg_')[1] 95 | 96 | print(tname) 97 | tValues = tAnno.query('name == "%s"' %(tname)).values[0] 98 | tycords = json.loads(tValues[1]) #list of numbers 99 | txcords = json.loads(tValues[2]) 100 | 101 | xBox, yBox = get_head_wh(txcords, tycords) 102 | if xBox == -1 or yBox == -1: 103 | continue 104 | 105 | head_size = (xBox, yBox) 106 | nAll = nAll + ValidPoints(tycords) 107 | nCorrect = nCorrect + how_many_right_seq(pxcords, pycords, txcords, tycords, head_size, alpha) 108 | 109 | 110 | print('%d/%d %f' %(nCorrect, nAll, nCorrect*1.0/nAll)) 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /facial/tool/create_pairs_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pose_utils 3 | from itertools import permutations 4 | 5 | LABELS = ['nose', 'neck', 'Rsho', 'Relb', 'Rwri', 'Lsho', 'Lelb', 'Lwri', 6 | 'Rhip', 'Rkne', 'Rank', 'Lhip', 'Lkne', 'Lank', 'Leye', 'Reye', 'Lear', 'Rear'] 7 | 8 | MISSING_VALUE = -1 9 | 10 | def give_name_to_keypoints(array): 11 | res = {} 12 | for i, name in enumerate(LABELS): 13 | if array[i][0] != MISSING_VALUE and array[i][1] != MISSING_VALUE: 14 | res[name] = array[i][::-1] 15 | return res 16 | 17 | 18 | def pose_check_valid(kp_array): 19 | kp = give_name_to_keypoints(kp_array) 20 | return check_keypoints_present(kp, ['Rhip', 'Lhip', 'Lsho', 'Rsho']) 21 | 22 | 23 | def check_keypoints_present(kp, kp_names): 24 | result = True 25 | for name in kp_names: 26 | result = result and (name in kp) 27 | return result 28 | 29 | def filter_not_valid(df_keypoints): 30 | def check_valid(x): 31 | kp_array = pose_utils.load_pose_cords_from_strings(x['keypoints_y'], x['keypoints_x']) 32 | distractor = x['name'].startswith('-1') or x['name'].startswith('0000') 33 | return pose_check_valid(kp_array) and not distractor 34 | return df_keypoints[df_keypoints.apply(check_valid, axis=1)].copy() 35 | 36 | 37 | def make_pairs(df): 38 | persons = df.apply(lambda x: '_'.join(x['name'].split('_')[0:1]), axis=1) 39 | df['person'] = persons 40 | fr, to = [], [] 41 | for person in pd.unique(persons): 42 | pairs = zip(*list(permutations(df[df['person'] == person]['name'], 2))) 43 | if len(pairs) != 0: 44 | fr += list(pairs[0]) 45 | to += list(pairs[1]) 46 | pair_df = pd.DataFrame(index=range(len(fr))) 47 | pair_df['from'] = fr 48 | pair_df['to'] = to 49 | return pair_df 50 | 51 | 52 | if __name__ == "__main__": 53 | images_for_test = 12000 54 | 55 | annotations_file_train = './market_data/market-annotation-test.csv' 56 | pairs_file_train = './market_data/example_market-pairs-train.csv' 57 | 58 | df_keypoints = pd.read_csv(annotations_file_train, sep=':') 59 | df = filter_not_valid(df_keypoints) 60 | print ('Compute pair dataset for train...') 61 | pairs_df_train = make_pairs(df) 62 | print ('Number of pairs: %s' % len(pairs_df_train)) 63 | pairs_df_train.to_csv(pairs_file_train, index=False) 64 | 65 | annotations_file_test= './market_data/market-annotation-test.csv' 66 | pairs_file_test = './market_data/example_market-pairs-test.csv' 67 | 68 | print ('Compute pair dataset for test...') 69 | df_keypoints = pd.read_csv(annotations_file_test, sep=':') 70 | df = filter_not_valid(df_keypoints) 71 | pairs_df_test = make_pairs(df) 72 | pairs_df_test = pairs_df_test.sample(n=min(images_for_test, pairs_df_test.shape[0]), replace=False, random_state=0) 73 | print ('Number of pairs: %s' % len(pairs_df_test)) 74 | pairs_df_test.to_csv(pairs_file_test, index=False) 75 | 76 | -------------------------------------------------------------------------------- /facial/tool/crop_fashion.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | 4 | img_dir = './results/fashion_PATN_test/test_latest/images' 5 | save_dir = './results/fashion_PATN_test/test_latest/images_crop' 6 | 7 | if not os.path.exists(save_dir): 8 | os.mkdir(save_dir) 9 | 10 | cnt = 0 11 | 12 | for item in os.listdir(img_dir): 13 | if not item.endswith('.jpg') and not item.endswith('.png'): 14 | continue 15 | cnt = cnt + 1 16 | print('%d/8570 ...' %(cnt)) 17 | img = Image.open(os.path.join(img_dir, item)) 18 | imgcrop = img.crop((704, 0, 880, 256)) 19 | imgcrop.save(os.path.join(save_dir, item)) 20 | -------------------------------------------------------------------------------- /facial/tool/crop_market.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | 4 | img_dir = './results/market_PATN_test/test_latest/images' 5 | save_dir = './results/market_PATN_test/test_latest/images_crop' 6 | 7 | if not os.path.exists(save_dir): 8 | os.mkdir(save_dir) 9 | 10 | cnt = 0 11 | for item in os.listdir(img_dir): 12 | if not item.endswith('.jpg') and not item.endswith('.png'): 13 | continue 14 | cnt = cnt + 1 15 | print('%d/12000 ...' %(cnt)) 16 | img = Image.open(os.path.join(img_dir, item)) 17 | # for 5 split 18 | imgcrop = img.crop((256, 0, 320, 128)) 19 | imgcrop.save(os.path.join(save_dir, item)) 20 | -------------------------------------------------------------------------------- /facial/tool/generate_fashion_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from PIL import Image 4 | 5 | IMG_EXTENSIONS = [ 6 | '.jpg', '.JPG', '.jpeg', '.JPEG', 7 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 8 | ] 9 | 10 | def is_image_file(filename): 11 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 12 | 13 | def make_dataset(dir): 14 | images = [] 15 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 16 | new_root = './fashion_data' 17 | if not os.path.exists(new_root): 18 | os.mkdir(new_root) 19 | 20 | train_root = './fashion_data/train' 21 | if not os.path.exists(train_root): 22 | os.mkdir(train_root) 23 | 24 | test_root = './fashion_data/test' 25 | if not os.path.exists(test_root): 26 | os.mkdir(test_root) 27 | 28 | train_images = [] 29 | train_f = open('./fashion_data/train.lst', 'r') 30 | for lines in train_f: 31 | lines = lines.strip() 32 | if lines.endswith('.jpg'): 33 | train_images.append(lines) 34 | 35 | test_images = [] 36 | test_f = open('./fashion_data/test.lst', 'r') 37 | for lines in test_f: 38 | lines = lines.strip() 39 | if lines.endswith('.jpg'): 40 | test_images.append(lines) 41 | 42 | print(train_images, test_images) 43 | 44 | 45 | for root, _, fnames in sorted(os.walk(dir)): 46 | for fname in fnames: 47 | if is_image_file(fname): 48 | path = os.path.join(root, fname) 49 | path_names = path.split('/') 50 | # path_names[2] = path_names[2].replace('_', '') 51 | path_names[3] = path_names[3].replace('_', '') 52 | path_names[4] = path_names[4].split('_')[0] + "_" + "".join(path_names[4].split('_')[1:]) 53 | path_names = "".join(path_names) 54 | # new_path = os.path.join(root, path_names) 55 | img = Image.open(path) 56 | imgcrop = img.crop((40, 0, 216, 256)) 57 | if new_path in train_images: 58 | imgcrop.save(os.path.join(train_root, path_names)) 59 | elif new_path in test_images: 60 | imgcrop.save(os.path.join(test_root, path_names)) 61 | 62 | make_dataset('./fashion') 63 | -------------------------------------------------------------------------------- /facial/tool/generate_pose_map_fashion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import json 4 | import os 5 | 6 | MISSING_VALUE = -1 7 | # fix PATH 8 | img_dir = 'fashion_data' #raw image path 9 | annotations_file = 'fashion_data/fasion-resize-annotation-train.csv' #pose annotation path 10 | save_path = 'fashion_data/trainK' #path to store pose maps 11 | 12 | def load_pose_cords_from_strings(y_str, x_str): 13 | y_cords = json.loads(y_str) 14 | x_cords = json.loads(x_str) 15 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1) 16 | 17 | def cords_to_map(cords, img_size, sigma=6): 18 | result = np.zeros(img_size + cords.shape[0:1], dtype='uint8') 19 | for i, point in enumerate(cords): 20 | if point[0] == MISSING_VALUE or point[1] == MISSING_VALUE: 21 | continue 22 | xx, yy = np.meshgrid(np.arange(img_size[1]), np.arange(img_size[0])) 23 | result[..., i] = np.exp(-((yy - point[0]) ** 2 + (xx - point[1]) ** 2) / (2 * sigma ** 2)) 24 | # result[..., i] = np.where(((yy - point[0]) ** 2 + (xx - point[1]) ** 2) < (sigma ** 2), 1, 0) 25 | return result 26 | 27 | def compute_pose(image_dir, annotations_file, savePath, sigma): 28 | annotations_file = pd.read_csv(annotations_file, sep=':') 29 | annotations_file = annotations_file.set_index('name') 30 | image_size = (256, 176) 31 | cnt = len(annotations_file) 32 | for i in range(cnt): 33 | print('processing %d / %d ...' %(i, cnt)) 34 | row = annotations_file.iloc[i] 35 | name = row.name 36 | print(savePath, name) 37 | file_name = os.path.join(savePath, name + '.npy') 38 | kp_array = load_pose_cords_from_strings(row.keypoints_y, row.keypoints_x) 39 | pose = cords_to_map(kp_array, image_size, sigma) 40 | np.save(file_name, pose) 41 | # input() 42 | 43 | compute_pose(img_dir, annotations_file, save_path) 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /facial/tool/generate_pose_map_market.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import json 4 | import os 5 | 6 | MISSING_VALUE = -1 7 | 8 | img_dir = 'market_data/train' #raw image path 9 | annotations_file = 'market_data/market-annotation-train.csv' #pose annotation path 10 | save_path = 'market_data/trainK' #path to store pose maps 11 | 12 | def load_pose_cords_from_strings(y_str, x_str): 13 | y_cords = json.loads(y_str) 14 | x_cords = json.loads(x_str) 15 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1) 16 | 17 | def cords_to_map(cords, img_size, sigma=6): 18 | result = np.zeros(img_size + cords.shape[0:1], dtype='float32') 19 | for i, point in enumerate(cords): 20 | if point[0] == MISSING_VALUE or point[1] == MISSING_VALUE: 21 | continue 22 | xx, yy = np.meshgrid(np.arange(img_size[1]), np.arange(img_size[0])) 23 | result[..., i] = np.exp(-((yy - point[0]) ** 2 + (xx - point[1]) ** 2) / (2 * sigma ** 2)) 24 | return result 25 | 26 | def compute_pose(image_dir, annotations_file, savePath): 27 | annotations_file = pd.read_csv(annotations_file, sep=':') 28 | annotations_file = annotations_file.set_index('name') 29 | image_size = (128, 64) 30 | cnt = len(annotations_file) 31 | for i in range(cnt): 32 | print('processing %d / %d ...' %(i, cnt)) 33 | row = annotations_file.iloc[i] 34 | name = row.name 35 | print(savePath, name) 36 | file_name = os.path.join(savePath, name + '.npy') 37 | kp_array = load_pose_cords_from_strings(row.keypoints_y, row.keypoints_x) 38 | pose = cords_to_map(kp_array, image_size) 39 | np.save(file_name, pose) 40 | 41 | compute_pose(img_dir, annotations_file, save_path) 42 | 43 | -------------------------------------------------------------------------------- /facial/tool/inception_score.py: -------------------------------------------------------------------------------- 1 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os.path 7 | import sys 8 | import tarfile 9 | 10 | import numpy as np 11 | from six.moves import urllib 12 | import tensorflow as tf 13 | import glob 14 | import scipy.misc 15 | import math 16 | import sys 17 | 18 | MODEL_DIR = '~/models' 19 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 20 | softmax = None 21 | 22 | # Call this function with list of images. Each of elements should be a 23 | # numpy array with values ranging from 0 to 255. 24 | def get_inception_score(images, splits=10): 25 | #assert(type(images) == list) 26 | assert(type(images[0]) == np.ndarray) 27 | assert(len(images[0].shape) == 3) 28 | assert(np.max(images[0]) > 10) 29 | assert(np.min(images[0]) >= 0.0) 30 | inps = [] 31 | for img in images: 32 | img = img.astype(np.float32) 33 | inps.append(np.expand_dims(img, 0)) 34 | bs = 10 35 | with tf.Session() as sess: 36 | preds = [] 37 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 38 | for i in range(n_batches): 39 | sys.stdout.write(".") 40 | sys.stdout.flush() 41 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 42 | inp = np.concatenate(inp, 0) 43 | pred = sess.run(softmax, {'ExpandDims:0': inp}) 44 | preds.append(pred) 45 | preds = np.concatenate(preds, 0) 46 | scores = [] 47 | for i in range(splits): 48 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 49 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 50 | kl = np.mean(np.sum(kl, 1)) 51 | scores.append(np.exp(kl)) 52 | return np.mean(scores), np.std(scores) 53 | 54 | # This function is called automatically. 55 | def _init_inception(): 56 | global softmax 57 | if not os.path.exists(MODEL_DIR): 58 | os.makedirs(MODEL_DIR) 59 | filename = DATA_URL.split('/')[-1] 60 | filepath = os.path.join(MODEL_DIR, filename) 61 | if not os.path.exists(filepath): 62 | def _progress(count, block_size, total_size): 63 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 64 | filename, float(count * block_size) / float(total_size) * 100.0)) 65 | sys.stdout.flush() 66 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 67 | print() 68 | statinfo = os.stat(filepath) 69 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 70 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 71 | with tf.gfile.FastGFile(os.path.join( 72 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 73 | graph_def = tf.GraphDef() 74 | graph_def.ParseFromString(f.read()) 75 | _ = tf.import_graph_def(graph_def, name='') 76 | # Works with an arbitrary minibatch size. 77 | with tf.Session() as sess: 78 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 79 | ops = pool3.graph.get_operations() 80 | for op_idx, op in enumerate(ops): 81 | for o in op.outputs: 82 | shape = o.get_shape() 83 | shape = [s.value for s in shape] 84 | new_shape = [] 85 | for j, s in enumerate(shape): 86 | if s == 1 and j == 0: 87 | new_shape.append(None) 88 | else: 89 | new_shape.append(s) 90 | o._shape = tf.TensorShape(new_shape) 91 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] 92 | logits = tf.matmul(tf.squeeze(pool3), w) 93 | softmax = tf.nn.softmax(logits) 94 | 95 | if softmax is None: 96 | _init_inception() 97 | -------------------------------------------------------------------------------- /facial/tool/resize_fashion.py: -------------------------------------------------------------------------------- 1 | from skimage.io import imread, imsave 2 | from skimage.transform import resize 3 | import os 4 | import numpy as np 5 | import pandas as pd 6 | import json 7 | 8 | def resize_dataset(folder, new_folder, new_size = (256, 176), crop_bord=40): 9 | if not os.path.exists(new_folder): 10 | os.makedirs(new_folder) 11 | for name in os.listdir(folder): 12 | old_name = os.path.join(folder, name) 13 | new_name = os.path.join(new_folder, name) 14 | 15 | img = imread(old_name) 16 | if crop_bord == 0: 17 | pass 18 | else: 19 | img = img[:, crop_bord:-crop_bord] 20 | 21 | img = resize(img, new_size, preserve_range=True).astype(np.uint8) 22 | 23 | imsave(new_name, img) 24 | 25 | def resize_annotations(name, new_name, new_size = (256, 176), old_size = (256, 256), crop_bord=40): 26 | df = pd.read_csv(name, sep=':') 27 | 28 | ratio_y = new_size[0] / float(old_size[0]) 29 | ratio_x = new_size[1] / float(old_size[1] - 2 * crop_bord) 30 | 31 | def modify(values, ratio, crop): 32 | val = np.array(json.loads(values)) 33 | mask = val == -1 34 | val = ((val - crop) * ratio).astype(int) 35 | val[mask] = -1 36 | return str(list(val)) 37 | 38 | df['keypoints_y'] = df.apply(lambda row: modify(row['keypoints_y'], ratio_y, 0), axis=1) 39 | df['keypoints_x'] = df.apply(lambda row: modify(row['keypoints_x'], ratio_x, crop_bord), axis=1) 40 | 41 | df.to_csv(new_name, sep=':', index=False) 42 | 43 | 44 | root_dir = 'xxx' 45 | resize_dataset(root_dir + '/test', root_dir + 'fashion_resize/test') 46 | resize_annotations(root_dir + 'fasion-annotation-test.csv', root_dir + 'fasion-resize-annotation-test.csv') 47 | 48 | resize_dataset(root_dir + '/train', root_dir + 'fashion_resize/train') 49 | resize_annotations(root_dir + 'fasion-annotation-train.csv', root_dir + 'fasion-resize-annotation-train.csv') 50 | 51 | 52 | -------------------------------------------------------------------------------- /facial/tool/rm_insnorm_running_vars.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ckp_path = './checkpoints/fashion_PATN/latest_net_netG.pth' 4 | save_path = './checkpoints/fashion_PATN_v1.0/latest_net_netG.pth' 5 | states_dict = torch.load(ckp_path) 6 | states_dict_new = states_dict.copy() 7 | for key in states_dict.keys(): 8 | if "running_var" in key or "running_mean" in key: 9 | del states_dict_new[key] 10 | 11 | torch.save(states_dict_new, save_path) -------------------------------------------------------------------------------- /facial/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data.data_loader import CreateDataLoader 4 | from models.models import create_model 5 | from util.visualizer import Visualizer 6 | 7 | opt = TrainOptions().parse() 8 | data_loader = CreateDataLoader(opt) 9 | dataset = data_loader.load_data() 10 | dataset_size = len(data_loader) 11 | print('#training images = %d' % dataset_size) 12 | 13 | model = create_model(opt) 14 | visualizer = Visualizer(opt) 15 | total_steps = 0 16 | 17 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): 18 | epoch_start_time = time.time() 19 | epoch_iter = 0 20 | 21 | for i, data in enumerate(dataset): 22 | iter_start_time = time.time() 23 | visualizer.reset() 24 | total_steps += opt.batchSize 25 | epoch_iter += opt.batchSize 26 | model.set_input(data) 27 | model.optimize_parameters() 28 | 29 | if total_steps % opt.display_freq == 0: 30 | save_result = total_steps % opt.update_html_freq == 0 31 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 32 | 33 | if total_steps % opt.print_freq == 0: 34 | errors = model.get_current_errors() 35 | t = (time.time() - iter_start_time) / opt.batchSize 36 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 37 | if opt.display_id > 0: 38 | visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) 39 | 40 | if total_steps % opt.save_latest_freq == 0: 41 | print('saving the latest model (epoch %d, total_steps %d)' % 42 | (epoch, total_steps)) 43 | model.save('latest') 44 | 45 | if epoch % opt.save_epoch_freq == 0: 46 | print('saving the model at the end of epoch %d, iters %d' % 47 | (epoch, total_steps)) 48 | model.save('latest') 49 | model.save(epoch) 50 | 51 | print('End of epoch %d / %d \t Time Taken: %d sec' % 52 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 53 | model.update_learning_rate() 54 | -------------------------------------------------------------------------------- /facial/train_facial.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2; 2 | python train.py --dataroot /home/ht1/Radboud_selectiongan/ --name facial_exp --model BiGraphGAN --lambda_GAN 5 --lambda_A 10 --lambda_B 10 --dataset_mode aligned --no_lsgan --n_layers 3 --norm batch --batchSize 24 --resize_or_crop no --gpu_ids 0,1,2 --BP_input_nc 3 --no_flip --which_model_netG Graph --niter 100 --niter_decay 100 --checkpoints_dir ./checkpoints --L1_type l1_plus_perL1 --n_layers_D 3 --with_D_PP 1 --with_D_PB 1 --display_id 0 --save_epoch_freq 50 3 | #--continue_train --which_epoch 640 --epoch_count 641 4 | -------------------------------------------------------------------------------- /facial/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__init__.py -------------------------------------------------------------------------------- /facial/util/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__init__.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/html.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/html.cpython-35.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/html.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/html.cpython-36.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/html.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/html.cpython-38.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/image_pool.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/image_pool.cpython-35.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/image_pool.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/image_pool.cpython-36.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/image_pool.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/image_pool.cpython-38.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/util.cpython-35.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/visualizer.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/visualizer.cpython-35.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/visualizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/visualizer.cpython-36.pyc -------------------------------------------------------------------------------- /facial/util/__pycache__/visualizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/__pycache__/visualizer.cpython-38.pyc -------------------------------------------------------------------------------- /facial/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | self.doc = dominate.document(title=title) 18 | if reflesh > 0: 19 | with self.doc.head: 20 | meta(http_equiv="reflesh", content=str(reflesh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, width=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=os.path.join('images', link)): 41 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 42 | br() 43 | p(txt) 44 | 45 | def save(self): 46 | html_file = '%s/index.html' % self.web_dir 47 | f = open(html_file, 'wt') 48 | f.write(self.doc.render()) 49 | f.close() 50 | 51 | 52 | if __name__ == '__main__': 53 | html = HTML('web/', 'test_html') 54 | html.add_header('hello world') 55 | 56 | ims = [] 57 | txts = [] 58 | links = [] 59 | for n in range(4): 60 | ims.append('image_%d.png' % n) 61 | txts.append('text_%d' % n) 62 | links.append('image_%d.png' % n) 63 | html.add_images(ims, txts, links) 64 | html.save() 65 | -------------------------------------------------------------------------------- /facial/util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | 7 | class ImagePool(): 8 | def __init__(self, pool_size): 9 | self.pool_size = pool_size 10 | if self.pool_size > 0: 11 | self.num_imgs = 0 12 | self.images = [] 13 | 14 | def query(self, images): 15 | if self.pool_size == 0: 16 | return Variable(images) 17 | return_images = [] 18 | for image in images: 19 | image = torch.unsqueeze(image, 0) 20 | if self.num_imgs < self.pool_size: 21 | self.num_imgs = self.num_imgs + 1 22 | self.images.append(image) 23 | return_images.append(image) 24 | else: 25 | p = random.uniform(0, 1) 26 | if p > 0.5: 27 | random_id = random.randint(0, self.pool_size-1) 28 | tmp = self.images[random_id].clone() 29 | self.images[random_id] = image 30 | return_images.append(tmp) 31 | else: 32 | return_images.append(image) 33 | return_images = Variable(torch.cat(return_images, 0)) 34 | return return_images 35 | -------------------------------------------------------------------------------- /facial/util/png.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import zlib 3 | 4 | def encode(buf, width, height): 5 | """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """ 6 | assert (width * height * 3 == len(buf)) 7 | bpp = 3 8 | 9 | def raw_data(): 10 | # reverse the vertical line order and add null bytes at the start 11 | row_bytes = width * bpp 12 | for row_start in range((height - 1) * width * bpp, -1, -row_bytes): 13 | yield b'\x00' 14 | yield buf[row_start:row_start + row_bytes] 15 | 16 | def chunk(tag, data): 17 | return [ 18 | struct.pack("!I", len(data)), 19 | tag, 20 | data, 21 | struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag))) 22 | ] 23 | 24 | SIGNATURE = b'\x89PNG\r\n\x1a\n' 25 | COLOR_TYPE_RGB = 2 26 | COLOR_TYPE_RGBA = 6 27 | bit_depth = 8 28 | return b''.join( 29 | [ SIGNATURE ] + 30 | chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) + 31 | chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) + 32 | chunk(b'IEND', b'') 33 | ) 34 | -------------------------------------------------------------------------------- /facial/util/util.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/facial/util/util.pyc -------------------------------------------------------------------------------- /imgs/face_results.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/imgs/face_results.jpeg -------------------------------------------------------------------------------- /imgs/fashion_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/imgs/fashion_results.jpg -------------------------------------------------------------------------------- /imgs/market_results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/imgs/market_results.jpg -------------------------------------------------------------------------------- /imgs/method.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/imgs/method.jpg -------------------------------------------------------------------------------- /imgs/motivation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/imgs/motivation.jpg -------------------------------------------------------------------------------- /market_1501/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/data/__init__.py -------------------------------------------------------------------------------- /market_1501/data/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/data/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/data/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/data/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/data/__pycache__/base_data_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/data/__pycache__/base_data_loader.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/data/__pycache__/base_data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/data/__pycache__/base_data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/data/__pycache__/base_dataset.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/data/__pycache__/base_dataset.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/data/__pycache__/base_dataset.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/data/__pycache__/base_dataset.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/data/__pycache__/custom_dataset_data_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/data/__pycache__/custom_dataset_data_loader.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/data/__pycache__/custom_dataset_data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/data/__pycache__/custom_dataset_data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/data/__pycache__/data_loader.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/data/__pycache__/data_loader.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/data/__pycache__/data_loader.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/data/__pycache__/data_loader.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/data/__pycache__/image_folder.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/data/__pycache__/image_folder.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/data/__pycache__/image_folder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/data/__pycache__/image_folder.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/data/__pycache__/keypoint.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/data/__pycache__/keypoint.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/data/__pycache__/keypoint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/data/__pycache__/keypoint.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataLoader(): 3 | def __init__(self): 4 | pass 5 | 6 | def initialize(self, opt): 7 | self.opt = opt 8 | pass 9 | 10 | def load_data(): 11 | return None 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /market_1501/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | class BaseDataset(data.Dataset): 6 | def __init__(self): 7 | super(BaseDataset, self).__init__() 8 | 9 | def name(self): 10 | return 'BaseDataset' 11 | 12 | def initialize(self, opt): 13 | pass 14 | 15 | def get_transform(opt): 16 | transform_list = [] 17 | if opt.resize_or_crop == 'resize_and_crop': 18 | osize = [opt.loadSize, opt.loadSize] 19 | transform_list.append(transforms.Scale(osize, Image.BICUBIC)) 20 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 21 | elif opt.resize_or_crop == 'crop': 22 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 23 | elif opt.resize_or_crop == 'scale_width': 24 | transform_list.append(transforms.Lambda( 25 | lambda img: __scale_width(img, opt.fineSize))) 26 | elif opt.resize_or_crop == 'scale_width_and_crop': 27 | transform_list.append(transforms.Lambda( 28 | lambda img: __scale_width(img, opt.loadSize))) 29 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 30 | 31 | transform_list += [transforms.ToTensor(), 32 | transforms.Normalize((0.5, 0.5, 0.5), 33 | (0.5, 0.5, 0.5))] 34 | return transforms.Compose(transform_list) 35 | 36 | def __scale_width(img, target_width): 37 | ow, oh = img.size 38 | if (ow == target_width): 39 | return img 40 | w = target_width 41 | h = int(target_width * oh / ow) 42 | return img.resize((w, h), Image.BICUBIC) 43 | -------------------------------------------------------------------------------- /market_1501/data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | 8 | if opt.dataset_mode == 'keypoint': 9 | from data.keypoint import KeyDataset 10 | dataset = KeyDataset() 11 | 12 | else: 13 | raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode) 14 | 15 | print("dataset [%s] was created" % (dataset.name())) 16 | dataset.initialize(opt) 17 | return dataset 18 | 19 | 20 | class CustomDatasetDataLoader(BaseDataLoader): 21 | def name(self): 22 | return 'CustomDatasetDataLoader' 23 | 24 | def initialize(self, opt): 25 | BaseDataLoader.initialize(self, opt) 26 | self.dataset = CreateDataset(opt) 27 | self.dataloader = torch.utils.data.DataLoader( 28 | self.dataset, 29 | batch_size=opt.batchSize, 30 | shuffle=not opt.serial_batches, 31 | num_workers=int(opt.nThreads)) 32 | 33 | def load_data(self): 34 | return self 35 | 36 | def __len__(self): 37 | return min(len(self.dataset), self.opt.max_dataset_size) 38 | 39 | def __iter__(self): 40 | for i, data in enumerate(self.dataloader): 41 | if i >= self.opt.max_dataset_size: 42 | break 43 | yield data 44 | -------------------------------------------------------------------------------- /market_1501/data/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | def CreateDataLoader(opt): 3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 4 | data_loader = CustomDatasetDataLoader() 5 | print(data_loader.name()) 6 | data_loader.initialize(opt) 7 | return data_loader 8 | -------------------------------------------------------------------------------- /market_1501/data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | 34 | return images 35 | 36 | 37 | def default_loader(path): 38 | return Image.open(path).convert('RGB') 39 | 40 | 41 | class ImageFolder(data.Dataset): 42 | 43 | def __init__(self, root, transform=None, return_paths=False, 44 | loader=default_loader): 45 | imgs = make_dataset(root) 46 | if len(imgs) == 0: 47 | raise(RuntimeError("Found 0 images in: " + root + "\n" 48 | "Supported image extensions are: " + 49 | ",".join(IMG_EXTENSIONS))) 50 | 51 | self.root = root 52 | self.imgs = imgs 53 | self.transform = transform 54 | self.return_paths = return_paths 55 | self.loader = loader 56 | 57 | def __getitem__(self, index): 58 | path = self.imgs[index] 59 | img = self.loader(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.return_paths: 63 | return img, path 64 | else: 65 | return img 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | -------------------------------------------------------------------------------- /market_1501/losses/L1_plus_perceptualLoss.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import torchvision.models as models 8 | 9 | class L1_plus_perceptualLoss(nn.Module): 10 | def __init__(self, lambda_L1, lambda_perceptual, perceptual_layers, gpu_ids, percep_is_l1): 11 | super(L1_plus_perceptualLoss, self).__init__() 12 | 13 | self.lambda_L1 = lambda_L1 14 | self.lambda_perceptual = lambda_perceptual 15 | self.gpu_ids = gpu_ids 16 | 17 | self.percep_is_l1 = percep_is_l1 18 | 19 | vgg = models.vgg19(pretrained=True).features 20 | self.vgg_submodel = nn.Sequential() 21 | for i,layer in enumerate(list(vgg)): 22 | self.vgg_submodel.add_module(str(i),layer) 23 | if i == perceptual_layers: 24 | break 25 | self.vgg_submodel = torch.nn.DataParallel(self.vgg_submodel, device_ids=gpu_ids).cuda() 26 | 27 | print(self.vgg_submodel) 28 | 29 | def forward(self, inputs, targets): 30 | if self.lambda_L1 == 0 and self.lambda_perceptual == 0: 31 | return torch.zeros(1).cuda(), torch.zeros(1), torch.zeros(1) 32 | # normal L1 33 | loss_l1 = F.l1_loss(inputs, targets) * self.lambda_L1 34 | 35 | # perceptual L1 36 | mean = torch.FloatTensor(3) 37 | mean[0] = 0.485 38 | mean[1] = 0.456 39 | mean[2] = 0.406 40 | mean = mean.resize(1, 3, 1, 1).cuda() 41 | 42 | std = torch.FloatTensor(3) 43 | std[0] = 0.229 44 | std[1] = 0.224 45 | std[2] = 0.225 46 | std = std.resize(1, 3, 1, 1).cuda() 47 | 48 | fake_p2_norm = (inputs + 1)/2 # [-1, 1] => [0, 1] 49 | fake_p2_norm = (fake_p2_norm - mean)/std 50 | 51 | input_p2_norm = (targets + 1)/2 # [-1, 1] => [0, 1] 52 | input_p2_norm = (input_p2_norm - mean)/std 53 | 54 | 55 | fake_p2_norm = self.vgg_submodel(fake_p2_norm) 56 | input_p2_norm = self.vgg_submodel(input_p2_norm) 57 | input_p2_norm_no_grad = input_p2_norm.detach() 58 | 59 | if self.percep_is_l1 == 1: 60 | # use l1 for perceptual loss 61 | loss_perceptual = F.l1_loss(fake_p2_norm, input_p2_norm_no_grad) * self.lambda_perceptual 62 | else: 63 | # use l2 for perceptual loss 64 | loss_perceptual = F.mse_loss(fake_p2_norm, input_p2_norm_no_grad) * self.lambda_perceptual 65 | 66 | loss = loss_l1 + loss_perceptual 67 | 68 | return loss, loss_l1, loss_perceptual 69 | 70 | -------------------------------------------------------------------------------- /market_1501/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/losses/__init__.py -------------------------------------------------------------------------------- /market_1501/losses/__pycache__/L1_plus_perceptualLoss.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/losses/__pycache__/L1_plus_perceptualLoss.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/losses/__pycache__/L1_plus_perceptualLoss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/losses/__pycache__/L1_plus_perceptualLoss.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/losses/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/losses/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/losses/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/losses/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/models/__init__.py -------------------------------------------------------------------------------- /market_1501/models/__pycache__/PATN.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/models/__pycache__/PATN.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/models/__pycache__/PATN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/models/__pycache__/PATN.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/models/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/models/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/models/__pycache__/base_model.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/models/__pycache__/base_model.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/models/__pycache__/base_model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/models/__pycache__/base_model.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/models/__pycache__/model_variants.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/models/__pycache__/model_variants.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/models/__pycache__/model_variants.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/models/__pycache__/model_variants.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/models/__pycache__/models.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/models/__pycache__/models.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/models/__pycache__/models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/models/__pycache__/models.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/models/__pycache__/networks.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/models/__pycache__/networks.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/models/__pycache__/networks.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/models/__pycache__/networks.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class BaseModel(nn.Module): 7 | 8 | def __init__(self): 9 | super(BaseModel, self).__init__() 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, opt): 15 | self.opt = opt 16 | self.gpu_ids = opt.gpu_ids 17 | self.isTrain = opt.isTrain 18 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 19 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 20 | 21 | def set_input(self, input): 22 | self.input = input 23 | 24 | def forward(self): 25 | pass 26 | 27 | # used in test time, no backprop 28 | def test(self): 29 | pass 30 | 31 | def get_image_paths(self): 32 | pass 33 | 34 | def optimize_parameters(self): 35 | pass 36 | 37 | def get_current_visuals(self): 38 | return self.input 39 | 40 | def get_current_errors(self): 41 | return {} 42 | 43 | def save(self, label): 44 | pass 45 | 46 | # helper saving function that can be used by subclasses 47 | def save_network(self, network, network_label, epoch_label, gpu_ids): 48 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 49 | save_path = os.path.join(self.save_dir, save_filename) 50 | torch.save(network.cpu().state_dict(), save_path) 51 | if len(gpu_ids) and torch.cuda.is_available(): 52 | network.cuda(gpu_ids[0]) 53 | 54 | # helper loading function that can be used by subclasses 55 | def load_network(self, network, network_label, epoch_label): 56 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 57 | save_path = os.path.join(self.save_dir, save_filename) 58 | network.load_state_dict(torch.load(save_path)) 59 | 60 | # update learning rate (called once every epoch) 61 | def update_learning_rate(self): 62 | for scheduler in self.schedulers: 63 | scheduler.step() 64 | lr = self.optimizers[0].param_groups[0]['lr'] 65 | print('learning rate = %.7f' % lr) 66 | -------------------------------------------------------------------------------- /market_1501/models/models.py: -------------------------------------------------------------------------------- 1 | 2 | def create_model(opt): 3 | model = None 4 | print(opt.model) 5 | 6 | if opt.model == 'BiGraphGAN': 7 | assert opt.dataset_mode == 'keypoint' 8 | from .BiGraphGAN import TransferModel 9 | model = TransferModel() 10 | 11 | else: 12 | raise ValueError("Model [%s] not recognized." % opt.model) 13 | model.initialize(opt) 14 | print("model [%s] was created" % (model.name())) 15 | return model 16 | -------------------------------------------------------------------------------- /market_1501/models/test_model.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Variable 2 | from collections import OrderedDict 3 | import util.util as util 4 | from .base_model import BaseModel 5 | from . import networks 6 | 7 | 8 | class TestModel(BaseModel): 9 | def name(self): 10 | return 'TestModel' 11 | 12 | def initialize(self, opt): 13 | assert(not opt.isTrain) 14 | BaseModel.initialize(self, opt) 15 | self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize) 16 | 17 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, 18 | opt.ngf, opt.which_model_netG, 19 | opt.norm, not opt.no_dropout, 20 | opt.init_type, 21 | self.gpu_ids) 22 | which_epoch = opt.which_epoch 23 | self.load_network(self.netG, 'G', which_epoch) 24 | 25 | print('---------- Networks initialized -------------') 26 | networks.print_network(self.netG) 27 | print('-----------------------------------------------') 28 | 29 | def set_input(self, input): 30 | # we need to use single_dataset mode 31 | input_A = input['A'] 32 | self.input_A.resize_(input_A.size()).copy_(input_A) 33 | self.image_paths = input['A_paths'] 34 | 35 | def test(self): 36 | self.real_A = Variable(self.input_A) 37 | self.fake_B = self.netG(self.real_A) 38 | 39 | # get image paths 40 | def get_image_paths(self): 41 | return self.image_paths 42 | 43 | def get_current_visuals(self): 44 | real_A = util.tensor2im(self.real_A.data) 45 | fake_B = util.tensor2im(self.fake_B.data) 46 | return OrderedDict([('real_A', real_A), ('fake_B', fake_B)]) 47 | -------------------------------------------------------------------------------- /market_1501/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/options/__init__.py -------------------------------------------------------------------------------- /market_1501/options/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/options/__init__.pyc -------------------------------------------------------------------------------- /market_1501/options/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/options/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/options/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/options/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/options/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/options/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /market_1501/options/__pycache__/base_options.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/options/__pycache__/base_options.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/options/__pycache__/base_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/options/__pycache__/base_options.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/options/__pycache__/base_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/options/__pycache__/base_options.cpython-37.pyc -------------------------------------------------------------------------------- /market_1501/options/__pycache__/test_options.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/options/__pycache__/test_options.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/options/__pycache__/test_options.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/options/__pycache__/test_options.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/options/__pycache__/train_options.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/options/__pycache__/train_options.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/options/__pycache__/train_options.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/options/__pycache__/train_options.cpython-37.pyc -------------------------------------------------------------------------------- /market_1501/options/base_options.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/options/base_options.pyc -------------------------------------------------------------------------------- /market_1501/options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self): 6 | BaseOptions.initialize(self) 7 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 8 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 9 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 10 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 11 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 12 | self.parser.add_argument('--how_many', type=int, default=200, help='how many test images to run') 13 | 14 | self.isTrain = False 15 | -------------------------------------------------------------------------------- /market_1501/options/train_options.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/options/train_options.pyc -------------------------------------------------------------------------------- /market_1501/test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | from options.test_options import TestOptions 4 | from data.data_loader import CreateDataLoader 5 | from models.models import create_model 6 | from util.visualizer import Visualizer 7 | from util import html 8 | import time 9 | 10 | opt = TestOptions().parse() 11 | opt.nThreads = 1 # test code only supports nThreads = 1 12 | opt.batchSize = 1 # test code only supports batchSize = 1 13 | opt.serial_batches = True # no shuffle 14 | opt.no_flip = True # no flip 15 | 16 | data_loader = CreateDataLoader(opt) 17 | dataset = data_loader.load_data() 18 | model = create_model(opt) 19 | visualizer = Visualizer(opt) 20 | # create website 21 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 22 | 23 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 24 | 25 | print(opt.how_many) 26 | print(len(dataset)) 27 | 28 | model = model.eval() 29 | print(model.training) 30 | 31 | opt.how_many = 999999 32 | # test 33 | for i, data in enumerate(dataset): 34 | print(' process %d/%d img ..'%(i,opt.how_many)) 35 | if i >= opt.how_many: 36 | break 37 | model.set_input(data) 38 | startTime = time.time() 39 | model.test() 40 | endTime = time.time() 41 | print(endTime-startTime) 42 | visuals = model.get_current_visuals() 43 | img_path = model.get_image_paths() 44 | img_path = [img_path] 45 | print(img_path) 46 | visualizer.save_images(webpage, visuals, img_path) 47 | 48 | webpage.save() 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /market_1501/test_market.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0; 2 | python test.py --dataroot ./SelectionGAN/person_transfer/datasets/market_data/ --name market_exp --model BiGraphGAN --phase test --dataset_mode keypoint --norm batch --batchSize 1 --resize_or_crop no --gpu_ids 0 --BP_input_nc 18 --no_flip --which_model_netG Graph --checkpoints_dir ./checkpoints --pairLst ./SelectionGAN/person_transfer/datasets/market_data/market-pairs-test.csv --which_epoch 700 --results_dir ./results/ --display_id 0; 3 | -------------------------------------------------------------------------------- /market_1501/test_market_pretrained.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0; 2 | python test.py --dataroot ./SelectionGAN/person_transfer/datasets/market_data/ --name market_pretrained --model BiGraphGAN --phase test --dataset_mode keypoint --norm batch --batchSize 1 --resize_or_crop no --gpu_ids 0 --BP_input_nc 18 --no_flip --which_model_netG Graph --checkpoints_dir ./BiGraphGAN/scripts/checkpoints --pairLst ./SelectionGAN/person_transfer/datasets/market_data/market-pairs-test.csv --which_epoch latest --results_dir ./results/ --display_id 0; -------------------------------------------------------------------------------- /market_1501/tool/calPCKH_fashion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pandas as pd 3 | import numpy as np 4 | import json 5 | 6 | MISSING_VALUE = -1 7 | 8 | PARTS_SEL = [0, 1, 14, 15, 16, 17] 9 | 10 | target_annotation = './fashion_data/fasion-resize-annotation-test.csv' 11 | pred_annotation = './results/fashion_PATN/pckh.csv' 12 | 13 | 14 | ''' 15 | hz: head size 16 | alpha: norm factor 17 | px, py: predict coords 18 | tx, ty: target coords 19 | ''' 20 | def isRight(px, py, tx, ty, hz, alpha): 21 | if px == -1 or py == -1 or tx == -1 or ty == -1: 22 | return 0 23 | 24 | if abs(px - tx) < hz[0] * alpha and abs(py - ty) < hz[1] * alpha: 25 | return 1 26 | else: 27 | return 0 28 | 29 | 30 | def how_many_right_seq(px, py, tx, ty, hz, alpha): 31 | nRight = 0 32 | for i in range(len(px)): 33 | nRight = nRight + isRight(px[i], py[i], tx[i], ty[i], hz, alpha) 34 | 35 | return nRight 36 | 37 | 38 | def ValidPoints(tx): 39 | nValid = 0 40 | for item in tx: 41 | if item != -1: 42 | nValid = nValid + 1 43 | return nValid 44 | 45 | 46 | def get_head_wh(x_coords, y_coords): 47 | final_w, final_h = -1, -1 48 | component_count = 0 49 | save_componets = [] 50 | for component in PARTS_SEL: 51 | if x_coords[component] == MISSING_VALUE or y_coords[component] == MISSING_VALUE: 52 | continue 53 | else: 54 | component_count += 1 55 | save_componets.append([x_coords[component], y_coords[component]]) 56 | if component_count >= 2: 57 | x_cords = [] 58 | y_cords = [] 59 | for component in save_componets: 60 | x_cords.append(component[0]) 61 | y_cords.append(component[1]) 62 | xmin = min(x_cords) 63 | xmax = max(x_cords) 64 | ymin = min(y_cords) 65 | ymax = max(y_cords) 66 | final_w = xmax - xmin 67 | final_h = ymax - ymin 68 | return final_w, final_h 69 | 70 | 71 | tAnno = pd.read_csv(target_annotation, sep=':') 72 | pAnno = pd.read_csv(pred_annotation, sep=':') 73 | 74 | pRows = pAnno.shape[0] 75 | 76 | nAll = 0 77 | nCorrect = 0 78 | alpha = 0.5 79 | for i in range(pRows): 80 | pValues = pAnno.iloc[i].values 81 | pname = pValues[0] 82 | pycords = json.loads(pValues[1]) # list of numbers 83 | pxcords = json.loads(pValues[2]) 84 | 85 | if '_vis' in pname: 86 | tname = pname[:-8] 87 | else: 88 | tname = pname[:-4] 89 | 90 | if '___' in tname: 91 | tname = tname.split('___')[1] 92 | else: 93 | tname = tname.split('jpg_')[1] 94 | 95 | print(tname) 96 | tValues = tAnno.query('name == "%s"' % (tname)).values[0] 97 | tycords = json.loads(tValues[1]) # list of numbers 98 | txcords = json.loads(tValues[2]) 99 | 100 | 101 | xBox, yBox = get_head_wh(txcords, tycords) 102 | if xBox == -1 or yBox == -1: 103 | continue 104 | 105 | head_size = (xBox, yBox) 106 | nAll = nAll + ValidPoints(tycords) 107 | nCorrect = nCorrect + how_many_right_seq(pxcords, pycords, txcords, tycords, head_size, alpha) 108 | 109 | print('%d/%d %f' % (nCorrect, nAll, nCorrect * 1.0 / nAll)) 110 | -------------------------------------------------------------------------------- /market_1501/tool/calPCKH_market.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pandas as pd 3 | import numpy as np 4 | import json 5 | 6 | MISSING_VALUE = -1 7 | 8 | PARTS_SEL = [0, 1, 14, 15, 16, 17] 9 | 10 | # fix the PATH 11 | target_annotation = './market_data/market-annotation-test.csv' 12 | pred_annotation = '/results/market_PATN/pckh.csv' 13 | 14 | 15 | ''' 16 | hz: head size 17 | alpha: norm factor 18 | px, py: predict coords 19 | tx, ty: target coords 20 | ''' 21 | def isRight(px, py, tx, ty, hz, alpha): 22 | if px == -1 or py == -1 or tx == -1 or ty == -1: 23 | return 0 24 | 25 | if abs(px-tx) < hz[0]*alpha and abs(py-ty) < hz[1]*alpha: 26 | return 1 27 | else: 28 | return 0 29 | 30 | def how_many_right_seq(px, py, tx, ty, hz, alpha): 31 | nRight = 0 32 | for i in range(len(px)): 33 | nRight = nRight + isRight(px[i], py[i], tx[i], ty[i], hz, alpha) 34 | 35 | return nRight 36 | 37 | def ValidPoints(tx): 38 | nValid = 0 39 | for item in tx: 40 | if item != -1: 41 | nValid = nValid + 1 42 | return nValid 43 | 44 | def get_head_wh(x_coords, y_coords): 45 | final_w, final_h = -1, -1 46 | component_count = 0 47 | save_componets = [] 48 | for component in PARTS_SEL: 49 | if x_coords[component] == MISSING_VALUE or y_coords[component] == MISSING_VALUE: 50 | continue 51 | else: 52 | component_count += 1 53 | save_componets.append([x_coords[component], y_coords[component]]) 54 | if component_count >= 2: 55 | x_cords = [] 56 | y_cords = [] 57 | for component in save_componets: 58 | x_cords.append(component[0]) 59 | y_cords.append(component[1]) 60 | xmin = min(x_cords) 61 | xmax = max(x_cords) 62 | ymin = min(y_cords) 63 | ymax = max(y_cords) 64 | final_w = xmax - xmin 65 | final_h = ymax - ymin 66 | return final_w, final_h 67 | 68 | 69 | 70 | 71 | 72 | tAnno = pd.read_csv(target_annotation, sep=':') 73 | pAnno = pd.read_csv(pred_annotation, sep=':') 74 | 75 | pRows = pAnno.shape[0] 76 | 77 | nAll = 0 78 | nCorrect = 0 79 | alpha = 0.5 80 | for i in range(pRows): 81 | pValues = pAnno.iloc[i].values 82 | pname = pValues[0] 83 | pycords = json.loads(pValues[1]) #list of numbers 84 | pxcords = json.loads(pValues[2]) 85 | 86 | if '_vis' in pname: 87 | tname = pname[:-8] 88 | else: 89 | tname = pname[:-4] 90 | 91 | if '___' in tname: 92 | tname = tname.split('___')[1] 93 | else: 94 | tname = tname.split('jpg_')[1] 95 | 96 | print(tname) 97 | tValues = tAnno.query('name == "%s"' %(tname)).values[0] 98 | tycords = json.loads(tValues[1]) #list of numbers 99 | txcords = json.loads(tValues[2]) 100 | 101 | xBox, yBox = get_head_wh(txcords, tycords) 102 | if xBox == -1 or yBox == -1: 103 | continue 104 | 105 | head_size = (xBox, yBox) 106 | nAll = nAll + ValidPoints(tycords) 107 | nCorrect = nCorrect + how_many_right_seq(pxcords, pycords, txcords, tycords, head_size, alpha) 108 | 109 | 110 | print('%d/%d %f' %(nCorrect, nAll, nCorrect*1.0/nAll)) 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /market_1501/tool/create_pairs_dataset.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pose_utils 3 | from itertools import permutations 4 | 5 | LABELS = ['nose', 'neck', 'Rsho', 'Relb', 'Rwri', 'Lsho', 'Lelb', 'Lwri', 6 | 'Rhip', 'Rkne', 'Rank', 'Lhip', 'Lkne', 'Lank', 'Leye', 'Reye', 'Lear', 'Rear'] 7 | 8 | MISSING_VALUE = -1 9 | 10 | def give_name_to_keypoints(array): 11 | res = {} 12 | for i, name in enumerate(LABELS): 13 | if array[i][0] != MISSING_VALUE and array[i][1] != MISSING_VALUE: 14 | res[name] = array[i][::-1] 15 | return res 16 | 17 | 18 | def pose_check_valid(kp_array): 19 | kp = give_name_to_keypoints(kp_array) 20 | return check_keypoints_present(kp, ['Rhip', 'Lhip', 'Lsho', 'Rsho']) 21 | 22 | 23 | def check_keypoints_present(kp, kp_names): 24 | result = True 25 | for name in kp_names: 26 | result = result and (name in kp) 27 | return result 28 | 29 | def filter_not_valid(df_keypoints): 30 | def check_valid(x): 31 | kp_array = pose_utils.load_pose_cords_from_strings(x['keypoints_y'], x['keypoints_x']) 32 | distractor = x['name'].startswith('-1') or x['name'].startswith('0000') 33 | return pose_check_valid(kp_array) and not distractor 34 | return df_keypoints[df_keypoints.apply(check_valid, axis=1)].copy() 35 | 36 | 37 | def make_pairs(df): 38 | persons = df.apply(lambda x: '_'.join(x['name'].split('_')[0:1]), axis=1) 39 | df['person'] = persons 40 | fr, to = [], [] 41 | for person in pd.unique(persons): 42 | pairs = zip(*list(permutations(df[df['person'] == person]['name'], 2))) 43 | if len(pairs) != 0: 44 | fr += list(pairs[0]) 45 | to += list(pairs[1]) 46 | pair_df = pd.DataFrame(index=range(len(fr))) 47 | pair_df['from'] = fr 48 | pair_df['to'] = to 49 | return pair_df 50 | 51 | 52 | if __name__ == "__main__": 53 | images_for_test = 12000 54 | 55 | annotations_file_train = './market_data/market-annotation-test.csv' 56 | pairs_file_train = './market_data/example_market-pairs-train.csv' 57 | 58 | df_keypoints = pd.read_csv(annotations_file_train, sep=':') 59 | df = filter_not_valid(df_keypoints) 60 | print ('Compute pair dataset for train...') 61 | pairs_df_train = make_pairs(df) 62 | print ('Number of pairs: %s' % len(pairs_df_train)) 63 | pairs_df_train.to_csv(pairs_file_train, index=False) 64 | 65 | annotations_file_test= './market_data/market-annotation-test.csv' 66 | pairs_file_test = './market_data/example_market-pairs-test.csv' 67 | 68 | print ('Compute pair dataset for test...') 69 | df_keypoints = pd.read_csv(annotations_file_test, sep=':') 70 | df = filter_not_valid(df_keypoints) 71 | pairs_df_test = make_pairs(df) 72 | pairs_df_test = pairs_df_test.sample(n=min(images_for_test, pairs_df_test.shape[0]), replace=False, random_state=0) 73 | print ('Number of pairs: %s' % len(pairs_df_test)) 74 | pairs_df_test.to_csv(pairs_file_test, index=False) 75 | 76 | -------------------------------------------------------------------------------- /market_1501/tool/crop_fashion.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | 4 | img_dir = './results/fashion_PATN_test/test_latest/images' 5 | save_dir = './results/fashion_PATN_test/test_latest/images_crop' 6 | 7 | if not os.path.exists(save_dir): 8 | os.mkdir(save_dir) 9 | 10 | cnt = 0 11 | 12 | for item in os.listdir(img_dir): 13 | if not item.endswith('.jpg') and not item.endswith('.png'): 14 | continue 15 | cnt = cnt + 1 16 | print('%d/8570 ...' %(cnt)) 17 | img = Image.open(os.path.join(img_dir, item)) 18 | imgcrop = img.crop((704, 0, 880, 256)) 19 | imgcrop.save(os.path.join(save_dir, item)) 20 | -------------------------------------------------------------------------------- /market_1501/tool/crop_market.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | 4 | img_dir = './results/market_PATN_test/test_latest/images' 5 | save_dir = './results/market_PATN_test/test_latest/images_crop' 6 | 7 | if not os.path.exists(save_dir): 8 | os.mkdir(save_dir) 9 | 10 | cnt = 0 11 | for item in os.listdir(img_dir): 12 | if not item.endswith('.jpg') and not item.endswith('.png'): 13 | continue 14 | cnt = cnt + 1 15 | print('%d/12000 ...' %(cnt)) 16 | img = Image.open(os.path.join(img_dir, item)) 17 | # for 5 split 18 | imgcrop = img.crop((256, 0, 320, 128)) 19 | imgcrop.save(os.path.join(save_dir, item)) 20 | -------------------------------------------------------------------------------- /market_1501/tool/generate_fashion_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from PIL import Image 4 | 5 | IMG_EXTENSIONS = [ 6 | '.jpg', '.JPG', '.jpeg', '.JPEG', 7 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 8 | ] 9 | 10 | def is_image_file(filename): 11 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 12 | 13 | def make_dataset(dir): 14 | images = [] 15 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 16 | new_root = './fashion_data' 17 | if not os.path.exists(new_root): 18 | os.mkdir(new_root) 19 | 20 | train_root = './fashion_data/train' 21 | if not os.path.exists(train_root): 22 | os.mkdir(train_root) 23 | 24 | test_root = './fashion_data/test' 25 | if not os.path.exists(test_root): 26 | os.mkdir(test_root) 27 | 28 | train_images = [] 29 | train_f = open('./fashion_data/train.lst', 'r') 30 | for lines in train_f: 31 | lines = lines.strip() 32 | if lines.endswith('.jpg'): 33 | train_images.append(lines) 34 | 35 | test_images = [] 36 | test_f = open('./fashion_data/test.lst', 'r') 37 | for lines in test_f: 38 | lines = lines.strip() 39 | if lines.endswith('.jpg'): 40 | test_images.append(lines) 41 | 42 | print(train_images, test_images) 43 | 44 | 45 | for root, _, fnames in sorted(os.walk(dir)): 46 | for fname in fnames: 47 | if is_image_file(fname): 48 | path = os.path.join(root, fname) 49 | path_names = path.split('/') 50 | # path_names[2] = path_names[2].replace('_', '') 51 | path_names[3] = path_names[3].replace('_', '') 52 | path_names[4] = path_names[4].split('_')[0] + "_" + "".join(path_names[4].split('_')[1:]) 53 | path_names = "".join(path_names) 54 | # new_path = os.path.join(root, path_names) 55 | img = Image.open(path) 56 | imgcrop = img.crop((40, 0, 216, 256)) 57 | if new_path in train_images: 58 | imgcrop.save(os.path.join(train_root, path_names)) 59 | elif new_path in test_images: 60 | imgcrop.save(os.path.join(test_root, path_names)) 61 | 62 | make_dataset('./fashion') 63 | -------------------------------------------------------------------------------- /market_1501/tool/generate_pose_map_fashion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import json 4 | import os 5 | 6 | MISSING_VALUE = -1 7 | # fix PATH 8 | img_dir = 'fashion_data' #raw image path 9 | annotations_file = 'fashion_data/fasion-resize-annotation-train.csv' #pose annotation path 10 | save_path = 'fashion_data/trainK' #path to store pose maps 11 | 12 | def load_pose_cords_from_strings(y_str, x_str): 13 | y_cords = json.loads(y_str) 14 | x_cords = json.loads(x_str) 15 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1) 16 | 17 | def cords_to_map(cords, img_size, sigma=6): 18 | result = np.zeros(img_size + cords.shape[0:1], dtype='uint8') 19 | for i, point in enumerate(cords): 20 | if point[0] == MISSING_VALUE or point[1] == MISSING_VALUE: 21 | continue 22 | xx, yy = np.meshgrid(np.arange(img_size[1]), np.arange(img_size[0])) 23 | result[..., i] = np.exp(-((yy - point[0]) ** 2 + (xx - point[1]) ** 2) / (2 * sigma ** 2)) 24 | # result[..., i] = np.where(((yy - point[0]) ** 2 + (xx - point[1]) ** 2) < (sigma ** 2), 1, 0) 25 | return result 26 | 27 | def compute_pose(image_dir, annotations_file, savePath, sigma): 28 | annotations_file = pd.read_csv(annotations_file, sep=':') 29 | annotations_file = annotations_file.set_index('name') 30 | image_size = (256, 176) 31 | cnt = len(annotations_file) 32 | for i in range(cnt): 33 | print('processing %d / %d ...' %(i, cnt)) 34 | row = annotations_file.iloc[i] 35 | name = row.name 36 | print(savePath, name) 37 | file_name = os.path.join(savePath, name + '.npy') 38 | kp_array = load_pose_cords_from_strings(row.keypoints_y, row.keypoints_x) 39 | pose = cords_to_map(kp_array, image_size, sigma) 40 | np.save(file_name, pose) 41 | # input() 42 | 43 | compute_pose(img_dir, annotations_file, save_path) 44 | 45 | 46 | 47 | -------------------------------------------------------------------------------- /market_1501/tool/generate_pose_map_market.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import json 4 | import os 5 | 6 | MISSING_VALUE = -1 7 | 8 | img_dir = 'market_data/train' #raw image path 9 | annotations_file = 'market_data/market-annotation-train.csv' #pose annotation path 10 | save_path = 'market_data/trainK' #path to store pose maps 11 | 12 | def load_pose_cords_from_strings(y_str, x_str): 13 | y_cords = json.loads(y_str) 14 | x_cords = json.loads(x_str) 15 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1) 16 | 17 | def cords_to_map(cords, img_size, sigma=6): 18 | result = np.zeros(img_size + cords.shape[0:1], dtype='float32') 19 | for i, point in enumerate(cords): 20 | if point[0] == MISSING_VALUE or point[1] == MISSING_VALUE: 21 | continue 22 | xx, yy = np.meshgrid(np.arange(img_size[1]), np.arange(img_size[0])) 23 | result[..., i] = np.exp(-((yy - point[0]) ** 2 + (xx - point[1]) ** 2) / (2 * sigma ** 2)) 24 | return result 25 | 26 | def compute_pose(image_dir, annotations_file, savePath): 27 | annotations_file = pd.read_csv(annotations_file, sep=':') 28 | annotations_file = annotations_file.set_index('name') 29 | image_size = (128, 64) 30 | cnt = len(annotations_file) 31 | for i in range(cnt): 32 | print('processing %d / %d ...' %(i, cnt)) 33 | row = annotations_file.iloc[i] 34 | name = row.name 35 | print(savePath, name) 36 | file_name = os.path.join(savePath, name + '.npy') 37 | kp_array = load_pose_cords_from_strings(row.keypoints_y, row.keypoints_x) 38 | pose = cords_to_map(kp_array, image_size) 39 | np.save(file_name, pose) 40 | 41 | compute_pose(img_dir, annotations_file, save_path) 42 | 43 | -------------------------------------------------------------------------------- /market_1501/tool/resize_fashion.py: -------------------------------------------------------------------------------- 1 | from skimage.io import imread, imsave 2 | from skimage.transform import resize 3 | import os 4 | import numpy as np 5 | import pandas as pd 6 | import json 7 | 8 | def resize_dataset(folder, new_folder, new_size = (256, 176), crop_bord=40): 9 | if not os.path.exists(new_folder): 10 | os.makedirs(new_folder) 11 | for name in os.listdir(folder): 12 | old_name = os.path.join(folder, name) 13 | new_name = os.path.join(new_folder, name) 14 | 15 | img = imread(old_name) 16 | if crop_bord == 0: 17 | pass 18 | else: 19 | img = img[:, crop_bord:-crop_bord] 20 | 21 | img = resize(img, new_size, preserve_range=True).astype(np.uint8) 22 | 23 | imsave(new_name, img) 24 | 25 | def resize_annotations(name, new_name, new_size = (256, 176), old_size = (256, 256), crop_bord=40): 26 | df = pd.read_csv(name, sep=':') 27 | 28 | ratio_y = new_size[0] / float(old_size[0]) 29 | ratio_x = new_size[1] / float(old_size[1] - 2 * crop_bord) 30 | 31 | def modify(values, ratio, crop): 32 | val = np.array(json.loads(values)) 33 | mask = val == -1 34 | val = ((val - crop) * ratio).astype(int) 35 | val[mask] = -1 36 | return str(list(val)) 37 | 38 | df['keypoints_y'] = df.apply(lambda row: modify(row['keypoints_y'], ratio_y, 0), axis=1) 39 | df['keypoints_x'] = df.apply(lambda row: modify(row['keypoints_x'], ratio_x, crop_bord), axis=1) 40 | 41 | df.to_csv(new_name, sep=':', index=False) 42 | 43 | 44 | root_dir = 'xxx' 45 | resize_dataset(root_dir + '/test', root_dir + 'fashion_resize/test') 46 | resize_annotations(root_dir + 'fasion-annotation-test.csv', root_dir + 'fasion-resize-annotation-test.csv') 47 | 48 | resize_dataset(root_dir + '/train', root_dir + 'fashion_resize/train') 49 | resize_annotations(root_dir + 'fasion-annotation-train.csv', root_dir + 'fasion-resize-annotation-train.csv') 50 | 51 | 52 | -------------------------------------------------------------------------------- /market_1501/tool/rm_insnorm_running_vars.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | ckp_path = './checkpoints/fashion_PATN/latest_net_netG.pth' 4 | save_path = './checkpoints/fashion_PATN_v1.0/latest_net_netG.pth' 5 | states_dict = torch.load(ckp_path) 6 | states_dict_new = states_dict.copy() 7 | for key in states_dict.keys(): 8 | if "running_var" in key or "running_mean" in key: 9 | del states_dict_new[key] 10 | 11 | torch.save(states_dict_new, save_path) -------------------------------------------------------------------------------- /market_1501/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data.data_loader import CreateDataLoader 4 | from models.models import create_model 5 | from util.visualizer import Visualizer 6 | 7 | opt = TrainOptions().parse() 8 | data_loader = CreateDataLoader(opt) 9 | dataset = data_loader.load_data() 10 | dataset_size = len(data_loader) 11 | print('#training images = %d' % dataset_size) 12 | 13 | model = create_model(opt) 14 | visualizer = Visualizer(opt) 15 | total_steps = 0 16 | 17 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): 18 | epoch_start_time = time.time() 19 | epoch_iter = 0 20 | 21 | for i, data in enumerate(dataset): 22 | iter_start_time = time.time() 23 | visualizer.reset() 24 | total_steps += opt.batchSize 25 | epoch_iter += opt.batchSize 26 | model.set_input(data) 27 | model.optimize_parameters() 28 | 29 | if total_steps % opt.display_freq == 0: 30 | save_result = total_steps % opt.update_html_freq == 0 31 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 32 | 33 | if total_steps % opt.print_freq == 0: 34 | errors = model.get_current_errors() 35 | t = (time.time() - iter_start_time) / opt.batchSize 36 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 37 | if opt.display_id > 0: 38 | visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors) 39 | 40 | if total_steps % opt.save_latest_freq == 0: 41 | print('saving the latest model (epoch %d, total_steps %d)' % 42 | (epoch, total_steps)) 43 | model.save('latest') 44 | 45 | if epoch % opt.save_epoch_freq == 0: 46 | print('saving the model at the end of epoch %d, iters %d' % 47 | (epoch, total_steps)) 48 | model.save('latest') 49 | model.save(epoch) 50 | 51 | print('End of epoch %d / %d \t Time Taken: %d sec' % 52 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 53 | model.update_learning_rate() 54 | -------------------------------------------------------------------------------- /market_1501/train_market.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0; 2 | python train.py --dataroot ./SelectionGAN/person_transfer/datasets/market_data/ --name market_exp --model BiGraphGAN --lambda_GAN 5 --lambda_A 10 --lambda_B 10 --dataset_mode keypoint --no_lsgan --n_layers 3 --norm batch --batchSize 32 --resize_or_crop no --gpu_ids 0 --BP_input_nc 18 --no_flip --which_model_netG Graph --niter 500 --niter_decay 200 --checkpoints_dir ./checkpoints --pairLst ./SelectionGAN/person_transfer/datasets/market_data/market-pairs-train.csv --L1_type l1_plus_perL1 --n_layers_D 3 --with_D_PP 1 --with_D_PB 1 --display_id 0 3 | #--continue_train --which_epoch 640 --epoch_count 641 -------------------------------------------------------------------------------- /market_1501/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/util/__init__.py -------------------------------------------------------------------------------- /market_1501/util/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/util/__init__.pyc -------------------------------------------------------------------------------- /market_1501/util/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/util/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/util/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/util/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/util/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/util/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /market_1501/util/__pycache__/html.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/util/__pycache__/html.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/util/__pycache__/html.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/util/__pycache__/html.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/util/__pycache__/image_pool.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/util/__pycache__/image_pool.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/util/__pycache__/image_pool.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/util/__pycache__/image_pool.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/util/__pycache__/util.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/util/__pycache__/util.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/util/__pycache__/util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/util/__pycache__/util.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/util/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/util/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /market_1501/util/__pycache__/visualizer.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/util/__pycache__/visualizer.cpython-35.pyc -------------------------------------------------------------------------------- /market_1501/util/__pycache__/visualizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/util/__pycache__/visualizer.cpython-36.pyc -------------------------------------------------------------------------------- /market_1501/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | self.doc = dominate.document(title=title) 18 | if reflesh > 0: 19 | with self.doc.head: 20 | meta(http_equiv="reflesh", content=str(reflesh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, width=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=os.path.join('images', link)): 41 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 42 | br() 43 | p(txt) 44 | 45 | def save(self): 46 | html_file = '%s/index.html' % self.web_dir 47 | f = open(html_file, 'wt') 48 | f.write(self.doc.render()) 49 | f.close() 50 | 51 | 52 | if __name__ == '__main__': 53 | html = HTML('web/', 'test_html') 54 | html.add_header('hello world') 55 | 56 | ims = [] 57 | txts = [] 58 | links = [] 59 | for n in range(4): 60 | ims.append('image_%d.png' % n) 61 | txts.append('text_%d' % n) 62 | links.append('image_%d.png' % n) 63 | html.add_images(ims, txts, links) 64 | html.save() 65 | -------------------------------------------------------------------------------- /market_1501/util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | from torch.autograd import Variable 5 | 6 | 7 | class ImagePool(): 8 | def __init__(self, pool_size): 9 | self.pool_size = pool_size 10 | if self.pool_size > 0: 11 | self.num_imgs = 0 12 | self.images = [] 13 | 14 | def query(self, images): 15 | if self.pool_size == 0: 16 | return Variable(images) 17 | return_images = [] 18 | for image in images: 19 | image = torch.unsqueeze(image, 0) 20 | if self.num_imgs < self.pool_size: 21 | self.num_imgs = self.num_imgs + 1 22 | self.images.append(image) 23 | return_images.append(image) 24 | else: 25 | p = random.uniform(0, 1) 26 | if p > 0.5: 27 | random_id = random.randint(0, self.pool_size-1) 28 | tmp = self.images[random_id].clone() 29 | self.images[random_id] = image 30 | return_images.append(tmp) 31 | else: 32 | return_images.append(image) 33 | return_images = Variable(torch.cat(return_images, 0)) 34 | return return_images 35 | -------------------------------------------------------------------------------- /market_1501/util/png.py: -------------------------------------------------------------------------------- 1 | import struct 2 | import zlib 3 | 4 | def encode(buf, width, height): 5 | """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """ 6 | assert (width * height * 3 == len(buf)) 7 | bpp = 3 8 | 9 | def raw_data(): 10 | # reverse the vertical line order and add null bytes at the start 11 | row_bytes = width * bpp 12 | for row_start in range((height - 1) * width * bpp, -1, -row_bytes): 13 | yield b'\x00' 14 | yield buf[row_start:row_start + row_bytes] 15 | 16 | def chunk(tag, data): 17 | return [ 18 | struct.pack("!I", len(data)), 19 | tag, 20 | data, 21 | struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag))) 22 | ] 23 | 24 | SIGNATURE = b'\x89PNG\r\n\x1a\n' 25 | COLOR_TYPE_RGB = 2 26 | COLOR_TYPE_RGBA = 6 27 | bit_depth = 8 28 | return b''.join( 29 | [ SIGNATURE ] + 30 | chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) + 31 | chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) + 32 | chunk(b'IEND', b'') 33 | ) 34 | -------------------------------------------------------------------------------- /market_1501/util/util.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ha0Tang/BiGraphGAN/fffb7a210a8f4849ea6f1add382199445d794345/market_1501/util/util.pyc -------------------------------------------------------------------------------- /scripts/download_bigraphgan_model.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | echo "Note: available models are market and deepfashion" 4 | echo "Specified [$FILE]" 5 | 6 | URL=http://disi.unitn.it/~hao.tang/uploads/models/BiGraphGAN/${FILE}_pretrained.tar.gz 7 | TAR_FILE=./checkpoints/${FILE}_pretrained.tar.gz 8 | TARGET_DIR=./checkpoints/${FILE}_pretrained/ 9 | 10 | wget -N $URL -O $TAR_FILE 11 | 12 | mkdir -p $TARGET_DIR 13 | tar -zxvf $TAR_FILE -C ./checkpoints/ 14 | rm $TAR_FILE -------------------------------------------------------------------------------- /scripts/download_bigraphgan_result.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | echo "Note: available results are market and deepfashion" 4 | echo "Specified [$FILE]" 5 | 6 | URL=http://disi.unitn.it/~hao.tang/uploads/results/BiGraphGAN/${FILE}_results.tar.gz 7 | TAR_FILE=./results_by_author/${FILE}_results.tar.gz 8 | TARGET_DIR=./results_by_author/${FILE}_results/ 9 | 10 | wget -N $URL -O $TAR_FILE 11 | 12 | mkdir -p $TARGET_DIR 13 | tar -zxvf $TAR_FILE -C ./results_by_author/ 14 | rm $TAR_FILE --------------------------------------------------------------------------------