├── README.md ├── calc_metrics.py ├── dataset_tool.py ├── dnnlib ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── util.cpython-37.pyc │ └── util.cpython-38.pyc └── util.py ├── human_colormap.mat ├── legacy.py ├── metrics ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── frechet_inception_distance.cpython-37.pyc │ ├── frechet_inception_distance.cpython-38.pyc │ ├── inception_score.cpython-37.pyc │ ├── inception_score.cpython-38.pyc │ ├── kernel_inception_distance.cpython-37.pyc │ ├── kernel_inception_distance.cpython-38.pyc │ ├── metric_main.cpython-37.pyc │ ├── metric_main.cpython-38.pyc │ ├── metric_utils.cpython-37.pyc │ ├── metric_utils.cpython-38.pyc │ ├── perceptual_path_length.cpython-37.pyc │ ├── perceptual_path_length.cpython-38.pyc │ ├── precision_recall.cpython-37.pyc │ └── precision_recall.cpython-38.pyc ├── frechet_inception_distance.py ├── inception_score.py ├── kernel_inception_distance.py ├── metric_main.py ├── metric_utils.py ├── perceptual_path_length.py └── precision_recall.py ├── test.py ├── test.sh ├── test_datas ├── garment_parsing │ ├── 16_5DDABAA148AB70GS_2319034_forcast-6638-4309132-4.png │ ├── 16_6A959AA1C79B20GS_2541597_vero-moda-8715-7951452-4.png │ ├── 16_71E4DAAEED7A2AGS_2563260_uniqtee-6218-0623652-4.png │ ├── 16_7F4CBAA498AD56GS_2367987_niko-and-0495-7897632-4.png │ ├── 16_8533AAA1AE4F33GS_2524699_urban-revivo-6435-9964252-4.png │ ├── 16_A4B3CAADFF38DEGS_2310129_zalora-basics-6049-9210132-4.png │ ├── 16_B6D3CAA13E73ADGS_1490146_hopeshow-0620-6410941-4.png │ ├── 17_13524AA81E7E87GS_2532771_nike-5825-1772352-4.png │ ├── 17_6C790AA2F9B6D8GS_2518373_savel-5476-3738152-4.png │ ├── 17_81ADCAABC27079GS_1995579_gap-3776-9755991-4.png │ ├── 17_93BC2AA1BBA46FGS_2392311_cotton-on-body-0401-1132932-4.png │ ├── 17_AD48BAAB5D1E15GS_1803082_lc-waikiki-4177-2803081-4.png │ ├── 24_04BF5AA0C435A3GS_1830735_hm-8239-5370381-4.png │ ├── 24_1FC88AA49BD873GS_1646422_zalora-work-7767-2246461-4.png │ ├── 24_22BFFAA361A0BFGS_1707010_vero-moda-6549-0107071-4.png │ ├── 24_358BFAAD0FD0ADGS_1773796_vero-moda-5539-6973771-4.png │ ├── 24_5A877AA4F29ABFGS_1733032_zalora-work-5797-2303371-4.png │ ├── 24_5AF3DAA5F87B36GS_1615066_zalora-basics-6811-6605161-4.png │ ├── 24_65412AAEE760D5GS_1686958_missguided-4171-8596861-4.png │ ├── 24_72650AA3CF0CB3GS_1530724_zalora-basics-9270-4270351-4.png │ ├── 24_9A87AAA0D87689GS_1487834_desigual-6870-4387841-4.png │ ├── 24_A5225AAA144C67GS_951967_brave-soul-0755-769159-4.png │ ├── 24_B4612AA480B8FCGS_1797475_desigual-8847-5747971-4.png │ ├── 24_C99F9AA18E57E8GS_1590237_zalora-basics-7996-7320951-4.png │ └── 24_D394CAA4C7D0BAGS_1504042_puma-8627-2404051-4.png ├── image │ ├── 16_5DDABAA148AB70GS_2319034_forcast-6638-4309132-4.jpg │ ├── 16_6A959AA1C79B20GS_2541597_vero-moda-8715-7951452-4.jpg │ ├── 16_71E4DAAEED7A2AGS_2563260_uniqtee-6218-0623652-4.jpg │ ├── 16_7F4CBAA498AD56GS_2367987_niko-and-0495-7897632-4.jpg │ ├── 16_8533AAA1AE4F33GS_2524699_urban-revivo-6435-9964252-4.jpg │ ├── 16_A4B3CAADFF38DEGS_2310129_zalora-basics-6049-9210132-4.jpg │ ├── 16_B6D3CAA13E73ADGS_1490146_hopeshow-0620-6410941-4.jpg │ ├── 17_13524AA81E7E87GS_2532771_nike-5825-1772352-4.jpg │ ├── 17_6C790AA2F9B6D8GS_2518373_savel-5476-3738152-4.jpg │ ├── 17_81ADCAABC27079GS_1995579_gap-3776-9755991-4.jpg │ ├── 17_93BC2AA1BBA46FGS_2392311_cotton-on-body-0401-1132932-4.jpg │ ├── 17_AD48BAAB5D1E15GS_1803082_lc-waikiki-4177-2803081-4.jpg │ ├── 24_04BF5AA0C435A3GS_1830735_hm-8239-5370381-4.jpg │ ├── 24_1FC88AA49BD873GS_1646422_zalora-work-7767-2246461-4.jpg │ ├── 24_22BFFAA361A0BFGS_1707010_vero-moda-6549-0107071-4.jpg │ ├── 24_358BFAAD0FD0ADGS_1773796_vero-moda-5539-6973771-4.jpg │ ├── 24_5A877AA4F29ABFGS_1733032_zalora-work-5797-2303371-4.jpg │ ├── 24_5AF3DAA5F87B36GS_1615066_zalora-basics-6811-6605161-4.jpg │ ├── 24_65412AAEE760D5GS_1686958_missguided-4171-8596861-4.jpg │ ├── 24_72650AA3CF0CB3GS_1530724_zalora-basics-9270-4270351-4.jpg │ ├── 24_9A87AAA0D87689GS_1487834_desigual-6870-4387841-4.jpg │ ├── 24_A5225AAA144C67GS_951967_brave-soul-0755-769159-4.jpg │ ├── 24_B4612AA480B8FCGS_1797475_desigual-8847-5747971-4.jpg │ ├── 24_C99F9AA18E57E8GS_1590237_zalora-basics-7996-7320951-4.jpg │ └── 24_D394CAA4C7D0BAGS_1504042_puma-8627-2404051-4.jpg ├── keypoints │ ├── 16_5DDABAA148AB70GS_2319034_forcast-6638-4309132-4_keypoints.json │ ├── 16_6A959AA1C79B20GS_2541597_vero-moda-8715-7951452-4_keypoints.json │ ├── 16_71E4DAAEED7A2AGS_2563260_uniqtee-6218-0623652-4_keypoints.json │ ├── 16_7F4CBAA498AD56GS_2367987_niko-and-0495-7897632-4_keypoints.json │ ├── 16_8533AAA1AE4F33GS_2524699_urban-revivo-6435-9964252-4_keypoints.json │ ├── 16_A4B3CAADFF38DEGS_2310129_zalora-basics-6049-9210132-4_keypoints.json │ ├── 16_B6D3CAA13E73ADGS_1490146_hopeshow-0620-6410941-4_keypoints.json │ ├── 17_13524AA81E7E87GS_2532771_nike-5825-1772352-4_keypoints.json │ ├── 17_6C790AA2F9B6D8GS_2518373_savel-5476-3738152-4_keypoints.json │ ├── 17_81ADCAABC27079GS_1995579_gap-3776-9755991-4_keypoints.json │ ├── 17_93BC2AA1BBA46FGS_2392311_cotton-on-body-0401-1132932-4_keypoints.json │ ├── 17_AD48BAAB5D1E15GS_1803082_lc-waikiki-4177-2803081-4_keypoints.json │ ├── 24_04BF5AA0C435A3GS_1830735_hm-8239-5370381-4_keypoints.json │ ├── 24_1FC88AA49BD873GS_1646422_zalora-work-7767-2246461-4_keypoints.json │ ├── 24_22BFFAA361A0BFGS_1707010_vero-moda-6549-0107071-4_keypoints.json │ ├── 24_358BFAAD0FD0ADGS_1773796_vero-moda-5539-6973771-4_keypoints.json │ ├── 24_5A877AA4F29ABFGS_1733032_zalora-work-5797-2303371-4_keypoints.json │ ├── 24_5AF3DAA5F87B36GS_1615066_zalora-basics-6811-6605161-4_keypoints.json │ ├── 24_65412AAEE760D5GS_1686958_missguided-4171-8596861-4_keypoints.json │ ├── 24_72650AA3CF0CB3GS_1530724_zalora-basics-9270-4270351-4_keypoints.json │ ├── 24_9A87AAA0D87689GS_1487834_desigual-6870-4387841-4_keypoints.json │ ├── 24_A5225AAA144C67GS_951967_brave-soul-0755-769159-4_keypoints.json │ ├── 24_B4612AA480B8FCGS_1797475_desigual-8847-5747971-4_keypoints.json │ ├── 24_C99F9AA18E57E8GS_1590237_zalora-basics-7996-7320951-4_keypoints.json │ └── 24_D394CAA4C7D0BAGS_1504042_puma-8627-2404051-4_keypoints.json ├── parsing │ ├── 16_5DDABAA148AB70GS_2319034_forcast-6638-4309132-4.png │ ├── 16_6A959AA1C79B20GS_2541597_vero-moda-8715-7951452-4.png │ ├── 16_71E4DAAEED7A2AGS_2563260_uniqtee-6218-0623652-4.png │ ├── 16_7F4CBAA498AD56GS_2367987_niko-and-0495-7897632-4.png │ ├── 16_8533AAA1AE4F33GS_2524699_urban-revivo-6435-9964252-4.png │ ├── 16_A4B3CAADFF38DEGS_2310129_zalora-basics-6049-9210132-4.png │ ├── 16_B6D3CAA13E73ADGS_1490146_hopeshow-0620-6410941-4.png │ ├── 17_13524AA81E7E87GS_2532771_nike-5825-1772352-4.png │ ├── 17_6C790AA2F9B6D8GS_2518373_savel-5476-3738152-4.png │ ├── 17_81ADCAABC27079GS_1995579_gap-3776-9755991-4.png │ ├── 17_93BC2AA1BBA46FGS_2392311_cotton-on-body-0401-1132932-4.png │ ├── 17_AD48BAAB5D1E15GS_1803082_lc-waikiki-4177-2803081-4.png │ ├── 24_04BF5AA0C435A3GS_1830735_hm-8239-5370381-4.png │ ├── 24_1FC88AA49BD873GS_1646422_zalora-work-7767-2246461-4.png │ ├── 24_22BFFAA361A0BFGS_1707010_vero-moda-6549-0107071-4.png │ ├── 24_358BFAAD0FD0ADGS_1773796_vero-moda-5539-6973771-4.png │ ├── 24_5A877AA4F29ABFGS_1733032_zalora-work-5797-2303371-4.png │ ├── 24_5AF3DAA5F87B36GS_1615066_zalora-basics-6811-6605161-4.png │ ├── 24_65412AAEE760D5GS_1686958_missguided-4171-8596861-4.png │ ├── 24_72650AA3CF0CB3GS_1530724_zalora-basics-9270-4270351-4.png │ ├── 24_9A87AAA0D87689GS_1487834_desigual-6870-4387841-4.png │ ├── 24_A5225AAA144C67GS_951967_brave-soul-0755-769159-4.png │ ├── 24_B4612AA480B8FCGS_1797475_desigual-8847-5747971-4.png │ ├── 24_C99F9AA18E57E8GS_1590237_zalora-basics-7996-7320951-4.png │ └── 24_D394CAA4C7D0BAGS_1504042_puma-8627-2404051-4.png └── test_pairs.txt ├── torch_utils ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── custom_ops.cpython-37.pyc │ ├── custom_ops.cpython-38.pyc │ ├── misc.cpython-37.pyc │ ├── misc.cpython-38.pyc │ ├── persistence.cpython-37.pyc │ ├── persistence.cpython-38.pyc │ ├── training_stats.cpython-37.pyc │ └── training_stats.cpython-38.pyc ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── bias_act.cpython-37.pyc │ │ ├── bias_act.cpython-38.pyc │ │ ├── conv2d_gradfix.cpython-37.pyc │ │ ├── conv2d_gradfix.cpython-38.pyc │ │ ├── conv2d_resample.cpython-37.pyc │ │ ├── conv2d_resample.cpython-38.pyc │ │ ├── fma.cpython-37.pyc │ │ ├── fma.cpython-38.pyc │ │ ├── grid_sample_gradfix.cpython-37.pyc │ │ ├── grid_sample_gradfix.cpython-38.pyc │ │ ├── upfirdn2d.cpython-37.pyc │ │ └── upfirdn2d.cpython-38.pyc │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py └── training_stats.py ├── train.py ├── train.sh ├── training ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── dataset.cpython-37.pyc │ ├── networks.cpython-37.pyc │ └── utils.cpython-37.pyc ├── augment.py ├── dataset.py ├── loss_fullbody.py ├── networks.py ├── training_loop_fullbody.py └── utils.py ├── util_classes.py └── util_functions.py /README.md: -------------------------------------------------------------------------------- 1 | # Versatile Unpaired Virtual Try-on via Patch-Routed Spatially-Adaptive GAN++ 2 | 3 | Official implementation of "Versatile Unpaired Virtual Try-on via Patch-Routed Spatially-Adaptive GAN++". 4 | 5 | ## Requirements 6 | 7 | Create a virtual environment: 8 | ``` 9 | virtualenv pasta --python=python3.7 10 | source pasta/bin/activate 11 | ``` 12 | Install required packages: 13 | ``` 14 | pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html 15 | pip install click requests tqdm pyspng ninja imageio-ffmpeg==0.4.3 16 | pip install psutil scipy matplotlib opencv-python scikit-image==0.18.3 pycocotools 17 | apt install libgl1-mesa-glx 18 | ``` 19 | 20 | ## Running Inference 21 | We provide the [pre-trained models](https://drive.google.com/file/d/1oESyGm1Zcz2lWUO6AvKlj-pXWtvIRGZd/view?usp=sharing) of PASTA-GAN++ which are trained by using the full UPT dataset (i.e., our newly collected data, data from Deepfashion dataset, data from MPV dataset) with the resolution of 512 separately. 22 | 23 | we provide some test data under the directory `test_datas`, and provide a simple script to test the pre-trained model provided above on the UPT dataset as follow: 24 | ``` 25 | CUDA_VISIBLE_DEVICES=0 python3 -W ignore test.py \ 26 | --dataroot test_datas --testtxt test_pairs.txt \ 27 | --network checkpoints/pasta-gan++/network-snapshot-004408.pkl \ 28 | --outdir test_results/upper \ 29 | --batchsize 1 --testpart upper 30 | ``` 31 | or you can run the bash script by using the following command: 32 | ``` 33 | bash test.sh 1 34 | ``` 35 | 36 | Note that, in the testing script, the parameter `--network` refers to the path of the pre-trained model, the parameter `--outdir` refers to the path of the directory for generated results, the parameter `--dataroot` refers to the path of the data root, the parameter `--testtxt` refers to pair list of the garment-person pairs, the parameter `--testpart` refers to the garment part PASTA-GAN++ conducts the garment transfer. `--use-sleeve-mask` refers to whether to use the sleeve mask for data preprocess (if sleeve mask is unavailable, just ignore this parameter). As for the configuration for these parameters, please refer to `test.sh`. 37 | -------------------------------------------------------------------------------- /calc_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Calculate quality metrics for previous training run or pretrained network pickle.""" 10 | 11 | import os 12 | import click 13 | import json 14 | import tempfile 15 | import copy 16 | import torch 17 | import dnnlib 18 | 19 | import legacy 20 | from metrics import metric_main 21 | from metrics import metric_utils 22 | from torch_utils import training_stats 23 | from torch_utils import custom_ops 24 | from torch_utils import misc 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def subprocess_fn(rank, args, temp_dir): 29 | dnnlib.util.Logger(should_flush=True) 30 | 31 | # Init torch.distributed. 32 | if args.num_gpus > 1: 33 | init_file = os.path.abspath(os.path.join(temp_dir, '.torch_distributed_init')) 34 | if os.name == 'nt': 35 | init_method = 'file:///' + init_file.replace('\\', '/') 36 | torch.distributed.init_process_group(backend='gloo', init_method=init_method, rank=rank, world_size=args.num_gpus) 37 | else: 38 | init_method = f'file://{init_file}' 39 | torch.distributed.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=args.num_gpus) 40 | 41 | # Init torch_utils. 42 | sync_device = torch.device('cuda', rank) if args.num_gpus > 1 else None 43 | training_stats.init_multiprocessing(rank=rank, sync_device=sync_device) 44 | if rank != 0 or not args.verbose: 45 | custom_ops.verbosity = 'none' 46 | 47 | # Print network summary. 48 | device = torch.device('cuda', rank) 49 | torch.backends.cudnn.benchmark = True 50 | torch.backends.cuda.matmul.allow_tf32 = False 51 | torch.backends.cudnn.allow_tf32 = False 52 | G = copy.deepcopy(args.G).eval().requires_grad_(False).to(device) 53 | if rank == 0 and args.verbose: 54 | z = torch.empty([1, G.z_dim], device=device) 55 | c = torch.empty([1, G.c_dim], device=device) 56 | misc.print_module_summary(G, [z, c]) 57 | 58 | # Calculate each metric. 59 | for metric in args.metrics: 60 | if rank == 0 and args.verbose: 61 | print(f'Calculating {metric}...') 62 | progress = metric_utils.ProgressMonitor(verbose=args.verbose) 63 | result_dict = metric_main.calc_metric(metric=metric, G=G, dataset_kwargs=args.dataset_kwargs, 64 | num_gpus=args.num_gpus, rank=rank, device=device, progress=progress) 65 | if rank == 0: 66 | metric_main.report_metric(result_dict, run_dir=args.run_dir, snapshot_pkl=args.network_pkl) 67 | if rank == 0 and args.verbose: 68 | print() 69 | 70 | # Done. 71 | if rank == 0 and args.verbose: 72 | print('Exiting...') 73 | 74 | #---------------------------------------------------------------------------- 75 | 76 | class CommaSeparatedList(click.ParamType): 77 | name = 'list' 78 | 79 | def convert(self, value, param, ctx): 80 | _ = param, ctx 81 | if value is None or value.lower() == 'none' or value == '': 82 | return [] 83 | return value.split(',') 84 | 85 | #---------------------------------------------------------------------------- 86 | 87 | @click.command() 88 | @click.pass_context 89 | @click.option('network_pkl', '--network', help='Network pickle filename or URL', metavar='PATH', required=True) 90 | @click.option('--metrics', help='Comma-separated list or "none"', type=CommaSeparatedList(), default='fid50k_full', show_default=True) 91 | @click.option('--data', help='Dataset to evaluate metrics against (directory or zip) [default: same as training data]', metavar='PATH') 92 | @click.option('--mirror', help='Whether the dataset was augmented with x-flips during training [default: look up]', type=bool, metavar='BOOL') 93 | @click.option('--gpus', help='Number of GPUs to use', type=int, default=1, metavar='INT', show_default=True) 94 | @click.option('--verbose', help='Print optional information', type=bool, default=True, metavar='BOOL', show_default=True) 95 | 96 | def calc_metrics(ctx, network_pkl, metrics, data, mirror, gpus, verbose): 97 | """Calculate quality metrics for previous training run or pretrained network pickle. 98 | 99 | Examples: 100 | 101 | \b 102 | # Previous training run: look up options automatically, save result to JSONL file. 103 | python calc_metrics.py --metrics=pr50k3_full \\ 104 | --network=~/training-runs/00000-ffhq10k-res64-auto1/network-snapshot-000000.pkl 105 | 106 | \b 107 | # Pre-trained network pickle: specify dataset explicitly, print result to stdout. 108 | python calc_metrics.py --metrics=fid50k_full --data=~/datasets/ffhq.zip --mirror=1 \\ 109 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl 110 | 111 | Available metrics: 112 | 113 | \b 114 | ADA paper: 115 | fid50k_full Frechet inception distance against the full dataset. 116 | kid50k_full Kernel inception distance against the full dataset. 117 | pr50k3_full Precision and recall againt the full dataset. 118 | is50k Inception score for CIFAR-10. 119 | 120 | \b 121 | StyleGAN and StyleGAN2 papers: 122 | fid50k Frechet inception distance against 50k real images. 123 | kid50k Kernel inception distance against 50k real images. 124 | pr50k3 Precision and recall against 50k real images. 125 | ppl2_wend Perceptual path length in W at path endpoints against full image. 126 | ppl_zfull Perceptual path length in Z for full paths against cropped image. 127 | ppl_wfull Perceptual path length in W for full paths against cropped image. 128 | ppl_zend Perceptual path length in Z at path endpoints against cropped image. 129 | ppl_wend Perceptual path length in W at path endpoints against cropped image. 130 | """ 131 | dnnlib.util.Logger(should_flush=True) 132 | 133 | # Validate arguments. 134 | args = dnnlib.EasyDict(metrics=metrics, num_gpus=gpus, network_pkl=network_pkl, verbose=verbose) 135 | if not all(metric_main.is_valid_metric(metric) for metric in args.metrics): 136 | ctx.fail('\n'.join(['--metrics can only contain the following values:'] + metric_main.list_valid_metrics())) 137 | if not args.num_gpus >= 1: 138 | ctx.fail('--gpus must be at least 1') 139 | 140 | # Load network. 141 | if not dnnlib.util.is_url(network_pkl, allow_file_urls=True) and not os.path.isfile(network_pkl): 142 | ctx.fail('--network must point to a file or URL') 143 | if args.verbose: 144 | print(f'Loading network from "{network_pkl}"...') 145 | with dnnlib.util.open_url(network_pkl, verbose=args.verbose) as f: 146 | network_dict = legacy.load_network_pkl(f) 147 | args.G = network_dict['G_ema'] # subclass of torch.nn.Module 148 | 149 | # Initialize dataset options. 150 | if data is not None: 151 | args.dataset_kwargs = dnnlib.EasyDict(class_name='training.dataset.ImageFolderDataset', path=data) 152 | elif network_dict['training_set_kwargs'] is not None: 153 | args.dataset_kwargs = dnnlib.EasyDict(network_dict['training_set_kwargs']) 154 | else: 155 | ctx.fail('Could not look up dataset options; please specify --data') 156 | 157 | # Finalize dataset options. 158 | args.dataset_kwargs.resolution = args.G.img_resolution 159 | args.dataset_kwargs.use_labels = (args.G.c_dim != 0) 160 | if mirror is not None: 161 | args.dataset_kwargs.xflip = mirror 162 | 163 | # Print dataset options. 164 | if args.verbose: 165 | print('Dataset options:') 166 | print(json.dumps(args.dataset_kwargs, indent=2)) 167 | 168 | # Locate run dir. 169 | args.run_dir = None 170 | if os.path.isfile(network_pkl): 171 | pkl_dir = os.path.dirname(network_pkl) 172 | if os.path.isfile(os.path.join(pkl_dir, 'training_options.json')): 173 | args.run_dir = pkl_dir 174 | 175 | # Launch processes. 176 | if args.verbose: 177 | print('Launching processes...') 178 | torch.multiprocessing.set_start_method('spawn') 179 | with tempfile.TemporaryDirectory() as temp_dir: 180 | if args.num_gpus == 1: 181 | subprocess_fn(rank=0, args=args, temp_dir=temp_dir) 182 | else: 183 | torch.multiprocessing.spawn(fn=subprocess_fn, args=(args, temp_dir), nprocs=args.num_gpus) 184 | 185 | #---------------------------------------------------------------------------- 186 | 187 | if __name__ == "__main__": 188 | calc_metrics() # pylint: disable=no-value-for-parameter 189 | 190 | #---------------------------------------------------------------------------- 191 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /dnnlib/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/dnnlib/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /dnnlib/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/dnnlib/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /dnnlib/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/dnnlib/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /dnnlib/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/dnnlib/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /human_colormap.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/human_colormap.mat -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /metrics/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/metrics/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/metrics/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/frechet_inception_distance.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/metrics/__pycache__/frechet_inception_distance.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/frechet_inception_distance.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/metrics/__pycache__/frechet_inception_distance.cpython-38.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/inception_score.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/metrics/__pycache__/inception_score.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/inception_score.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/metrics/__pycache__/inception_score.cpython-38.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/kernel_inception_distance.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/metrics/__pycache__/kernel_inception_distance.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/kernel_inception_distance.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/metrics/__pycache__/kernel_inception_distance.cpython-38.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/metric_main.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/metrics/__pycache__/metric_main.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/metric_main.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/metrics/__pycache__/metric_main.cpython-38.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/metric_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/metrics/__pycache__/metric_utils.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/metric_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/metrics/__pycache__/metric_utils.cpython-38.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/perceptual_path_length.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/metrics/__pycache__/perceptual_path_length.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/perceptual_path_length.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/metrics/__pycache__/perceptual_path_length.cpython-38.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/precision_recall.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/metrics/__pycache__/precision_recall.cpython-37.pyc -------------------------------------------------------------------------------- /metrics/__pycache__/precision_recall.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/metrics/__pycache__/precision_recall.cpython-38.pyc -------------------------------------------------------------------------------- /metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Frechet Inception Distance (FID) from the paper 10 | "GANs trained by a two time-scale update rule converge to a local Nash 11 | equilibrium". Matches the original implementation by Heusel et al. at 12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 13 | 14 | import numpy as np 15 | import scipy.linalg 16 | from . import metric_utils 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_fid(opts, max_real, num_gen): 21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 24 | 25 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 27 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() 28 | 29 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 31 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() 32 | 33 | if opts.rank != 0: 34 | return float('nan') 35 | 36 | m = np.square(mu_gen - mu_real).sum() 37 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 38 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 39 | return float(fid) 40 | 41 | #---------------------------------------------------------------------------- 42 | -------------------------------------------------------------------------------- /metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Inception Score (IS) from the paper "Improved techniques for training 10 | GANs". Matches the original implementation by Salimans et al. at 11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_is(opts, num_gen, num_splits): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 21 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. 22 | 23 | gen_probs = metric_utils.compute_feature_stats_for_generator( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | capture_all=True, max_items=num_gen).get_all() 26 | 27 | if opts.rank != 0: 28 | return float('nan'), float('nan') 29 | 30 | scores = [] 31 | for i in range(num_splits): 32 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 33 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 34 | kl = np.mean(np.sum(kl, axis=1)) 35 | scores.append(np.exp(kl)) 36 | return float(np.mean(scores)), float(np.std(scores)) 37 | 38 | #---------------------------------------------------------------------------- 39 | -------------------------------------------------------------------------------- /metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 10 | GANs". Matches the original implementation by Binkowski et al. at 11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt' 21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 22 | 23 | real_features = metric_utils.compute_feature_stats_for_dataset( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() 26 | 27 | gen_features = metric_utils.compute_feature_stats_for_generator( 28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() 30 | 31 | if opts.rank != 0: 32 | return float('nan') 33 | 34 | n = real_features.shape[1] 35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 36 | t = 0 37 | for _subset_idx in range(num_subsets): 38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 41 | b = (x @ y.T / n + 1) ** 3 42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 43 | kid = t / num_subsets / m 44 | return float(kid) 45 | 46 | #---------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /metrics/metric_main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import time 11 | import json 12 | import torch 13 | import dnnlib 14 | 15 | from . import metric_utils 16 | from . import frechet_inception_distance 17 | from . import kernel_inception_distance 18 | from . import precision_recall 19 | from . import perceptual_path_length 20 | from . import inception_score 21 | 22 | #---------------------------------------------------------------------------- 23 | 24 | _metric_dict = dict() # name => fn 25 | 26 | def register_metric(fn): 27 | assert callable(fn) 28 | _metric_dict[fn.__name__] = fn 29 | return fn 30 | 31 | def is_valid_metric(metric): 32 | return metric in _metric_dict 33 | 34 | def list_valid_metrics(): 35 | return list(_metric_dict.keys()) 36 | 37 | #---------------------------------------------------------------------------- 38 | 39 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. 40 | assert is_valid_metric(metric) 41 | opts = metric_utils.MetricOptions(**kwargs) 42 | 43 | # Calculate. 44 | start_time = time.time() 45 | results = _metric_dict[metric](opts) 46 | total_time = time.time() - start_time 47 | 48 | # Broadcast results. 49 | for key, value in list(results.items()): 50 | if opts.num_gpus > 1: 51 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) 52 | torch.distributed.broadcast(tensor=value, src=0) 53 | value = float(value.cpu()) 54 | results[key] = value 55 | 56 | # Decorate with metadata. 57 | return dnnlib.EasyDict( 58 | results = dnnlib.EasyDict(results), 59 | metric = metric, 60 | total_time = total_time, 61 | total_time_str = dnnlib.util.format_time(total_time), 62 | num_gpus = opts.num_gpus, 63 | ) 64 | 65 | #---------------------------------------------------------------------------- 66 | 67 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None): 68 | metric = result_dict['metric'] 69 | assert is_valid_metric(metric) 70 | if run_dir is not None and snapshot_pkl is not None: 71 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) 72 | 73 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) 74 | print(jsonl_line) 75 | if run_dir is not None and os.path.isdir(run_dir): 76 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: 77 | f.write(jsonl_line + '\n') 78 | 79 | #---------------------------------------------------------------------------- 80 | # Primary metrics. 81 | 82 | @register_metric 83 | def fid50k_full(opts): 84 | opts.dataset_kwargs.update(max_size=None, xflip=False) 85 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) 86 | return dict(fid50k_full=fid) 87 | 88 | @register_metric 89 | def kid50k_full(opts): 90 | opts.dataset_kwargs.update(max_size=None, xflip=False) 91 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) 92 | return dict(kid50k_full=kid) 93 | 94 | @register_metric 95 | def pr50k3_full(opts): 96 | opts.dataset_kwargs.update(max_size=None, xflip=False) 97 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 98 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) 99 | 100 | @register_metric 101 | def ppl2_wend(opts): 102 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) 103 | return dict(ppl2_wend=ppl) 104 | 105 | @register_metric 106 | def is50k(opts): 107 | opts.dataset_kwargs.update(max_size=None, xflip=False) 108 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) 109 | return dict(is50k_mean=mean, is50k_std=std) 110 | 111 | #---------------------------------------------------------------------------- 112 | # Legacy metrics. 113 | 114 | @register_metric 115 | def fid50k(opts): 116 | opts.dataset_kwargs.update(max_size=None) 117 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) 118 | return dict(fid50k=fid) 119 | 120 | @register_metric 121 | def kid50k(opts): 122 | opts.dataset_kwargs.update(max_size=None) 123 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) 124 | return dict(kid50k=kid) 125 | 126 | @register_metric 127 | def pr50k3(opts): 128 | opts.dataset_kwargs.update(max_size=None) 129 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 130 | return dict(pr50k3_precision=precision, pr50k3_recall=recall) 131 | 132 | @register_metric 133 | def ppl_zfull(opts): 134 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='full', crop=True, batch_size=2) 135 | return dict(ppl_zfull=ppl) 136 | 137 | @register_metric 138 | def ppl_wfull(opts): 139 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='full', crop=True, batch_size=2) 140 | return dict(ppl_wfull=ppl) 141 | 142 | @register_metric 143 | def ppl_zend(opts): 144 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='z', sampling='end', crop=True, batch_size=2) 145 | return dict(ppl_zend=ppl) 146 | 147 | @register_metric 148 | def ppl_wend(opts): 149 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=True, batch_size=2) 150 | return dict(ppl_wend=ppl) 151 | 152 | #---------------------------------------------------------------------------- 153 | -------------------------------------------------------------------------------- /metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator 10 | Architecture for Generative Adversarial Networks". Matches the original 11 | implementation by Karras et al. at 12 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" 13 | 14 | import copy 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | from . import metric_utils 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | # Spherical interpolation of a batch of vectors. 23 | def slerp(a, b, t): 24 | a = a / a.norm(dim=-1, keepdim=True) 25 | b = b / b.norm(dim=-1, keepdim=True) 26 | d = (a * b).sum(dim=-1, keepdim=True) 27 | p = t * torch.acos(d) 28 | c = b - d * a 29 | c = c / c.norm(dim=-1, keepdim=True) 30 | d = a * torch.cos(p) + c * torch.sin(p) 31 | d = d / d.norm(dim=-1, keepdim=True) 32 | return d 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | class PPLSampler(torch.nn.Module): 37 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): 38 | assert space in ['z', 'w'] 39 | assert sampling in ['full', 'end'] 40 | super().__init__() 41 | self.G = copy.deepcopy(G) 42 | self.G_kwargs = G_kwargs 43 | self.epsilon = epsilon 44 | self.space = space 45 | self.sampling = sampling 46 | self.crop = crop 47 | self.vgg16 = copy.deepcopy(vgg16) 48 | 49 | def forward(self, c): 50 | # Generate random latents and interpolation t-values. 51 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) 52 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) 53 | 54 | # Interpolate in W or Z. 55 | if self.space == 'w': 56 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) 57 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) 58 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) 59 | else: # space == 'z' 60 | zt0 = slerp(z0, z1, t.unsqueeze(1)) 61 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) 62 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) 63 | 64 | # Randomize noise buffers. 65 | for name, buf in self.G.named_buffers(): 66 | if name.endswith('.noise_const'): 67 | buf.copy_(torch.randn_like(buf)) 68 | 69 | # Generate images. 70 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) 71 | 72 | # Center crop. 73 | if self.crop: 74 | assert img.shape[2] == img.shape[3] 75 | c = img.shape[2] // 8 76 | img = img[:, :, c*3 : c*7, c*2 : c*6] 77 | 78 | # Downsample to 256x256. 79 | factor = self.G.img_resolution // 256 80 | if factor > 1: 81 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) 82 | 83 | # Scale dynamic range from [-1,1] to [0,255]. 84 | img = (img + 1) * (255 / 2) 85 | if self.G.img_channels == 1: 86 | img = img.repeat([1, 3, 1, 1]) 87 | 88 | # Evaluate differential LPIPS. 89 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) 90 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 91 | return dist 92 | 93 | #---------------------------------------------------------------------------- 94 | 95 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size, jit=False): 96 | dataset = dnnlib.util.construct_class_by_name(**opts.dataset_kwargs) 97 | vgg16_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' 98 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) 99 | 100 | # Setup sampler. 101 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) 102 | sampler.eval().requires_grad_(False).to(opts.device) 103 | if jit: 104 | c = torch.zeros([batch_size, opts.G.c_dim], device=opts.device) 105 | sampler = torch.jit.trace(sampler, [c], check_trace=False) 106 | 107 | # Sampling loop. 108 | dist = [] 109 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) 110 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 111 | progress.update(batch_start) 112 | c = [dataset.get_label(np.random.randint(len(dataset))) for _i in range(batch_size)] 113 | # c = torch.from_numpy(np.stack(c)).pin_memory().to(opts.device) 114 | c = torch.from_numpy(c).pin_memory().to(opts.device) 115 | x = sampler(c) 116 | for src in range(opts.num_gpus): 117 | y = x.clone() 118 | if opts.num_gpus > 1: 119 | torch.distributed.broadcast(y, src=src) 120 | dist.append(y) 121 | progress.update(num_samples) 122 | 123 | # Compute PPL. 124 | if opts.rank != 0: 125 | return float('nan') 126 | dist = torch.cat(dist)[:num_samples].cpu().numpy() 127 | lo = np.percentile(dist, 1, interpolation='lower') 128 | hi = np.percentile(dist, 99, interpolation='higher') 129 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() 130 | return float(ppl) 131 | 132 | #---------------------------------------------------------------------------- 133 | -------------------------------------------------------------------------------- /metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 10 | Metric for Assessing Generative Models". Matches the original implementation 11 | by Kynkaanniemi et al. at 12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 13 | 14 | import torch 15 | from . import metric_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 20 | assert 0 <= rank < num_gpus 21 | num_cols = col_features.shape[0] 22 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 23 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) 24 | dist_batches = [] 25 | for col_batch in col_batches[rank :: num_gpus]: 26 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 27 | for src in range(num_gpus): 28 | dist_broadcast = dist_batch.clone() 29 | if num_gpus > 1: 30 | torch.distributed.broadcast(dist_broadcast, src=src) 31 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 32 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 37 | detector_url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' 38 | detector_kwargs = dict(return_features=True) 39 | 40 | real_features = metric_utils.compute_feature_stats_for_dataset( 41 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 42 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) 43 | 44 | gen_features = metric_utils.compute_feature_stats_for_generator( 45 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 46 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) 47 | 48 | results = dict() 49 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: 50 | kth = [] 51 | for manifold_batch in manifold.split(row_batch_size): 52 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 53 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) 54 | kth = torch.cat(kth) if opts.rank == 0 else None 55 | pred = [] 56 | for probes_batch in probes.split(row_batch_size): 57 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 58 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 59 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') 60 | return results['precision'], results['recall'] 61 | 62 | #---------------------------------------------------------------------------- 63 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate images using pretrained network pickle.""" 10 | 11 | import os 12 | import re 13 | from typing import List, Optional 14 | 15 | import click 16 | import dnnlib 17 | import numpy as np 18 | import PIL.Image 19 | import torch 20 | import torch.nn as nn 21 | 22 | from training import dataset as custom_dataset 23 | 24 | import legacy 25 | import cv2 26 | import tqdm 27 | 28 | import scipy.io as sio 29 | import tqdm 30 | 31 | CMAP = sio.loadmat('human_colormap.mat')['colormap'] 32 | CMAP = (CMAP * 256).astype(np.uint8) 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def num_range(s: str) -> List[int]: 37 | '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' 38 | 39 | range_re = re.compile(r'^(\d+)-(\d+)$') 40 | m = range_re.match(s) 41 | if m: 42 | return list(range(int(m.group(1)), int(m.group(2))+1)) 43 | vals = s.split(',') 44 | return [int(x) for x in vals] 45 | 46 | #---------------------------------------------------------------------------- 47 | 48 | @click.command() 49 | @click.pass_context 50 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 51 | @click.option('--seeds', type=num_range, help='List of random seeds') 52 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 53 | @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') 54 | @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) 55 | @click.option('--projected-w', help='Projection result file', type=str, metavar='FILE') 56 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') 57 | @click.option('--dataroot',type=str) 58 | @click.option('--batchsize',type=int) 59 | @click.option('--testpart',type=str) 60 | @click.option('--testtxt',type=str) 61 | @click.option('--use-sleeve-mask', is_flag=True) 62 | def generate_images( 63 | ctx: click.Context, 64 | network_pkl: str, 65 | seeds: Optional[List[int]], 66 | truncation_psi: float, 67 | noise_mode: str, 68 | outdir: str, 69 | class_idx: Optional[int], 70 | projected_w: Optional[str], 71 | dataroot: str, 72 | batchsize: str, 73 | testpart: str, 74 | testtxt: str, 75 | use_sleeve_mask: boolean 76 | ): 77 | """Generate images using pretrained network pickle. 78 | 79 | Examples: 80 | 81 | \b 82 | # Generate curated MetFaces images without truncation (Fig.10 left) 83 | python generate.py --outdir=out --trunc=1 --seeds=85,265,297,849 \\ 84 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 85 | 86 | \b 87 | # Generate uncurated MetFaces images with truncation (Fig.12 upper left) 88 | python generate.py --outdir=out --trunc=0.7 --seeds=600-605 \\ 89 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 90 | 91 | \b 92 | # Generate class conditional CIFAR-10 images (Fig.17 left, Car) 93 | python generate.py --outdir=out --seeds=0-35 --class=1 \\ 94 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl 95 | 96 | \b 97 | # Render an image from projected W 98 | python generate.py --outdir=out --projected_w=projected_w.npz \\ 99 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl 100 | """ 101 | 102 | print('Loading networks from "%s"...' % network_pkl) 103 | device = torch.device('cuda') 104 | with dnnlib.util.open_url(network_pkl) as f: 105 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 106 | 107 | os.makedirs(outdir, exist_ok=True) 108 | 109 | if testpart == 'full': 110 | dataset = custom_dataset.UvitonDatasetFull_512_test_full(path=dataroot,test_txt=testtxt,use_sleeve_mask=use_sleeve_mask, max_size=None, xflip=False) 111 | elif testpart == 'upper': 112 | dataset = custom_dataset.UvitonDatasetFull_512_test_upper(path=dataroot,test_txt=testtxt,use_sleeve_mask=use_sleeve_mask, max_size=None, xflip=False) 113 | elif testpart == 'lower': 114 | dataset = custom_dataset.UvitonDatasetFull_512_test_lower(path=dataroot,test_txt=testtxt,use_sleeve_mask=use_sleeve_mask, max_size=None, xflip=False) 115 | else: 116 | raise ValueError('Invalid value for test part!') 117 | dataloader = torch.utils.data.DataLoader(dataset,batch_size=batchsize,shuffle=False,pin_memory=True, num_workers=0) 118 | 119 | device = torch.device('cuda') 120 | 121 | for data in tqdm.tqdm(dataloader): 122 | image, clothes, pose, _, norm_img, norm_img_lower, denorm_upper_clothes, denorm_lower_clothes, \ 123 | denorm_upper_mask, denorm_lower_mask, retain_mask, skin_average, lower_label_map, lower_clothes_upper_bound, \ 124 | person_name, clothes_name = data 125 | 126 | image_tensor = image.to(device).to(torch.float32) / 127.5 - 1 127 | clothes_tensor = clothes.to(device).to(torch.float32) / 127.5 - 1 128 | pose_tensor = pose.to(device).to(torch.float32) / 127.5 - 1 129 | norm_img_tensor = norm_img.to(device).to(torch.float32) / 127.5 - 1 130 | norm_img_lower_tensor = norm_img_lower.to(device).to(torch.float32) / 127.5 - 1 131 | 132 | skin_tensor = skin_average.to(device).to(torch.float32) / 127.5 - 1 133 | lower_label_map_tensor = lower_label_map.to(device).to(torch.float32) / 127.5 - 1 134 | lower_clothes_upper_bound_tensor = lower_clothes_upper_bound.to(device).to(torch.float32) / 127.5 - 1 135 | 136 | parts_tensor = torch.cat([norm_img_tensor, norm_img_lower_tensor],dim=1) 137 | 138 | denorm_upper_clothes_tensor = denorm_upper_clothes.to(device).to(torch.float32) / 127.5 - 1 139 | denorm_upper_mask_tensor = denorm_upper_mask.to(device).to(torch.float32) 140 | 141 | denorm_lower_clothes_tensor = denorm_lower_clothes.to(device).to(torch.float32) / 127.5 - 1 142 | denorm_lower_mask_tensor = denorm_lower_mask.to(device).to(torch.float32) 143 | 144 | retain_mask_tensor = retain_mask.to(device) 145 | retain_tensor = image_tensor * retain_mask_tensor - (1-retain_mask_tensor) 146 | pose_tensor = torch.cat([pose_tensor,lower_label_map_tensor,lower_clothes_upper_bound_tensor],dim=1) 147 | retain_tensor = torch.cat([retain_tensor,skin_tensor],dim=1) 148 | gen_z = torch.randn([batchsize,0],device=device) 149 | 150 | with torch.no_grad(): 151 | gen_c, cat_feat_list = G.style_encoding(parts_tensor, retain_tensor) 152 | pose_feat = G.const_encoding(pose_tensor) 153 | ws = G.mapping(gen_z,gen_c) 154 | cat_feats = {} 155 | for cat_feat in cat_feat_list: 156 | h = cat_feat.shape[2] 157 | cat_feats[str(h)] = cat_feat 158 | gt_parsing = None 159 | _, gen_imgs, _ = G.synthesis(ws, pose_feat, cat_feats, denorm_upper_clothes_tensor, denorm_lower_clothes_tensor, \ 160 | denorm_upper_mask_tensor, denorm_lower_mask_tensor, gt_parsing) 161 | 162 | for ii in range(gen_imgs.size(0)): 163 | gen_img = gen_imgs[ii].detach().cpu().numpy() 164 | gen_img = (gen_img.transpose(1,2,0)+1.0) * 127.5 165 | gen_img = np.clip(gen_img,0,255) 166 | gen_img = gen_img.astype(np.uint8)[...,[2,1,0]] 167 | 168 | image_np = image_tensor[ii].detach().cpu().numpy() 169 | image_np = (image_np.transpose(1,2,0)+1.0) * 127.5 170 | image_np = image_np.astype(np.uint8)[...,[2,1,0]] 171 | 172 | clothes_np = clothes_tensor[ii].detach().cpu().numpy() 173 | clothes_np = (clothes_np.transpose(1,2,0)+1.0) * 127.5 174 | clothes_np = clothes_np.astype(np.uint8)[...,[2,1,0]] 175 | 176 | result = np.concatenate([clothes_np[:,96:416,:], image_np[:,96:416,:], \ 177 | gen_img[:,96:416,:]], axis=1) 178 | 179 | person_n = person_name[ii].split('/')[-1] 180 | clothes_n = clothes_name[ii].split('/')[-1] 181 | 182 | save_name = person_n[:-4]+'___'+clothes_n[:-4]+'.png' 183 | save_path = os.path.join(outdir, save_name) 184 | cv2.imwrite(save_path,result) 185 | 186 | print('finish') 187 | 188 | 189 | #---------------------------------------------------------------------------- 190 | 191 | if __name__ == "__main__": 192 | generate_images() # pylint: disable=no-value-for-parameter 193 | 194 | #---------------------------------------------------------------------------- 195 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | if [ $1 == 1 ]; then 3 | CUDA_VISIBLE_DEVICES=0 python3 -W ignore test.py \ 4 | --dataroot test_datas --testtxt test_pairs.txt \ 5 | --network checkpoints/pasta-gan++/network-snapshot-004408.pkl \ 6 | --outdir test_results/upper \ 7 | --batchsize 1 --testpart upper \ 8 | --use-sleeve-mask 9 | elif [ $1 == 2 ]; then 10 | CUDA_VISIBLE_DEVICES=0 python3 -W ignore test.py \ 11 | --dataroot test_datas --testtxt test_pairs.txt \ 12 | --network checkpoints/pasta-gan++/network-snapshot-004408.pkl \ 13 | --outdir test_results/lower \ 14 | --batchsize 1 --testpart lower \ 15 | --use-sleeve-mask 16 | elif [ $1 == 3 ]; then 17 | CUDA_VISIBLE_DEVICES=0 python3 -W ignore test.py \ 18 | --dataroot test_datas --testtxt test_pairs.txt \ 19 | --network checkpoints/pasta-gan++/network-snapshot-004408.pkl \ 20 | --outdir test_results/full \ 21 | --batchsize 1 --testpart full \ 22 | --use-sleeve-mask 23 | fi -------------------------------------------------------------------------------- /test_datas/garment_parsing/16_5DDABAA148AB70GS_2319034_forcast-6638-4309132-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/16_5DDABAA148AB70GS_2319034_forcast-6638-4309132-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/16_6A959AA1C79B20GS_2541597_vero-moda-8715-7951452-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/16_6A959AA1C79B20GS_2541597_vero-moda-8715-7951452-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/16_71E4DAAEED7A2AGS_2563260_uniqtee-6218-0623652-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/16_71E4DAAEED7A2AGS_2563260_uniqtee-6218-0623652-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/16_7F4CBAA498AD56GS_2367987_niko-and-0495-7897632-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/16_7F4CBAA498AD56GS_2367987_niko-and-0495-7897632-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/16_8533AAA1AE4F33GS_2524699_urban-revivo-6435-9964252-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/16_8533AAA1AE4F33GS_2524699_urban-revivo-6435-9964252-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/16_A4B3CAADFF38DEGS_2310129_zalora-basics-6049-9210132-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/16_A4B3CAADFF38DEGS_2310129_zalora-basics-6049-9210132-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/16_B6D3CAA13E73ADGS_1490146_hopeshow-0620-6410941-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/16_B6D3CAA13E73ADGS_1490146_hopeshow-0620-6410941-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/17_13524AA81E7E87GS_2532771_nike-5825-1772352-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/17_13524AA81E7E87GS_2532771_nike-5825-1772352-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/17_6C790AA2F9B6D8GS_2518373_savel-5476-3738152-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/17_6C790AA2F9B6D8GS_2518373_savel-5476-3738152-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/17_81ADCAABC27079GS_1995579_gap-3776-9755991-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/17_81ADCAABC27079GS_1995579_gap-3776-9755991-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/17_93BC2AA1BBA46FGS_2392311_cotton-on-body-0401-1132932-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/17_93BC2AA1BBA46FGS_2392311_cotton-on-body-0401-1132932-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/17_AD48BAAB5D1E15GS_1803082_lc-waikiki-4177-2803081-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/17_AD48BAAB5D1E15GS_1803082_lc-waikiki-4177-2803081-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/24_04BF5AA0C435A3GS_1830735_hm-8239-5370381-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/24_04BF5AA0C435A3GS_1830735_hm-8239-5370381-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/24_1FC88AA49BD873GS_1646422_zalora-work-7767-2246461-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/24_1FC88AA49BD873GS_1646422_zalora-work-7767-2246461-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/24_22BFFAA361A0BFGS_1707010_vero-moda-6549-0107071-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/24_22BFFAA361A0BFGS_1707010_vero-moda-6549-0107071-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/24_358BFAAD0FD0ADGS_1773796_vero-moda-5539-6973771-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/24_358BFAAD0FD0ADGS_1773796_vero-moda-5539-6973771-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/24_5A877AA4F29ABFGS_1733032_zalora-work-5797-2303371-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/24_5A877AA4F29ABFGS_1733032_zalora-work-5797-2303371-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/24_5AF3DAA5F87B36GS_1615066_zalora-basics-6811-6605161-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/24_5AF3DAA5F87B36GS_1615066_zalora-basics-6811-6605161-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/24_65412AAEE760D5GS_1686958_missguided-4171-8596861-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/24_65412AAEE760D5GS_1686958_missguided-4171-8596861-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/24_72650AA3CF0CB3GS_1530724_zalora-basics-9270-4270351-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/24_72650AA3CF0CB3GS_1530724_zalora-basics-9270-4270351-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/24_9A87AAA0D87689GS_1487834_desigual-6870-4387841-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/24_9A87AAA0D87689GS_1487834_desigual-6870-4387841-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/24_A5225AAA144C67GS_951967_brave-soul-0755-769159-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/24_A5225AAA144C67GS_951967_brave-soul-0755-769159-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/24_B4612AA480B8FCGS_1797475_desigual-8847-5747971-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/24_B4612AA480B8FCGS_1797475_desigual-8847-5747971-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/24_C99F9AA18E57E8GS_1590237_zalora-basics-7996-7320951-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/24_C99F9AA18E57E8GS_1590237_zalora-basics-7996-7320951-4.png -------------------------------------------------------------------------------- /test_datas/garment_parsing/24_D394CAA4C7D0BAGS_1504042_puma-8627-2404051-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/garment_parsing/24_D394CAA4C7D0BAGS_1504042_puma-8627-2404051-4.png -------------------------------------------------------------------------------- /test_datas/image/16_5DDABAA148AB70GS_2319034_forcast-6638-4309132-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/16_5DDABAA148AB70GS_2319034_forcast-6638-4309132-4.jpg -------------------------------------------------------------------------------- /test_datas/image/16_6A959AA1C79B20GS_2541597_vero-moda-8715-7951452-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/16_6A959AA1C79B20GS_2541597_vero-moda-8715-7951452-4.jpg -------------------------------------------------------------------------------- /test_datas/image/16_71E4DAAEED7A2AGS_2563260_uniqtee-6218-0623652-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/16_71E4DAAEED7A2AGS_2563260_uniqtee-6218-0623652-4.jpg -------------------------------------------------------------------------------- /test_datas/image/16_7F4CBAA498AD56GS_2367987_niko-and-0495-7897632-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/16_7F4CBAA498AD56GS_2367987_niko-and-0495-7897632-4.jpg -------------------------------------------------------------------------------- /test_datas/image/16_8533AAA1AE4F33GS_2524699_urban-revivo-6435-9964252-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/16_8533AAA1AE4F33GS_2524699_urban-revivo-6435-9964252-4.jpg -------------------------------------------------------------------------------- /test_datas/image/16_A4B3CAADFF38DEGS_2310129_zalora-basics-6049-9210132-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/16_A4B3CAADFF38DEGS_2310129_zalora-basics-6049-9210132-4.jpg -------------------------------------------------------------------------------- /test_datas/image/16_B6D3CAA13E73ADGS_1490146_hopeshow-0620-6410941-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/16_B6D3CAA13E73ADGS_1490146_hopeshow-0620-6410941-4.jpg -------------------------------------------------------------------------------- /test_datas/image/17_13524AA81E7E87GS_2532771_nike-5825-1772352-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/17_13524AA81E7E87GS_2532771_nike-5825-1772352-4.jpg -------------------------------------------------------------------------------- /test_datas/image/17_6C790AA2F9B6D8GS_2518373_savel-5476-3738152-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/17_6C790AA2F9B6D8GS_2518373_savel-5476-3738152-4.jpg -------------------------------------------------------------------------------- /test_datas/image/17_81ADCAABC27079GS_1995579_gap-3776-9755991-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/17_81ADCAABC27079GS_1995579_gap-3776-9755991-4.jpg -------------------------------------------------------------------------------- /test_datas/image/17_93BC2AA1BBA46FGS_2392311_cotton-on-body-0401-1132932-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/17_93BC2AA1BBA46FGS_2392311_cotton-on-body-0401-1132932-4.jpg -------------------------------------------------------------------------------- /test_datas/image/17_AD48BAAB5D1E15GS_1803082_lc-waikiki-4177-2803081-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/17_AD48BAAB5D1E15GS_1803082_lc-waikiki-4177-2803081-4.jpg -------------------------------------------------------------------------------- /test_datas/image/24_04BF5AA0C435A3GS_1830735_hm-8239-5370381-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/24_04BF5AA0C435A3GS_1830735_hm-8239-5370381-4.jpg -------------------------------------------------------------------------------- /test_datas/image/24_1FC88AA49BD873GS_1646422_zalora-work-7767-2246461-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/24_1FC88AA49BD873GS_1646422_zalora-work-7767-2246461-4.jpg -------------------------------------------------------------------------------- /test_datas/image/24_22BFFAA361A0BFGS_1707010_vero-moda-6549-0107071-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/24_22BFFAA361A0BFGS_1707010_vero-moda-6549-0107071-4.jpg -------------------------------------------------------------------------------- /test_datas/image/24_358BFAAD0FD0ADGS_1773796_vero-moda-5539-6973771-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/24_358BFAAD0FD0ADGS_1773796_vero-moda-5539-6973771-4.jpg -------------------------------------------------------------------------------- /test_datas/image/24_5A877AA4F29ABFGS_1733032_zalora-work-5797-2303371-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/24_5A877AA4F29ABFGS_1733032_zalora-work-5797-2303371-4.jpg -------------------------------------------------------------------------------- /test_datas/image/24_5AF3DAA5F87B36GS_1615066_zalora-basics-6811-6605161-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/24_5AF3DAA5F87B36GS_1615066_zalora-basics-6811-6605161-4.jpg -------------------------------------------------------------------------------- /test_datas/image/24_65412AAEE760D5GS_1686958_missguided-4171-8596861-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/24_65412AAEE760D5GS_1686958_missguided-4171-8596861-4.jpg -------------------------------------------------------------------------------- /test_datas/image/24_72650AA3CF0CB3GS_1530724_zalora-basics-9270-4270351-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/24_72650AA3CF0CB3GS_1530724_zalora-basics-9270-4270351-4.jpg -------------------------------------------------------------------------------- /test_datas/image/24_9A87AAA0D87689GS_1487834_desigual-6870-4387841-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/24_9A87AAA0D87689GS_1487834_desigual-6870-4387841-4.jpg -------------------------------------------------------------------------------- /test_datas/image/24_A5225AAA144C67GS_951967_brave-soul-0755-769159-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/24_A5225AAA144C67GS_951967_brave-soul-0755-769159-4.jpg -------------------------------------------------------------------------------- /test_datas/image/24_B4612AA480B8FCGS_1797475_desigual-8847-5747971-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/24_B4612AA480B8FCGS_1797475_desigual-8847-5747971-4.jpg -------------------------------------------------------------------------------- /test_datas/image/24_C99F9AA18E57E8GS_1590237_zalora-basics-7996-7320951-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/24_C99F9AA18E57E8GS_1590237_zalora-basics-7996-7320951-4.jpg -------------------------------------------------------------------------------- /test_datas/image/24_D394CAA4C7D0BAGS_1504042_puma-8627-2404051-4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/image/24_D394CAA4C7D0BAGS_1504042_puma-8627-2404051-4.jpg -------------------------------------------------------------------------------- /test_datas/keypoints/16_5DDABAA148AB70GS_2319034_forcast-6638-4309132-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[176.384,50.597,0.981056,153.611,100.625,0.906886,114.945,107.76,0.803526,107.792,177.857,0.896331,76.2957,235.08,0.833012,190.801,97.7254,0.814812,256.61,142.091,0.914178,197.904,164.938,0.798382,116.382,222.203,0.496886,117.752,342.396,0.751119,126.382,446.827,0.726105,166.457,223.661,0.546065,187.901,346.621,0.832972,210.804,451.099,0.816646,165.019,40.5311,0.960382,179.325,40.6214,0.940001,140.71,40.6198,0.922707,0,0,0],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/16_6A959AA1C79B20GS_2541597_vero-moda-8715-7951452-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[162.117,54.8367,0.866904,162.106,106.308,0.835282,120.655,99.1969,0.856827,109.259,167.88,0.820228,99.2135,233.637,0.83054,199.374,107.824,0.895181,210.795,177.886,0.847246,213.699,243.642,0.820726,132.133,227.934,0.567046,117.825,338.122,0.663147,93.4966,451.13,0.741666,186.463,232.227,0.585337,186.418,338.136,0.667372,166.519,439.684,0.789948,153.517,49.1131,0.920659,167.869,49.1081,0.902881,142.105,50.5951,0.942684,179.35,50.5909,0.9154],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/16_71E4DAAEED7A2AGS_2563260_uniqtee-6218-0623652-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[164.928,50.5027,0.926162,166.471,97.7783,0.942752,127.818,97.7213,0.844091,110.675,164.946,0.866136,97.7127,223.641,0.911312,209.346,97.825,0.87678,212.208,175.01,0.879451,222.21,236.581,0.838379,136.377,223.638,0.561052,133.539,350.979,0.827886,120.738,461.12,0.807831,189.323,225.085,0.568389,196.469,353.817,0.787896,205.04,462.499,0.782625,153.559,40.5875,0.956424,170.725,39.108,0.895871,143.555,49.0819,0.797994,186.472,40.5026,0.958935],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/16_7F4CBAA498AD56GS_2367987_niko-and-0495-7897632-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[169.298,51.9347,0.939046,155.003,99.2414,0.88739,116.357,97.7933,0.814092,107.74,176.387,0.825174,89.1848,245.068,0.681054,196.463,103.468,0.787907,199.335,174.981,0.809915,197.884,235.156,0.792864,122.128,239.38,0.496351,113.525,350.963,0.490865,127.775,449.696,0.649859,177.855,242.264,0.574708,200.792,345.244,0.540359,210.761,443.977,0.707047,163.518,39.1557,0.910198,179.309,41.9752,0.935969,142.085,37.7082,0.952756,183.625,49.1107,0.09192],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/16_8533AAA1AE4F33GS_2524699_urban-revivo-6435-9964252-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[177.916,40.5311,0.978704,163.531,97.7953,0.9058,120.654,96.3419,0.857069,109.251,173.615,0.860164,106.31,246.564,0.868806,200.829,102.066,0.833897,215.076,177.823,0.85785,223.659,235.112,0.88585,137.843,227.956,0.509045,114.927,348.059,0.806835,84.9243,463.933,0.812479,192.179,233.626,0.570056,187.874,348.104,0.824201,166.413,462.539,0.838517,167.894,29.1965,0.902083,187.84,33.3981,0.891371,149.262,39.0922,0.880972,190.764,42.0451,0.0733979],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/16_A4B3CAADFF38DEGS_2310129_zalora-basics-6049-9210132-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[166.421,40.5625,0.961971,164.933,97.7661,0.917976,123.539,94.8873,0.80178,119.228,173.606,0.844165,132.058,239.404,0.841301,200.724,104.896,0.869995,203.663,177.894,0.792787,212.204,249.384,0.652328,143.557,223.653,0.57249,142.135,349.543,0.799477,150.67,454.001,0.759883,190.808,225.096,0.544396,187.939,342.367,0.803811,176.397,455.405,0.733238,156.422,34.8149,0.854893,176.436,34.807,0.902539,143.546,44.826,0.881435,187.909,47.7212,0.900255],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/16_B6D3CAA13E73ADGS_1490146_hopeshow-0620-6410941-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[139.257,53.4072,0.87679,142.096,112.09,0.882388,97.8444,113.501,0.812145,83.4137,187.906,0.862637,53.4777,255.136,0.822713,182.192,110.698,0.801185,186.448,189.315,0.875283,199.281,256.614,0.875823,102.054,233.69,0.495762,96.3383,349.569,0.791952,97.7572,462.559,0.835765,155.007,240.809,0.550421,165.016,353.815,0.74884,180.754,452.532,0.762245,130.592,47.6841,0.892604,145.035,46.2791,0.85366,117.806,52.0628,0.903764,157.887,51.9762,0.826559],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/17_13524AA81E7E87GS_2532771_nike-5825-1772352-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[160.721,51.9309,0.93927,163.523,104.93,0.837661,119.216,107.746,0.850672,109.17,183.611,0.864792,96.275,256.59,0.871401,205.057,102.066,0.773866,212.288,177.852,0.893432,222.278,255.093,0.852909,130.676,226.53,0.612344,132.096,338.167,0.827548,130.646,456.827,0.777053,183.616,229.382,0.58607,209.319,343.795,0.800975,225.078,466.834,0.772996,152.133,40.6207,0.90955,167.879,40.6062,0.927976,137.849,43.4246,0.856675,182.173,42.0452,0.840106],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/17_6C790AA2F9B6D8GS_2518373_savel-5476-3738152-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[162.113,37.6669,0.854135,162.161,99.1838,0.872697,119.228,97.7798,0.867769,97.7495,163.545,0.88845,117.778,216.52,0.795371,202.244,102.065,0.804036,222.151,166.463,0.870425,216.525,193.612,0.726321,143.509,225.068,0.636658,97.7636,349.518,0.870813,130.729,465.389,0.814735,196.483,226.532,0.598622,189.359,349.503,0.870403,185.021,472.528,0.721446,150.722,29.1109,0.885197,167.922,27.6842,0.953834,140.657,40.551,0.77448,180.751,37.6753,0.728968],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/17_81ADCAABC27079GS_1995579_gap-3776-9755991-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[130.59,53.484,0.89873,142.165,99.1859,0.854304,106.362,99.2558,0.755377,102.078,157.85,0.698941,84.9092,199.296,0.821115,177.911,97.7317,0.818354,202.206,164.981,0.885176,183.589,215.052,0.848807,96.4023,223.635,0.675096,102.066,346.661,0.784309,110.659,443.962,0.778655,150.689,227.957,0.599233,180.719,348.095,0.838545,212.205,443.96,0.793056,119.205,49.1593,0.935203,136.394,44.854,0.849386,109.201,53.4726,0.679177,152.137,50.5308,0.91025],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/17_93BC2AA1BBA46FGS_2392311_cotton-on-body-0401-1132932-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[153.625,49.1115,0.929522,165.016,97.7891,0.905801,120.716,100.669,0.809627,53.4845,142.038,0.873515,109.204,170.705,0.827419,209.339,96.3139,0.799553,237.955,165.068,0.8605,206.474,225.124,0.801146,123.546,230.824,0.596395,130.724,351.019,0.822347,144.969,464.04,0.748959,177.866,235.103,0.645469,210.834,348.158,0.878659,255.075,455.411,0.715736,143.524,39.1648,0.969468,164.925,37.6768,0.901659,132.132,44.8507,0.836458,177.92,39.0692,0.941903],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/17_AD48BAAB5D1E15GS_1803082_lc-waikiki-4177-2803081-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[146.44,67.7298,0.862175,166.509,114.931,0.842952,127.791,110.631,0.83379,120.645,182.173,0.829024,119.28,245.094,0.894916,210.782,117.774,0.851216,219.338,189.336,0.807715,222.226,253.676,0.806096,140.656,235.108,0.575203,109.239,348.077,0.884395,143.491,448.239,0.775911,190.791,240.815,0.569725,187.822,348.11,0.84702,187.93,451.09,0.778594,142.131,61.9363,0.936782,155.023,60.5266,0.921548,0,0,0,177.822,57.7223,0.803743],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/24_04BF5AA0C435A3GS_1830735_hm-8239-5370381-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[150.666,50.5049,0.895922,166.445,99.1831,0.896265,130.637,100.658,0.810864,69.1428,132.172,0.849863,109.21,157.867,0.862096,207.894,97.755,0.803859,262.302,142.178,0.866717,210.721,166.388,0.895299,122.1,215.069,0.629379,120.669,348.132,0.79386,120.727,461.136,0.794387,175.013,222.254,0.626254,192.175,349.503,0.766142,222.265,458.244,0.798936,139.246,42.0064,0.872101,156.42,39.0658,0.927126,132.052,53.4173,0.708098,176.395,39.1714,0.95406],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/24_1FC88AA49BD873GS_1646422_zalora-work-7767-2246461-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[157.886,39.1391,0.913493,153.517,94.9069,0.901482,109.276,86.3874,0.813519,106.341,170.732,0.822605,116.344,236.537,0.834682,190.739,97.842,0.868604,186.427,163.56,0.211552,0,0,0,140.69,229.385,0.595928,123.549,353.844,0.680138,109.295,475.415,0.687137,196.498,223.696,0.651168,202.187,351.01,0.732777,210.712,474.02,0.666919,152.099,27.6993,0.933803,169.323,30.5125,0.878038,132.188,29.1494,0.938011,179.29,39.1802,0.793762],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/24_22BFFAA361A0BFGS_1707010_vero-moda-6549-0107071-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[163.543,46.2719,0.861211,163.554,97.8466,0.869092,120.609,96.3453,0.864062,106.342,166.428,0.797571,87.8087,233.69,0.81927,202.24,102.052,0.81552,212.16,170.733,0.816093,219.383,232.214,0.814754,132.076,233.677,0.613873,139.262,356.67,0.52526,142.163,466.867,0.675702,187.904,233.688,0.643885,189.318,355.25,0.535878,195.054,474.014,0.604377,153.545,37.7078,0.928125,173.557,37.6938,0.893276,142.157,40.6203,0.782881,187.852,40.6241,0.949694],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/24_358BFAAD0FD0ADGS_1773796_vero-moda-5539-6973771-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[142.135,59.1176,0.879248,142.097,109.274,0.869705,104.901,107.798,0.801416,94.8774,177.858,0.722742,97.8087,243.658,0.845749,177.872,112.087,0.852213,192.177,185.047,0.654093,199.288,253.671,0.775516,122.067,235.112,0.625562,123.501,348.151,0.744242,127.833,449.676,0.673392,175.026,235.139,0.621881,166.379,349.49,0.808916,160.694,449.642,0.678756,132.148,50.5498,0.944053,150.723,50.5128,0.895811,120.715,52.0616,0.892445,164.925,51.9749,0.990315],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/24_5A877AA4F29ABFGS_1733032_zalora-work-5797-2303371-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[147.853,40.6189,0.881923,165,107.709,0.873048,129.2,104.943,0.75638,113.504,177.906,0.802803,94.867,246.526,0.825667,202.18,107.731,0.800714,210.741,179.345,0.865114,230.781,246.566,0.834778,129.255,233.682,0.563761,139.237,348.121,0.679893,143.631,451.071,0.756756,182.159,233.657,0.594124,200.757,345.253,0.738038,232.191,449.648,0.702828,142.05,37.6429,0.881787,156.459,33.3819,0.85398,132.159,42.0389,0.307678,176.473,40.594,0.882488],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/24_5AF3DAA5F87B36GS_1615066_zalora-basics-6811-6605161-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[143.634,50.5377,0.947735,143.508,103.476,0.88097,103.489,106.334,0.800173,94.8718,179.304,0.815468,86.3561,253.717,0.830612,182.155,99.1772,0.823408,190.746,167.917,0.844311,192.203,223.732,0.822,119.168,232.222,0.608975,114.961,343.808,0.631372,132.05,454.007,0.708113,173.572,232.209,0.617335,199.375,339.513,0.657483,233.701,453.963,0.745902,136.385,40.5196,0.886619,154.983,40.603,0.955164,120.736,47.7245,0.92507,166.458,50.5468,0.932592],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/24_65412AAEE760D5GS_1686958_missguided-4171-8596861-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[163.584,44.8363,0.877625,155.068,97.7149,0.90501,119.289,96.3723,0.848477,110.644,165.035,0.779956,110.666,226.516,0.756091,192.167,97.7293,0.861874,200.762,165.008,0.618936,213.646,227.943,0.596995,134.987,212.245,0.587817,129.203,329.515,0.784257,120.627,433.932,0.743715,187.927,213.632,0.620502,183.584,326.64,0.828503,172.14,428.265,0.724903,153.627,39.03,0.94936,172.139,39.0669,0.89454,143.505,42.0024,0.788442,182.149,44.8332,0.775969],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/24_72650AA3CF0CB3GS_1530724_zalora-basics-9270-4270351-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[173.564,39.0875,0.898902,164.951,96.35,0.877407,129.217,92.0494,0.751564,86.2865,156.421,0.871074,117.756,213.648,0.831207,199.337,100.631,0.758046,206.497,177.935,0.765548,223.665,247.959,0.749901,153.503,230.813,0.578795,123.525,349.569,0.779998,93.4811,462.522,0.737159,199.395,230.789,0.56701,195.055,348.122,0.829105,187.838,462.583,0.802724,162.147,29.1471,0.904462,179.298,29.0889,0.929798,143.589,40.5817,0.928775,189.276,39.1083,0.314911],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/24_9A87AAA0D87689GS_1487834_desigual-6870-4387841-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[152.115,49.1591,0.899822,166.39,100.628,0.834981,130.635,104.937,0.750633,112.085,176.406,0.726314,97.818,236.566,0.791778,202.177,97.7814,0.807437,235.15,164.955,0.669653,212.22,189.337,0.582964,123.538,233.655,0.599404,110.677,346.639,0.697975,119.249,462.518,0.779909,176.411,235.101,0.641647,187.927,346.675,0.701194,205.076,462.569,0.736225,142.117,40.5064,0.943487,160.713,39.0536,0.887634,132.164,47.6723,0.312239,177.84,39.1612,0.988472],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/24_A5225AAA144C67GS_951967_brave-soul-0755-769159-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[166.512,62.0242,0.950986,162.133,117.816,0.836491,119.282,119.269,0.857315,107.773,196.475,0.793743,96.2638,260.842,0.821747,200.835,114.915,0.756964,209.349,187.948,0.856531,212.164,256.579,0.879064,122.119,242.243,0.599036,139.275,350.971,0.720538,149.286,453.982,0.718275,176.475,243.68,0.648864,176.378,352.391,0.774569,176.374,453.949,0.68762,156.47,52.0077,0.889164,176.433,52.0127,0.962013,142.059,54.8613,0.881732,179.344,54.8717,0.193882],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/24_B4612AA480B8FCGS_1797475_desigual-8847-5747971-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[166.467,47.7258,0.925092,167.84,99.1771,0.889443,123.53,96.3357,0.850682,109.247,167.878,0.78704,110.652,233.671,0.794875,212.176,104.942,0.842698,225.064,173.597,0.767792,222.292,240.829,0.839578,139.28,212.185,0.582932,140.688,335.218,0.693033,133.562,455.402,0.704902,190.788,213.65,0.58184,177.874,336.641,0.743052,165.005,441.101,0.775097,157.854,39.0817,0.876779,176.503,39.0803,0.974004,147.838,40.5012,0.817377,189.344,40.6318,0.918332],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/24_C99F9AA18E57E8GS_1590237_zalora-basics-7996-7320951-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[164.927,50.51,0.936597,159.268,100.635,0.834949,119.278,99.2028,0.866339,107.842,174.988,0.839894,96.4017,243.697,0.877217,199.309,102.071,0.836006,202.172,174.973,0.845459,205.09,240.797,0.801329,126.368,239.405,0.587606,132.115,348.126,0.728512,145.006,453.965,0.725887,177.909,246.521,0.650673,176.372,349.53,0.75222,169.284,451.135,0.648686,153.616,40.506,0.98876,173.551,40.6037,0.877756,142.072,49.1431,0.942302,182.196,50.5774,0.88119],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/keypoints/24_D394CAA4C7D0BAGS_1504042_puma-8627-2404051-4_keypoints.json: -------------------------------------------------------------------------------- 1 | {"version":1.3,"people":[{"person_id":[-1],"pose_keypoints_2d":[102.055,42.0161,0.856876,119.224,97.7216,0.901629,76.35,104.887,0.833916,72.0454,167.877,0.890929,59.1269,210.8,0.805207,160.706,89.2176,0.744288,180.751,164.944,0.852713,164.98,223.697,0.837733,76.3333,225.079,0.625145,84.9374,353.813,0.649592,99.2347,462.567,0.736856,132.08,229.396,0.623301,154.957,349.504,0.699269,189.287,455.39,0.694973,90.6216,37.6893,0.857232,109.255,30.5723,0.913589,82.0141,44.8195,0.802946,127.835,34.8537,0.851283],"face_keypoints_2d":[],"hand_left_keypoints_2d":[],"hand_right_keypoints_2d":[],"pose_keypoints_3d":[],"face_keypoints_3d":[],"hand_left_keypoints_3d":[],"hand_right_keypoints_3d":[]}]} -------------------------------------------------------------------------------- /test_datas/parsing/16_5DDABAA148AB70GS_2319034_forcast-6638-4309132-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/16_5DDABAA148AB70GS_2319034_forcast-6638-4309132-4.png -------------------------------------------------------------------------------- /test_datas/parsing/16_6A959AA1C79B20GS_2541597_vero-moda-8715-7951452-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/16_6A959AA1C79B20GS_2541597_vero-moda-8715-7951452-4.png -------------------------------------------------------------------------------- /test_datas/parsing/16_71E4DAAEED7A2AGS_2563260_uniqtee-6218-0623652-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/16_71E4DAAEED7A2AGS_2563260_uniqtee-6218-0623652-4.png -------------------------------------------------------------------------------- /test_datas/parsing/16_7F4CBAA498AD56GS_2367987_niko-and-0495-7897632-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/16_7F4CBAA498AD56GS_2367987_niko-and-0495-7897632-4.png -------------------------------------------------------------------------------- /test_datas/parsing/16_8533AAA1AE4F33GS_2524699_urban-revivo-6435-9964252-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/16_8533AAA1AE4F33GS_2524699_urban-revivo-6435-9964252-4.png -------------------------------------------------------------------------------- /test_datas/parsing/16_A4B3CAADFF38DEGS_2310129_zalora-basics-6049-9210132-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/16_A4B3CAADFF38DEGS_2310129_zalora-basics-6049-9210132-4.png -------------------------------------------------------------------------------- /test_datas/parsing/16_B6D3CAA13E73ADGS_1490146_hopeshow-0620-6410941-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/16_B6D3CAA13E73ADGS_1490146_hopeshow-0620-6410941-4.png -------------------------------------------------------------------------------- /test_datas/parsing/17_13524AA81E7E87GS_2532771_nike-5825-1772352-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/17_13524AA81E7E87GS_2532771_nike-5825-1772352-4.png -------------------------------------------------------------------------------- /test_datas/parsing/17_6C790AA2F9B6D8GS_2518373_savel-5476-3738152-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/17_6C790AA2F9B6D8GS_2518373_savel-5476-3738152-4.png -------------------------------------------------------------------------------- /test_datas/parsing/17_81ADCAABC27079GS_1995579_gap-3776-9755991-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/17_81ADCAABC27079GS_1995579_gap-3776-9755991-4.png -------------------------------------------------------------------------------- /test_datas/parsing/17_93BC2AA1BBA46FGS_2392311_cotton-on-body-0401-1132932-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/17_93BC2AA1BBA46FGS_2392311_cotton-on-body-0401-1132932-4.png -------------------------------------------------------------------------------- /test_datas/parsing/17_AD48BAAB5D1E15GS_1803082_lc-waikiki-4177-2803081-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/17_AD48BAAB5D1E15GS_1803082_lc-waikiki-4177-2803081-4.png -------------------------------------------------------------------------------- /test_datas/parsing/24_04BF5AA0C435A3GS_1830735_hm-8239-5370381-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/24_04BF5AA0C435A3GS_1830735_hm-8239-5370381-4.png -------------------------------------------------------------------------------- /test_datas/parsing/24_1FC88AA49BD873GS_1646422_zalora-work-7767-2246461-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/24_1FC88AA49BD873GS_1646422_zalora-work-7767-2246461-4.png -------------------------------------------------------------------------------- /test_datas/parsing/24_22BFFAA361A0BFGS_1707010_vero-moda-6549-0107071-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/24_22BFFAA361A0BFGS_1707010_vero-moda-6549-0107071-4.png -------------------------------------------------------------------------------- /test_datas/parsing/24_358BFAAD0FD0ADGS_1773796_vero-moda-5539-6973771-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/24_358BFAAD0FD0ADGS_1773796_vero-moda-5539-6973771-4.png -------------------------------------------------------------------------------- /test_datas/parsing/24_5A877AA4F29ABFGS_1733032_zalora-work-5797-2303371-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/24_5A877AA4F29ABFGS_1733032_zalora-work-5797-2303371-4.png -------------------------------------------------------------------------------- /test_datas/parsing/24_5AF3DAA5F87B36GS_1615066_zalora-basics-6811-6605161-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/24_5AF3DAA5F87B36GS_1615066_zalora-basics-6811-6605161-4.png -------------------------------------------------------------------------------- /test_datas/parsing/24_65412AAEE760D5GS_1686958_missguided-4171-8596861-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/24_65412AAEE760D5GS_1686958_missguided-4171-8596861-4.png -------------------------------------------------------------------------------- /test_datas/parsing/24_72650AA3CF0CB3GS_1530724_zalora-basics-9270-4270351-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/24_72650AA3CF0CB3GS_1530724_zalora-basics-9270-4270351-4.png -------------------------------------------------------------------------------- /test_datas/parsing/24_9A87AAA0D87689GS_1487834_desigual-6870-4387841-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/24_9A87AAA0D87689GS_1487834_desigual-6870-4387841-4.png -------------------------------------------------------------------------------- /test_datas/parsing/24_A5225AAA144C67GS_951967_brave-soul-0755-769159-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/24_A5225AAA144C67GS_951967_brave-soul-0755-769159-4.png -------------------------------------------------------------------------------- /test_datas/parsing/24_B4612AA480B8FCGS_1797475_desigual-8847-5747971-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/24_B4612AA480B8FCGS_1797475_desigual-8847-5747971-4.png -------------------------------------------------------------------------------- /test_datas/parsing/24_C99F9AA18E57E8GS_1590237_zalora-basics-7996-7320951-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/24_C99F9AA18E57E8GS_1590237_zalora-basics-7996-7320951-4.png -------------------------------------------------------------------------------- /test_datas/parsing/24_D394CAA4C7D0BAGS_1504042_puma-8627-2404051-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/test_datas/parsing/24_D394CAA4C7D0BAGS_1504042_puma-8627-2404051-4.png -------------------------------------------------------------------------------- /test_datas/test_pairs.txt: -------------------------------------------------------------------------------- 1 | 24_D394CAA4C7D0BAGS_1504042_puma-8627-2404051-4.jpg 24_C99F9AA18E57E8GS_1590237_zalora-basics-7996-7320951-4.jpg 2 | 16_6A959AA1C79B20GS_2541597_vero-moda-8715-7951452-4.jpg 16_A4B3CAADFF38DEGS_2310129_zalora-basics-6049-9210132-4.jpg 3 | 16_A4B3CAADFF38DEGS_2310129_zalora-basics-6049-9210132-4.jpg 24_D394CAA4C7D0BAGS_1504042_puma-8627-2404051-4.jpg 4 | 17_93BC2AA1BBA46FGS_2392311_cotton-on-body-0401-1132932-4.jpg 17_13524AA81E7E87GS_2532771_nike-5825-1772352-4.jpg 5 | 17_6C790AA2F9B6D8GS_2518373_savel-5476-3738152-4.jpg 24_22BFFAA361A0BFGS_1707010_vero-moda-6549-0107071-4.jpg 6 | 24_B4612AA480B8FCGS_1797475_desigual-8847-5747971-4.jpg 24_72650AA3CF0CB3GS_1530724_zalora-basics-9270-4270351-4.jpg 7 | 24_C99F9AA18E57E8GS_1590237_zalora-basics-7996-7320951-4.jpg 24_A5225AAA144C67GS_951967_brave-soul-0755-769159-4.jpg 8 | 24_22BFFAA361A0BFGS_1707010_vero-moda-6549-0107071-4.jpg 17_AD48BAAB5D1E15GS_1803082_lc-waikiki-4177-2803081-4.jpg 9 | 17_AD48BAAB5D1E15GS_1803082_lc-waikiki-4177-2803081-4.jpg 16_7F4CBAA498AD56GS_2367987_niko-and-0495-7897632-4.jpg 10 | 16_7F4CBAA498AD56GS_2367987_niko-and-0495-7897632-4.jpg 16_8533AAA1AE4F33GS_2524699_urban-revivo-6435-9964252-4.jpg 11 | 24_A5225AAA144C67GS_951967_brave-soul-0755-769159-4.jpg 16_5DDABAA148AB70GS_2319034_forcast-6638-4309132-4.jpg 12 | 16_5DDABAA148AB70GS_2319034_forcast-6638-4309132-4.jpg 17_93BC2AA1BBA46FGS_2392311_cotton-on-body-0401-1132932-4.jpg 13 | 24_5A877AA4F29ABFGS_1733032_zalora-work-5797-2303371-4.jpg 16_71E4DAAEED7A2AGS_2563260_uniqtee-6218-0623652-4.jpg 14 | 17_13524AA81E7E87GS_2532771_nike-5825-1772352-4.jpg 24_B4612AA480B8FCGS_1797475_desigual-8847-5747971-4.jpg 15 | 17_81ADCAABC27079GS_1995579_gap-3776-9755991-4.jpg 24_65412AAEE760D5GS_1686958_missguided-4171-8596861-4.jpg 16 | 16_71E4DAAEED7A2AGS_2563260_uniqtee-6218-0623652-4.jpg 24_9A87AAA0D87689GS_1487834_desigual-6870-4387841-4.jpg 17 | 24_65412AAEE760D5GS_1686958_missguided-4171-8596861-4.jpg 16_B6D3CAA13E73ADGS_1490146_hopeshow-0620-6410941-4.jpg 18 | 16_8533AAA1AE4F33GS_2524699_urban-revivo-6435-9964252-4.jpg 24_5AF3DAA5F87B36GS_1615066_zalora-basics-6811-6605161-4.jpg 19 | 24_5AF3DAA5F87B36GS_1615066_zalora-basics-6811-6605161-4.jpg 24_1FC88AA49BD873GS_1646422_zalora-work-7767-2246461-4.jpg 20 | 24_9A87AAA0D87689GS_1487834_desigual-6870-4387841-4.jpg 24_358BFAAD0FD0ADGS_1773796_vero-moda-5539-6973771-4.jpg 21 | 16_B6D3CAA13E73ADGS_1490146_hopeshow-0620-6410941-4.jpg 16_6A959AA1C79B20GS_2541597_vero-moda-8715-7951452-4.jpg 22 | 24_358BFAAD0FD0ADGS_1773796_vero-moda-5539-6973771-4.jpg 24_5A877AA4F29ABFGS_1733032_zalora-work-5797-2303371-4.jpg 23 | 24_04BF5AA0C435A3GS_1830735_hm-8239-5370381-4.jpg 17_81ADCAABC27079GS_1995579_gap-3776-9755991-4.jpg 24 | 24_1FC88AA49BD873GS_1646422_zalora-work-7767-2246461-4.jpg 17_6C790AA2F9B6D8GS_2518373_savel-5476-3738152-4.jpg 25 | 24_72650AA3CF0CB3GS_1530724_zalora-basics-9270-4270351-4.jpg 24_04BF5AA0C435A3GS_1830735_hm-8239-5370381-4.jpg 26 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/__pycache__/custom_ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/__pycache__/custom_ops.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/__pycache__/custom_ops.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/__pycache__/custom_ops.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/__pycache__/misc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/__pycache__/misc.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/__pycache__/persistence.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/__pycache__/persistence.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/__pycache__/persistence.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/__pycache__/persistence.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/__pycache__/training_stats.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/__pycache__/training_stats.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/__pycache__/training_stats.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/__pycache__/training_stats.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import glob 11 | import torch 12 | import torch.utils.cpp_extension 13 | import importlib 14 | import hashlib 15 | import shutil 16 | from pathlib import Path 17 | 18 | from torch.utils.file_baton import FileBaton 19 | 20 | #---------------------------------------------------------------------------- 21 | # Global options. 22 | 23 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 24 | 25 | #---------------------------------------------------------------------------- 26 | # Internal helper funcs. 27 | 28 | def _find_compiler_bindir(): 29 | patterns = [ 30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 34 | ] 35 | for pattern in patterns: 36 | matches = sorted(glob.glob(pattern)) 37 | if len(matches): 38 | return matches[-1] 39 | return None 40 | 41 | #---------------------------------------------------------------------------- 42 | # Main entry point for compiling and loading C++/CUDA plugins. 43 | 44 | _cached_plugins = dict() 45 | 46 | def get_plugin(module_name, sources, **build_kwargs): 47 | assert verbosity in ['none', 'brief', 'full'] 48 | 49 | # Already cached? 50 | if module_name in _cached_plugins: 51 | return _cached_plugins[module_name] 52 | 53 | # Print status. 54 | if verbosity == 'full': 55 | print(f'Setting up PyTorch plugin "{module_name}"...') 56 | elif verbosity == 'brief': 57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 58 | 59 | try: # pylint: disable=too-many-nested-blocks 60 | # Make sure we can find the necessary compiler binaries. 61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 62 | compiler_bindir = _find_compiler_bindir() 63 | if compiler_bindir is None: 64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 65 | os.environ['PATH'] += ';' + compiler_bindir 66 | 67 | # Compile and load. 68 | verbose_build = (verbosity == 'full') 69 | 70 | # Incremental build md5sum trickery. Copies all the input source files 71 | # into a cached build directory under a combined md5 digest of the input 72 | # source files. Copying is done only if the combined digest has changed. 73 | # This keeps input file timestamps and filenames the same as in previous 74 | # extension builds, allowing for fast incremental rebuilds. 75 | # 76 | # This optimization is done only in case all the source files reside in 77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 78 | # environment variable is set (we take this as a signal that the user 79 | # actually cares about this.) 80 | source_dirs_set = set(os.path.dirname(source) for source in sources) 81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): 82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) 83 | 84 | # Compute a combined hash digest for all source files in the same 85 | # custom op directory (usually .cu, .cpp, .py and .h files). 86 | hash_md5 = hashlib.md5() 87 | for src in all_source_files: 88 | with open(src, 'rb') as f: 89 | hash_md5.update(f.read()) 90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 92 | 93 | if not os.path.isdir(digest_build_dir): 94 | os.makedirs(digest_build_dir, exist_ok=True) 95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock')) 96 | if baton.try_acquire(): 97 | try: 98 | for src in all_source_files: 99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) 100 | finally: 101 | baton.release() 102 | else: 103 | # Someone else is copying source files under the digest dir, 104 | # wait until done and continue. 105 | baton.wait() 106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] 107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, 108 | verbose=verbose_build, sources=digest_sources, **build_kwargs) 109 | else: 110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 111 | module = importlib.import_module(module_name) 112 | 113 | except: 114 | if verbosity == 'brief': 115 | print('Failed!') 116 | raise 117 | 118 | # Print status and add to cache. 119 | if verbosity == 'full': 120 | print(f'Done setting up PyTorch plugin "{module_name}".') 121 | elif verbosity == 'brief': 122 | print('Done.') 123 | _cached_plugins[module_name] = module 124 | return module 125 | 126 | #---------------------------------------------------------------------------- 127 | -------------------------------------------------------------------------------- /torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import re 10 | import contextlib 11 | import numpy as np 12 | import torch 13 | import warnings 14 | import dnnlib 15 | 16 | #---------------------------------------------------------------------------- 17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 18 | # same constant is used multiple times. 19 | 20 | _constant_cache = dict() 21 | 22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 23 | value = np.asarray(value) 24 | if shape is not None: 25 | shape = tuple(shape) 26 | if dtype is None: 27 | dtype = torch.get_default_dtype() 28 | if device is None: 29 | device = torch.device('cpu') 30 | if memory_format is None: 31 | memory_format = torch.contiguous_format 32 | 33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 34 | tensor = _constant_cache.get(key, None) 35 | if tensor is None: 36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 37 | if shape is not None: 38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 39 | tensor = tensor.contiguous(memory_format=memory_format) 40 | _constant_cache[key] = tensor 41 | return tensor 42 | 43 | #---------------------------------------------------------------------------- 44 | # Replace NaN/Inf with specified numerical values. 45 | 46 | try: 47 | nan_to_num = torch.nan_to_num # 1.8.0a0 48 | except AttributeError: 49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 50 | assert isinstance(input, torch.Tensor) 51 | if posinf is None: 52 | if input.dtype == torch.int64: 53 | posinf = torch.iinfo(input.dtype).max 54 | else: 55 | posinf = torch.finfo(input.dtype).max 56 | if neginf is None: 57 | if input.dtype == torch.int64: 58 | neginf = torch.iinfo(input.dtype).min 59 | else: 60 | neginf = torch.finfo(input.dtype).min 61 | assert nan == 0 62 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 63 | 64 | #---------------------------------------------------------------------------- 65 | # Symbolic assert. 66 | 67 | try: 68 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 69 | except AttributeError: 70 | symbolic_assert = torch.Assert # 1.7.0 71 | 72 | #---------------------------------------------------------------------------- 73 | # Context manager to suppress known warnings in torch.jit.trace(). 74 | 75 | class suppress_tracer_warnings(warnings.catch_warnings): 76 | def __enter__(self): 77 | super().__enter__() 78 | warnings.simplefilter('ignore', category=torch.jit.TracerWarning) 79 | return self 80 | 81 | #---------------------------------------------------------------------------- 82 | # Assert that the shape of a tensor matches the given list of integers. 83 | # None indicates that the size of a dimension is allowed to vary. 84 | # Performs symbolic assertion when used in torch.jit.trace(). 85 | 86 | def assert_shape(tensor, ref_shape): 87 | if tensor.ndim != len(ref_shape): 88 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 89 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 90 | if ref_size is None: 91 | pass 92 | elif isinstance(ref_size, torch.Tensor): 93 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 94 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 95 | elif isinstance(size, torch.Tensor): 96 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 97 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 98 | elif size != ref_size: 99 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 100 | 101 | #---------------------------------------------------------------------------- 102 | # Function decorator that calls torch.autograd.profiler.record_function(). 103 | 104 | def profiled_function(fn): 105 | def decorator(*args, **kwargs): 106 | with torch.autograd.profiler.record_function(fn.__name__): 107 | return fn(*args, **kwargs) 108 | decorator.__name__ = fn.__name__ 109 | return decorator 110 | 111 | #---------------------------------------------------------------------------- 112 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 113 | # indefinitely, shuffling items as it goes. 114 | 115 | class InfiniteSampler(torch.utils.data.Sampler): 116 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 117 | assert len(dataset) > 0 118 | assert num_replicas > 0 119 | assert 0 <= rank < num_replicas 120 | assert 0 <= window_size <= 1 121 | super().__init__(dataset) 122 | self.dataset = dataset 123 | self.rank = rank 124 | self.num_replicas = num_replicas 125 | self.shuffle = shuffle 126 | self.seed = seed 127 | self.window_size = window_size 128 | 129 | def __iter__(self): 130 | order = np.arange(len(self.dataset)) 131 | rnd = None 132 | window = 0 133 | if self.shuffle: 134 | rnd = np.random.RandomState(self.seed) 135 | rnd.shuffle(order) 136 | window = int(np.rint(order.size * self.window_size)) 137 | 138 | idx = 0 139 | while True: 140 | i = idx % order.size 141 | if idx % self.num_replicas == self.rank: 142 | yield order[i] 143 | if window >= 2: 144 | j = (i - rnd.randint(window)) % order.size 145 | order[i], order[j] = order[j], order[i] 146 | idx += 1 147 | 148 | #---------------------------------------------------------------------------- 149 | # Utilities for operating with torch.nn.Module parameters and buffers. 150 | 151 | def params_and_buffers(module): 152 | assert isinstance(module, torch.nn.Module) 153 | return list(module.parameters()) + list(module.buffers()) 154 | 155 | def named_params_and_buffers(module): 156 | assert isinstance(module, torch.nn.Module) 157 | return list(module.named_parameters()) + list(module.named_buffers()) 158 | 159 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 160 | assert isinstance(src_module, torch.nn.Module) 161 | assert isinstance(dst_module, torch.nn.Module) 162 | src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} 163 | for name, tensor in named_params_and_buffers(dst_module): 164 | assert (name in src_tensors) or (not require_all) 165 | if name in src_tensors: 166 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) 167 | 168 | #---------------------------------------------------------------------------- 169 | # Context manager for easily enabling/disabling DistributedDataParallel 170 | # synchronization. 171 | 172 | @contextlib.contextmanager 173 | def ddp_sync(module, sync): 174 | assert isinstance(module, torch.nn.Module) 175 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 176 | yield 177 | else: 178 | with module.no_sync(): 179 | yield 180 | 181 | #---------------------------------------------------------------------------- 182 | # Check DistributedDataParallel consistency across processes. 183 | 184 | def check_ddp_consistency(module, ignore_regex=None): 185 | assert isinstance(module, torch.nn.Module) 186 | for name, tensor in named_params_and_buffers(module): 187 | fullname = type(module).__name__ + '.' + name 188 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 189 | continue 190 | tensor = tensor.detach() 191 | other = tensor.clone() 192 | torch.distributed.broadcast(tensor=other, src=0) 193 | a = (nan_to_num(tensor) == nan_to_num(other)).all() 194 | if a == False: 195 | l = 1 196 | assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname 197 | 198 | #---------------------------------------------------------------------------- 199 | # Print summary table of module hierarchy. 200 | 201 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 202 | assert isinstance(module, torch.nn.Module) 203 | assert not isinstance(module, torch.jit.ScriptModule) 204 | assert isinstance(inputs, (tuple, list)) 205 | 206 | # Register hooks. 207 | entries = [] 208 | nesting = [0] 209 | def pre_hook(_mod, _inputs): 210 | nesting[0] += 1 211 | def post_hook(mod, _inputs, outputs): 212 | nesting[0] -= 1 213 | if nesting[0] <= max_nesting: 214 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 215 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 216 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 217 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 218 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 219 | 220 | # Run module. 221 | outputs = module(*inputs) 222 | for hook in hooks: 223 | hook.remove() 224 | 225 | # Identify unique outputs, parameters, and buffers. 226 | tensors_seen = set() 227 | for e in entries: 228 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 229 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 230 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 231 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 232 | 233 | # Filter out redundant entries. 234 | if skip_redundant: 235 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 236 | 237 | # Construct table. 238 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 239 | rows += [['---'] * len(rows[0])] 240 | param_total = 0 241 | buffer_total = 0 242 | submodule_names = {mod: name for name, mod in module.named_modules()} 243 | for e in entries: 244 | name = '' if e.mod is module else submodule_names[e.mod] 245 | param_size = sum(t.numel() for t in e.unique_params) 246 | buffer_size = sum(t.numel() for t in e.unique_buffers) 247 | output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] 248 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 249 | rows += [[ 250 | name + (':0' if len(e.outputs) >= 2 else ''), 251 | str(param_size) if param_size else '-', 252 | str(buffer_size) if buffer_size else '-', 253 | (output_shapes + ['-'])[0], 254 | (output_dtypes + ['-'])[0], 255 | ]] 256 | for idx in range(1, len(e.outputs)): 257 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 258 | param_total += param_size 259 | buffer_total += buffer_size 260 | rows += [['---'] * len(rows[0])] 261 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 262 | 263 | # Print table. 264 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 265 | print() 266 | for row in rows: 267 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 268 | print() 269 | return outputs 270 | 271 | #---------------------------------------------------------------------------- 272 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/ops/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/ops/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/bias_act.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/ops/__pycache__/bias_act.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/bias_act.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/ops/__pycache__/bias_act.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/conv2d_gradfix.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/conv2d_gradfix.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/conv2d_resample.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/ops/__pycache__/conv2d_resample.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/conv2d_resample.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/ops/__pycache__/conv2d_resample.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/fma.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/ops/__pycache__/fma.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/fma.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/ops/__pycache__/fma.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/upfirdn2d.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/ops/__pycache__/upfirdn2d.cpython-37.pyc -------------------------------------------------------------------------------- /torch_utils/ops/__pycache__/upfirdn2d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/torch_utils/ops/__pycache__/upfirdn2d.cpython-38.pyc -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import warnings 13 | import numpy as np 14 | import torch 15 | import dnnlib 16 | import traceback 17 | 18 | from .. import custom_ops 19 | from .. import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | activation_funcs = { 24 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 25 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 26 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 27 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 28 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 29 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 30 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 31 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 32 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 33 | } 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | _inited = False 38 | _plugin = None 39 | _null_tensor = torch.empty([0]) 40 | 41 | def _init(): 42 | global _inited, _plugin 43 | if not _inited: 44 | _inited = True 45 | sources = ['bias_act.cpp', 'bias_act.cu'] 46 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] 47 | try: 48 | _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) 49 | except: 50 | warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + traceback.format_exc()) 51 | return _plugin is not None 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 56 | r"""Fused bias and activation function. 57 | 58 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 59 | and scales the result by `gain`. Each of the steps is optional. In most cases, 60 | the fused op is considerably more efficient than performing the same calculation 61 | using standard PyTorch ops. It supports first and second order gradients, 62 | but not third order gradients. 63 | 64 | Args: 65 | x: Input activation tensor. Can be of any shape. 66 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 67 | as `x`. The shape must be known, and it must match the dimension of `x` 68 | corresponding to `dim`. 69 | dim: The dimension in `x` corresponding to the elements of `b`. 70 | The value of `dim` is ignored if `b` is not specified. 71 | act: Name of the activation function to evaluate, or `"linear"` to disable. 72 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 73 | See `activation_funcs` for a full list. `None` is not allowed. 74 | alpha: Shape parameter for the activation function, or `None` to use the default. 75 | gain: Scaling factor for the output tensor, or `None` to use default. 76 | See `activation_funcs` for the default scaling of each activation function. 77 | If unsure, consider specifying 1. 78 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 79 | the clamping (default). 80 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 81 | 82 | Returns: 83 | Tensor of the same shape and datatype as `x`. 84 | """ 85 | assert isinstance(x, torch.Tensor) 86 | assert impl in ['ref', 'cuda'] 87 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 88 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 89 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 90 | 91 | #---------------------------------------------------------------------------- 92 | 93 | @misc.profiled_function 94 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 95 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 96 | """ 97 | assert isinstance(x, torch.Tensor) 98 | assert clamp is None or clamp >= 0 99 | spec = activation_funcs[act] 100 | alpha = float(alpha if alpha is not None else spec.def_alpha) 101 | gain = float(gain if gain is not None else spec.def_gain) 102 | clamp = float(clamp if clamp is not None else -1) 103 | 104 | # Add bias. 105 | if b is not None: 106 | assert isinstance(b, torch.Tensor) and b.ndim == 1 107 | assert 0 <= dim < x.ndim 108 | assert b.shape[0] == x.shape[dim] 109 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 110 | 111 | # Evaluate activation function. 112 | alpha = float(alpha) 113 | x = spec.func(x, alpha=alpha) 114 | 115 | # Scale by gain. 116 | gain = float(gain) 117 | if gain != 1: 118 | x = x * gain 119 | 120 | # Clamp. 121 | if clamp >= 0: 122 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 123 | return x 124 | 125 | #---------------------------------------------------------------------------- 126 | 127 | _bias_act_cuda_cache = dict() 128 | 129 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 130 | """Fast CUDA implementation of `bias_act()` using custom ops. 131 | """ 132 | # Parse arguments. 133 | assert clamp is None or clamp >= 0 134 | spec = activation_funcs[act] 135 | alpha = float(alpha if alpha is not None else spec.def_alpha) 136 | gain = float(gain if gain is not None else spec.def_gain) 137 | clamp = float(clamp if clamp is not None else -1) 138 | 139 | # Lookup from cache. 140 | key = (dim, act, alpha, gain, clamp) 141 | if key in _bias_act_cuda_cache: 142 | return _bias_act_cuda_cache[key] 143 | 144 | # Forward op. 145 | class BiasActCuda(torch.autograd.Function): 146 | @staticmethod 147 | def forward(ctx, x, b): # pylint: disable=arguments-differ 148 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format 149 | x = x.contiguous(memory_format=ctx.memory_format) 150 | b = b.contiguous() if b is not None else _null_tensor 151 | y = x 152 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 153 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 154 | ctx.save_for_backward( 155 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 156 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 157 | y if 'y' in spec.ref else _null_tensor) 158 | return y 159 | 160 | @staticmethod 161 | def backward(ctx, dy): # pylint: disable=arguments-differ 162 | dy = dy.contiguous(memory_format=ctx.memory_format) 163 | x, b, y = ctx.saved_tensors 164 | dx = None 165 | db = None 166 | 167 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 168 | dx = dy 169 | if act != 'linear' or gain != 1 or clamp >= 0: 170 | dx = BiasActCudaGrad.apply(dy, x, b, y) 171 | 172 | if ctx.needs_input_grad[1]: 173 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 174 | 175 | return dx, db 176 | 177 | # Backward op. 178 | class BiasActCudaGrad(torch.autograd.Function): 179 | @staticmethod 180 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 181 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format 182 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 183 | ctx.save_for_backward( 184 | dy if spec.has_2nd_grad else _null_tensor, 185 | x, b, y) 186 | return dx 187 | 188 | @staticmethod 189 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 190 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 191 | dy, x, b, y = ctx.saved_tensors 192 | d_dy = None 193 | d_x = None 194 | d_b = None 195 | d_y = None 196 | 197 | if ctx.needs_input_grad[0]: 198 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 199 | 200 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 201 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 202 | 203 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 204 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 205 | 206 | return d_dy, d_x, d_b, d_y 207 | 208 | # Add to cache. 209 | _bias_act_cuda_cache[key] = BiasActCuda 210 | return BiasActCuda 211 | 212 | #---------------------------------------------------------------------------- 213 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import warnings 13 | import contextlib 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 24 | 25 | @contextlib.contextmanager 26 | def no_weight_gradients(): 27 | global weight_gradients_disabled 28 | old = weight_gradients_disabled 29 | weight_gradients_disabled = True 30 | yield 31 | weight_gradients_disabled = old 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 36 | if _should_use_custom_op(input): 37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 39 | 40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 41 | if _should_use_custom_op(input): 42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def _should_use_custom_op(input): 48 | assert isinstance(input, torch.Tensor) 49 | if (not enabled) or (not torch.backends.cudnn.enabled): 50 | return False 51 | if input.device.type != 'cuda': 52 | return False 53 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 54 | return True 55 | warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') 56 | return False 57 | 58 | def _tuple_of_ints(xs, ndim): 59 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 60 | assert len(xs) == ndim 61 | assert all(isinstance(x, int) for x in xs) 62 | return xs 63 | 64 | #---------------------------------------------------------------------------- 65 | 66 | _conv2d_gradfix_cache = dict() 67 | 68 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 69 | # Parse arguments. 70 | ndim = 2 71 | weight_shape = tuple(weight_shape) 72 | stride = _tuple_of_ints(stride, ndim) 73 | padding = _tuple_of_ints(padding, ndim) 74 | output_padding = _tuple_of_ints(output_padding, ndim) 75 | dilation = _tuple_of_ints(dilation, ndim) 76 | 77 | # Lookup from cache. 78 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 79 | if key in _conv2d_gradfix_cache: 80 | return _conv2d_gradfix_cache[key] 81 | 82 | # Validate arguments. 83 | assert groups >= 1 84 | assert len(weight_shape) == ndim + 2 85 | assert all(stride[i] >= 1 for i in range(ndim)) 86 | assert all(padding[i] >= 0 for i in range(ndim)) 87 | assert all(dilation[i] >= 0 for i in range(ndim)) 88 | if not transpose: 89 | assert all(output_padding[i] == 0 for i in range(ndim)) 90 | else: # transpose 91 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 92 | 93 | # Helpers. 94 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 95 | def calc_output_padding(input_shape, output_shape): 96 | if transpose: 97 | return [0, 0] 98 | return [ 99 | input_shape[i + 2] 100 | - (output_shape[i + 2] - 1) * stride[i] 101 | - (1 - 2 * padding[i]) 102 | - dilation[i] * (weight_shape[i + 2] - 1) 103 | for i in range(ndim) 104 | ] 105 | 106 | # Forward & backward. 107 | class Conv2d(torch.autograd.Function): 108 | @staticmethod 109 | def forward(ctx, input, weight, bias): 110 | assert weight.shape == weight_shape 111 | if not transpose: 112 | output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 113 | else: # transpose 114 | output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 115 | ctx.save_for_backward(input, weight) 116 | return output 117 | 118 | @staticmethod 119 | def backward(ctx, grad_output): 120 | input, weight = ctx.saved_tensors 121 | grad_input = None 122 | grad_weight = None 123 | grad_bias = None 124 | 125 | if ctx.needs_input_grad[0]: 126 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 127 | grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) 128 | assert grad_input.shape == input.shape 129 | 130 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 131 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 132 | assert grad_weight.shape == weight_shape 133 | 134 | if ctx.needs_input_grad[2]: 135 | grad_bias = grad_output.sum([0, 2, 3]) 136 | 137 | return grad_input, grad_weight, grad_bias 138 | 139 | # Gradient with respect to the weights. 140 | class Conv2dGradWeight(torch.autograd.Function): 141 | @staticmethod 142 | def forward(ctx, grad_output, input): 143 | op = torch._C._jit_get_operation('aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') 144 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 145 | grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 146 | assert grad_weight.shape == weight_shape 147 | ctx.save_for_backward(grad_output, input) 148 | return grad_weight 149 | 150 | @staticmethod 151 | def backward(ctx, grad2_grad_weight): 152 | grad_output, input = ctx.saved_tensors 153 | grad2_grad_output = None 154 | grad2_input = None 155 | 156 | if ctx.needs_input_grad[0]: 157 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 158 | assert grad2_grad_output.shape == grad_output.shape 159 | 160 | if ctx.needs_input_grad[1]: 161 | p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) 162 | grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) 163 | assert grad2_input.shape == input.shape 164 | 165 | return grad2_grad_output, grad2_input 166 | 167 | _conv2d_gradfix_cache[key] = Conv2d 168 | return Conv2d 169 | 170 | #---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | w = w.flip([2, 3]) 37 | 38 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using 39 | # 1x1 kernel + memory_format=channels_last + less than 64 channels. 40 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: 41 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: 42 | if out_channels <= 4 and groups == 1: 43 | in_shape = x.shape 44 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) 45 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) 46 | else: 47 | x = x.to(memory_format=torch.contiguous_format) 48 | w = w.to(memory_format=torch.contiguous_format) 49 | x = conv2d_gradfix.conv2d(x, w, groups=groups) 50 | return x.to(memory_format=torch.channels_last) 51 | 52 | # Otherwise => execute using conv2d_gradfix. 53 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 54 | return op(x, w, stride=stride, padding=padding, groups=groups) 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | @misc.profiled_function 59 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 60 | r"""2D convolution with optional up/downsampling. 61 | 62 | Padding is performed only once at the beginning, not between the operations. 63 | 64 | Args: 65 | x: Input tensor of shape 66 | `[batch_size, in_channels, in_height, in_width]`. 67 | w: Weight tensor of shape 68 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 69 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 70 | calling upfirdn2d.setup_filter(). None = identity (default). 71 | up: Integer upsampling factor (default: 1). 72 | down: Integer downsampling factor (default: 1). 73 | padding: Padding with respect to the upsampled image. Can be a single number 74 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 75 | (default: 0). 76 | groups: Split input channels into N groups (default: 1). 77 | flip_weight: False = convolution, True = correlation (default: True). 78 | flip_filter: False = convolution, True = correlation (default: False). 79 | 80 | Returns: 81 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 82 | """ 83 | # Validate arguments. 84 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 85 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 86 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 87 | assert isinstance(up, int) and (up >= 1) 88 | assert isinstance(down, int) and (down >= 1) 89 | assert isinstance(groups, int) and (groups >= 1) 90 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 91 | fw, fh = _get_filter_size(f) 92 | px0, px1, py0, py1 = _parse_padding(padding) 93 | 94 | # Adjust padding to account for up/downsampling. 95 | if up > 1: 96 | px0 += (fw + up - 1) // 2 97 | px1 += (fw - up) // 2 98 | py0 += (fh + up - 1) // 2 99 | py1 += (fh - up) // 2 100 | if down > 1: 101 | px0 += (fw - down + 1) // 2 102 | px1 += (fw - down) // 2 103 | py0 += (fh - down + 1) // 2 104 | py1 += (fh - down) // 2 105 | 106 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 107 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 108 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 109 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 110 | return x 111 | 112 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 113 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 115 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 116 | return x 117 | 118 | # Fast path: downsampling only => use strided convolution. 119 | if down > 1 and up == 1: 120 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 121 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 122 | return x 123 | 124 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 125 | if up > 1: 126 | if groups == 1: 127 | w = w.transpose(0, 1) 128 | else: 129 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 130 | w = w.transpose(1, 2) 131 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 132 | px0 -= kw - 1 133 | px1 -= kw - up 134 | py0 -= kh - 1 135 | py1 -= kh - up 136 | pxt = max(min(-px0, -px1), 0) 137 | pyt = max(min(-py0, -py1), 0) 138 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 139 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 140 | if down > 1: 141 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 142 | return x 143 | 144 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 145 | if up == 1 and down == 1: 146 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 147 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 148 | 149 | # Fallback: Generic reference implementation. 150 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 151 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 152 | if down > 1: 153 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 154 | return x 155 | 156 | #---------------------------------------------------------------------------- 157 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import warnings 15 | import torch 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | def grid_sample(input, grid): 28 | if _should_use_custom_op(): 29 | return _GridSample2dForward.apply(input, grid) 30 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def _should_use_custom_op(): 35 | if not enabled: 36 | return False 37 | if any(torch.__version__.startswith(x) for x in ['1.7.', '1.8.', '1.9']): 38 | return True 39 | warnings.warn(f'grid_sample_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.grid_sample().') 40 | return False 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | class _GridSample2dForward(torch.autograd.Function): 45 | @staticmethod 46 | def forward(ctx, input, grid): 47 | assert input.ndim == 4 48 | assert grid.ndim == 4 49 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 50 | ctx.save_for_backward(input, grid) 51 | return output 52 | 53 | @staticmethod 54 | def backward(ctx, grad_output): 55 | input, grid = ctx.saved_tensors 56 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 57 | return grad_input, grad_grid 58 | 59 | #---------------------------------------------------------------------------- 60 | 61 | class _GridSample2dBackward(torch.autograd.Function): 62 | @staticmethod 63 | def forward(ctx, grad_output, input, grid): 64 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | import dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. A typical use case is to first unpickle a previous 83 | instance of a persistent class, and then upgrade it to use the latest 84 | version of the source code: 85 | 86 | with open('old_pickle.pkl', 'rb') as f: 87 | old_net = pickle.load(f) 88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 90 | """ 91 | assert isinstance(orig_class, type) 92 | if is_persistent(orig_class): 93 | return orig_class 94 | 95 | assert orig_class.__module__ in sys.modules 96 | orig_module = sys.modules[orig_class.__module__] 97 | orig_module_src = _module_to_src(orig_module) 98 | 99 | class Decorator(orig_class): 100 | _orig_module_src = orig_module_src 101 | _orig_class_name = orig_class.__name__ 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | self._init_args = copy.deepcopy(args) 106 | self._init_kwargs = copy.deepcopy(kwargs) 107 | assert orig_class.__name__ in orig_module.__dict__ 108 | _check_pickleable(self.__reduce__()) 109 | 110 | @property 111 | def init_args(self): 112 | return copy.deepcopy(self._init_args) 113 | 114 | @property 115 | def init_kwargs(self): 116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 117 | 118 | def __reduce__(self): 119 | fields = list(super().__reduce__()) 120 | fields += [None] * max(3 - len(fields), 0) 121 | if fields[0] is not _reconstruct_persistent_obj: 122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 123 | fields[0] = _reconstruct_persistent_obj # reconstruct func 124 | fields[1] = (meta,) # reconstruct args 125 | fields[2] = None # state dict 126 | return tuple(fields) 127 | 128 | Decorator.__name__ = orig_class.__name__ 129 | _decorators.add(Decorator) 130 | return Decorator 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def is_persistent(obj): 135 | r"""Test whether the given object or class is persistent, i.e., 136 | whether it will save its source code when pickled. 137 | """ 138 | try: 139 | if obj in _decorators: 140 | return True 141 | except TypeError: 142 | pass 143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 144 | 145 | #---------------------------------------------------------------------------- 146 | 147 | def import_hook(hook): 148 | r"""Register an import hook that is called whenever a persistent object 149 | is being unpickled. A typical use case is to patch the pickled source 150 | code to avoid errors and inconsistencies when the API of some imported 151 | module has changed. 152 | 153 | The hook should have the following signature: 154 | 155 | hook(meta) -> modified meta 156 | 157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 158 | 159 | type: Type of the persistent object, e.g. `'class'`. 160 | version: Internal version number of `torch_utils.persistence`. 161 | module_src Original source code of the Python module. 162 | class_name: Class name in the original Python module. 163 | state: Internal state of the object. 164 | 165 | Example: 166 | 167 | @persistence.import_hook 168 | def wreck_my_network(meta): 169 | if meta.class_name == 'MyNetwork': 170 | print('MyNetwork is being imported. I will wreck it!') 171 | meta.module_src = meta.module_src.replace("True", "False") 172 | return meta 173 | """ 174 | assert callable(hook) 175 | _import_hooks.append(hook) 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | def _reconstruct_persistent_obj(meta): 180 | r"""Hook that is called internally by the `pickle` module to unpickle 181 | a persistent object. 182 | """ 183 | meta = dnnlib.EasyDict(meta) 184 | meta.state = dnnlib.EasyDict(meta.state) 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | return obj 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def _module_to_src(module): 207 | r"""Query the source code of a given Python module. 208 | """ 209 | src = _module_to_src_dict.get(module, None) 210 | if src is None: 211 | src = inspect.getsource(module) 212 | _module_to_src_dict[module] = src 213 | _src_to_module_dict[src] = module 214 | return src 215 | 216 | def _src_to_module(src): 217 | r"""Get or create a Python module for the given source code. 218 | """ 219 | module = _src_to_module_dict.get(src, None) 220 | if module is None: 221 | module_name = "_imported_module_" + uuid.uuid4().hex 222 | module = types.ModuleType(module_name) 223 | sys.modules[module_name] = module 224 | _module_to_src_dict[module] = src 225 | _src_to_module_dict[src] = module 226 | exec(src, module.__dict__) # pylint: disable=exec-used 227 | return module 228 | 229 | #---------------------------------------------------------------------------- 230 | 231 | def _check_pickleable(obj): 232 | r"""Check that the given object is pickleable, raising an exception if 233 | it is not. This function is expected to be considerably more efficient 234 | than actually pickling the object. 235 | """ 236 | def recurse(obj): 237 | if isinstance(obj, (list, tuple, set)): 238 | return [recurse(x) for x in obj] 239 | if isinstance(obj, dict): 240 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 242 | return None # Python primitive types are pickleable. 243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: 244 | return None # NumPy arrays and PyTorch tensors are pickleable. 245 | if is_persistent(obj): 246 | return None # Persistent objects are pickleable, by virtue of the constructor check. 247 | return obj 248 | with io.BytesIO() as f: 249 | pickle.dump(recurse(obj), f) 250 | 251 | #---------------------------------------------------------------------------- 252 | -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for reporting and collecting training statistics across 10 | multiple processes and devices. The interface is designed to minimize 11 | synchronization overhead as well as the amount of boilerplate in user 12 | code.""" 13 | 14 | import re 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | 19 | from . import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 25 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 26 | _rank = 0 # Rank of the current process. 27 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 28 | _sync_called = False # Has _sync() been called yet? 29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def init_multiprocessing(rank, sync_device): 35 | r"""Initializes `torch_utils.training_stats` for collecting statistics 36 | across multiple processes. 37 | 38 | This function must be called after 39 | `torch.distributed.init_process_group()` and before `Collector.update()`. 40 | The call is not necessary if multi-process collection is not needed. 41 | 42 | Args: 43 | rank: Rank of the current process. 44 | sync_device: PyTorch device to use for inter-process 45 | communication, or None to disable multi-process 46 | collection. Typically `torch.device('cuda', rank)`. 47 | """ 48 | global _rank, _sync_device 49 | assert not _sync_called 50 | _rank = rank 51 | _sync_device = sync_device 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | @misc.profiled_function 56 | def report(name, value): 57 | r"""Broadcasts the given set of scalars to all interested instances of 58 | `Collector`, across device and process boundaries. 59 | 60 | This function is expected to be extremely cheap and can be safely 61 | called from anywhere in the training loop, loss function, or inside a 62 | `torch.nn.Module`. 63 | 64 | Warning: The current implementation expects the set of unique names to 65 | be consistent across processes. Please make sure that `report()` is 66 | called at least once for each unique name by each process, and in the 67 | same order. If a given process has no scalars to broadcast, it can do 68 | `report(name, [])` (empty list). 69 | 70 | Args: 71 | name: Arbitrary string specifying the name of the statistic. 72 | Averages are accumulated separately for each unique name. 73 | value: Arbitrary set of scalars. Can be a list, tuple, 74 | NumPy array, PyTorch tensor, or Python scalar. 75 | 76 | Returns: 77 | The same `value` that was passed in. 78 | """ 79 | if name not in _counters: 80 | _counters[name] = dict() 81 | 82 | elems = torch.as_tensor(value) 83 | if elems.numel() == 0: 84 | return value 85 | 86 | elems = elems.detach().flatten().to(_reduce_dtype) 87 | moments = torch.stack([ 88 | torch.ones_like(elems).sum(), 89 | elems.sum(), 90 | elems.square().sum(), 91 | ]) 92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 93 | moments = moments.to(_counter_dtype) 94 | 95 | device = moments.device 96 | if device not in _counters[name]: 97 | _counters[name][device] = torch.zeros_like(moments) 98 | _counters[name][device].add_(moments) 99 | return value 100 | 101 | #---------------------------------------------------------------------------- 102 | 103 | def report0(name, value): 104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 105 | but ignores any scalars provided by the other processes. 106 | See `report()` for further details. 107 | """ 108 | report(name, value if _rank == 0 else []) 109 | return value 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | class Collector: 114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 115 | computes their long-term averages (mean and standard deviation) over 116 | user-defined periods of time. 117 | 118 | The averages are first collected into internal counters that are not 119 | directly visible to the user. They are then copied to the user-visible 120 | state as a result of calling `update()` and can then be queried using 121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 122 | internal counters for the next round, so that the user-visible state 123 | effectively reflects averages collected between the last two calls to 124 | `update()`. 125 | 126 | Args: 127 | regex: Regular expression defining which statistics to 128 | collect. The default is to collect everything. 129 | keep_previous: Whether to retain the previous averages if no 130 | scalars were collected on a given round 131 | (default: True). 132 | """ 133 | def __init__(self, regex='.*', keep_previous=True): 134 | self._regex = re.compile(regex) 135 | self._keep_previous = keep_previous 136 | self._cumulative = dict() 137 | self._moments = dict() 138 | self.update() 139 | self._moments.clear() 140 | 141 | def names(self): 142 | r"""Returns the names of all statistics broadcasted so far that 143 | match the regular expression specified at construction time. 144 | """ 145 | return [name for name in _counters if self._regex.fullmatch(name)] 146 | 147 | def update(self): 148 | r"""Copies current values of the internal counters to the 149 | user-visible state and resets them for the next round. 150 | 151 | If `keep_previous=True` was specified at construction time, the 152 | operation is skipped for statistics that have received no scalars 153 | since the last update, retaining their previous averages. 154 | 155 | This method performs a number of GPU-to-CPU transfers and one 156 | `torch.distributed.all_reduce()`. It is intended to be called 157 | periodically in the main training loop, typically once every 158 | N training steps. 159 | """ 160 | if not self._keep_previous: 161 | self._moments.clear() 162 | for name, cumulative in _sync(self.names()): 163 | if name not in self._cumulative: 164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 165 | delta = cumulative - self._cumulative[name] 166 | self._cumulative[name].copy_(cumulative) 167 | if float(delta[0]) != 0: 168 | self._moments[name] = delta 169 | 170 | def _get_delta(self, name): 171 | r"""Returns the raw moments that were accumulated for the given 172 | statistic between the last two calls to `update()`, or zero if 173 | no scalars were collected. 174 | """ 175 | assert self._regex.fullmatch(name) 176 | if name not in self._moments: 177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 178 | return self._moments[name] 179 | 180 | def num(self, name): 181 | r"""Returns the number of scalars that were accumulated for the given 182 | statistic between the last two calls to `update()`, or zero if 183 | no scalars were collected. 184 | """ 185 | delta = self._get_delta(name) 186 | return int(delta[0]) 187 | 188 | def mean(self, name): 189 | r"""Returns the mean of the scalars that were accumulated for the 190 | given statistic between the last two calls to `update()`, or NaN if 191 | no scalars were collected. 192 | """ 193 | delta = self._get_delta(name) 194 | if int(delta[0]) == 0: 195 | return float('nan') 196 | return float(delta[1] / delta[0]) 197 | 198 | def std(self, name): 199 | r"""Returns the standard deviation of the scalars that were 200 | accumulated for the given statistic between the last two calls to 201 | `update()`, or NaN if no scalars were collected. 202 | """ 203 | delta = self._get_delta(name) 204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 205 | return float('nan') 206 | if int(delta[0]) == 1: 207 | return float(0) 208 | mean = float(delta[1] / delta[0]) 209 | raw_var = float(delta[2] / delta[0]) 210 | return np.sqrt(max(raw_var - np.square(mean), 0)) 211 | 212 | def as_dict(self): 213 | r"""Returns the averages accumulated between the last two calls to 214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 215 | 216 | dnnlib.EasyDict( 217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 218 | ... 219 | ) 220 | """ 221 | stats = dnnlib.EasyDict() 222 | for name in self.names(): 223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 224 | return stats 225 | 226 | def __getitem__(self, name): 227 | r"""Convenience getter. 228 | `collector[name]` is a synonym for `collector.mean(name)`. 229 | """ 230 | return self.mean(name) 231 | 232 | #---------------------------------------------------------------------------- 233 | 234 | def _sync(names): 235 | r"""Synchronize the global cumulative counters across devices and 236 | processes. Called internally by `Collector.update()`. 237 | """ 238 | if len(names) == 0: 239 | return [] 240 | global _sync_called 241 | _sync_called = True 242 | 243 | # Collect deltas within current rank. 244 | deltas = [] 245 | device = _sync_device if _sync_device is not None else torch.device('cpu') 246 | for name in names: 247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 248 | for counter in _counters[name].values(): 249 | delta.add_(counter.to(device)) 250 | counter.copy_(torch.zeros_like(counter)) 251 | deltas.append(delta) 252 | deltas = torch.stack(deltas) 253 | 254 | # Sum deltas across ranks. 255 | if _sync_device is not None: 256 | torch.distributed.all_reduce(deltas) 257 | 258 | # Update cumulative values. 259 | deltas = deltas.cpu() 260 | for idx, name in enumerate(names): 261 | if name not in _cumulative: 262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 263 | _cumulative[name].add_(deltas[idx]) 264 | 265 | # Return name-value pairs. 266 | return [(name, _cumulative[name]) for name in names] 267 | 268 | #---------------------------------------------------------------------------- 269 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | if [ $1 == 1 ]; then 3 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -W ignore train.py \ 4 | --outdir ./training-runs-fullbody \ 5 | --data /datazy/Datasets/UPT_512_320 \ 6 | --gpus 8 --cfg fashion \ 7 | --cond true --batch 24 --l1_weight 10 --seed 1 \ 8 | --vgg_weight 20 --use_noise_const_branch True \ 9 | --workers 4 --contextual_weight 0 --pl_weight 0 \ 10 | --mask_weight 30 11 | fi -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /training/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/training/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /training/__pycache__/dataset.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/training/__pycache__/dataset.cpython-37.pyc -------------------------------------------------------------------------------- /training/__pycache__/networks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/training/__pycache__/networks.cpython-37.pyc -------------------------------------------------------------------------------- /training/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiezhy6/PASTA-GAN-plusplus/6a4e4cb1d25bceb4c4fbf611a5111e2bc23962a0/training/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /training/utils.py: -------------------------------------------------------------------------------- 1 | # coding:utf-8 2 | 3 | import os 4 | import cv2 5 | import json 6 | import numpy as np 7 | import pycocotools.mask as maskUtils 8 | import math 9 | 10 | def get_mask_from_kps(kps, height, width): 11 | rles = maskUtils.frPyObjects(kps, height, width) 12 | rle = maskUtils.merge(rles) 13 | mask = maskUtils.decode(rle)[...,np.newaxis].astype(np.float32) 14 | mask = mask * 255.0 15 | return mask 16 | 17 | def get_rectangle_mask(a, b, c, d, height, width): 18 | x1 = a + (b-d)/4 19 | y1 = b + (c-a)/4 20 | x2 = a - (b-d)/4 21 | y2 = b - (c-a)/4 22 | 23 | x3 = c + (b-d)/4 24 | y3 = d + (c-a)/4 25 | x4 = c - (b-d)/4 26 | y4 = d - (c-a)/4 27 | kps = [x1,y1,x2,y2] 28 | 29 | v0_x = c-a 30 | v0_y = d-b 31 | v1_x = x3-x1 32 | v1_y = y3-y1 33 | v2_x = x4-x1 34 | v2_y = y4-y1 35 | 36 | cos1 = (v0_x*v1_x+v0_y*v1_y) / (math.sqrt(v0_x*v0_x+v0_y*v0_y)*math.sqrt(v1_x*v1_x+v1_y*v1_y)) 37 | cos2 = (v0_x*v2_x+v0_y*v2_y) / (math.sqrt(v0_x*v0_x+v0_y*v0_y)*math.sqrt(v2_x*v2_x+v2_y*v2_y)) 38 | 39 | if cos1 0.1 and e_c > 0.1: 59 | up_mask = get_rectangle_mask(s_x, s_y, e_x, e_y, height, width) 60 | # 对上半部分进行膨胀操作,消除两部分之间的空隙 61 | kernel = np.ones((20,20),np.uint8) 62 | up_mask = cv2.dilate(up_mask,kernel,iterations = 1) 63 | up_mask = (up_mask > 0).astype(np.float32)[...,np.newaxis] 64 | if e_c > 0.1 and w_c > 0.1: 65 | bottom_mask = get_rectangle_mask(e_x, e_y, w_x, w_y, height, width) 66 | bottom_mask = (bottom_mask > 0).astype(np.float32) 67 | 68 | return up_mask, bottom_mask 69 | 70 | 71 | def get_palm_mask(hand_mask, hand_up_mask, hand_bottom_mask): 72 | inter_up_mask = (hand_mask + hand_up_mask == 2).astype(np.float32) 73 | inter_bottom_mask = (hand_mask + hand_bottom_mask == 2).astype(np.float32) 74 | palm_mask = hand_mask - inter_up_mask - inter_bottom_mask 75 | 76 | return palm_mask -------------------------------------------------------------------------------- /util_classes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm 4 | 5 | 6 | class Normalize(nn.Module): 7 | def __init__(self, power=2): 8 | super(Normalize, self).__init__() 9 | self.power = power 10 | 11 | def forward(self, x): 12 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 13 | out = x.div(norm + 1e-7) 14 | return out 15 | 16 | 17 | def apply_offset(offset): 18 | ''' 19 | convert offset grid to location grid 20 | offset: [N, 2, H, W] for 2D or [N, 3, D, H, W] for 3D 21 | output: [N, 2, H, W] for 2D or [N, 3, D, H, W] for 3D 22 | ''' 23 | sizes = list(offset.size()[2:]) 24 | grid_list = torch.meshgrid([torch.arange(size, device=offset.device) for size in sizes]) 25 | grid_list = reversed(grid_list) 26 | # apply offset 27 | grid_list = [grid.float().unsqueeze(0) + offset[:, dim, ...] 28 | for dim, grid in enumerate(grid_list)] 29 | # normalize 30 | grid_list = [grid / ((size - 1.0) / 2.0) - 1.0 31 | for grid, size in zip(grid_list, reversed(sizes))] 32 | return torch.stack(grid_list, dim=-1) 33 | 34 | 35 | def spectral_norm(module, use_spect=True): 36 | """use spectral normal layer to stable the training process""" 37 | if use_spect: 38 | return SpectralNorm(module) 39 | else: 40 | return module 41 | 42 | 43 | class AddCoords(nn.Module): 44 | """ 45 | Add Coords to a tensor 46 | """ 47 | def __init__(self, with_r=False): 48 | super(AddCoords, self).__init__() 49 | self.with_r = with_r 50 | 51 | def forward(self, x): 52 | """ 53 | :param x: shape (batch, channel, x_dim, y_dim) 54 | :return: shape (batch, channel+2, x_dim, y_dim) 55 | """ 56 | B, _, x_dim, y_dim = x.size() 57 | 58 | # coord calculate 59 | xx_channel = torch.arange(x_dim).repeat(B, 1, y_dim, 1).type_as(x) 60 | yy_cahnnel = torch.arange(y_dim).repeat(B, 1, x_dim, 1).permute(0, 1, 3, 2).type_as(x) 61 | # normalization 62 | xx_channel = xx_channel.float() / (x_dim-1) 63 | yy_cahnnel = yy_cahnnel.float() / (y_dim-1) 64 | xx_channel = xx_channel * 2 - 1 65 | yy_cahnnel = yy_cahnnel * 2 - 1 66 | 67 | ret = torch.cat([x, xx_channel, yy_cahnnel], dim=1) 68 | 69 | if self.with_r: 70 | rr = torch.sqrt(xx_channel ** 2 + yy_cahnnel ** 2) 71 | ret = torch.cat([ret, rr], dim=1) 72 | 73 | return ret 74 | 75 | 76 | class CoordConv(nn.Module): 77 | """ 78 | CoordConv operation 79 | """ 80 | def __init__(self, input_nc, output_nc, with_r=False, use_spect=False, **kwargs): 81 | super(CoordConv, self).__init__() 82 | self.addcoords = AddCoords(with_r=with_r) 83 | input_nc = input_nc + 2 84 | if with_r: 85 | input_nc = input_nc + 1 86 | self.conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) 87 | 88 | def forward(self, x): 89 | ret = self.addcoords(x) 90 | ret = self.conv(ret) 91 | 92 | return ret 93 | 94 | 95 | def coord_conv(input_nc, output_nc, use_spect=False, use_coord=False, with_r=False, **kwargs): 96 | """use coord convolution layer to add position information""" 97 | if use_coord: 98 | return CoordConv(input_nc, output_nc, with_r, use_spect, **kwargs) 99 | else: 100 | return spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) 101 | 102 | 103 | class EncoderBlock(nn.Module): 104 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), 105 | use_spect=False, use_coord=False, downsample=True): 106 | super(EncoderBlock, self).__init__() 107 | 108 | kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1} 109 | kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1} 110 | 111 | if downsample: 112 | conv1 = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs_down) 113 | else: 114 | conv1 = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs_fine) 115 | conv2 = coord_conv(output_nc, output_nc, use_spect, use_coord, **kwargs_fine) 116 | 117 | if type(norm_layer) == type(None): 118 | self.model = nn.Sequential(nonlinearity, conv1, nonlinearity, conv2,) 119 | else: 120 | self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, conv1, 121 | norm_layer(output_nc), nonlinearity, conv2,) 122 | 123 | def forward(self, x): 124 | out = self.model(x) 125 | return out 126 | 127 | 128 | class ResBlockDecoder(nn.Module): 129 | """ 130 | Define a decoder block 131 | """ 132 | def __init__(self, input_nc, output_nc, hidden_nc=None, norm_layer=nn.BatchNorm2d, nonlinearity= nn.LeakyReLU(), 133 | use_spect=False, use_coord=False,upsample=True): 134 | super(ResBlockDecoder, self).__init__() 135 | 136 | self.upsample = upsample 137 | hidden_nc = input_nc if hidden_nc is None else hidden_nc 138 | conv1 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, kernel_size=3, stride=1, padding=1), use_spect) 139 | if upsample: 140 | conv2 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, kernel_size=3, stride=2, padding=1, output_padding=1), use_spect) 141 | bypass = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, kernel_size=3, stride=2, padding=1, output_padding=1), use_spect) 142 | self.shortcut = nn.Sequential(bypass) 143 | else: 144 | conv2 = spectral_norm(nn.Conv2d(hidden_nc, output_nc, kernel_size=3, stride=1, padding=1), use_spect) 145 | 146 | if type(norm_layer) == type(None): 147 | self.model = nn.Sequential(nonlinearity, conv1, nonlinearity, conv2,) 148 | else: 149 | self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, conv1, norm_layer(hidden_nc), nonlinearity, conv2,) 150 | 151 | 152 | def forward(self, x): 153 | if self.upsample: 154 | out = self.model(x) + self.shortcut(x) 155 | else: 156 | out = self.model(x)+x 157 | return out 158 | 159 | 160 | class Jump(nn.Module): 161 | """ 162 | Define the output layer 163 | """ 164 | def __init__(self, input_nc, output_nc, kernel_size = 3, norm_layer=nn.BatchNorm2d, nonlinearity= nn.LeakyReLU(), 165 | use_spect=False, use_coord=False): 166 | super(Jump, self).__init__() 167 | 168 | kwargs = {'kernel_size': kernel_size, 'padding':0, 'bias': True} 169 | 170 | self.conv1 = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs) 171 | 172 | if type(norm_layer) == type(None): 173 | self.model = nn.Sequential(nonlinearity, nn.ReflectionPad2d(int(kernel_size/2)), self.conv1) 174 | else: 175 | self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, nn.ReflectionPad2d(int(kernel_size / 2)), self.conv1) 176 | 177 | def forward(self, x): 178 | out = self.model(x) 179 | return out --------------------------------------------------------------------------------