├── .gitignore ├── .style.yapf ├── LICENSE ├── README.md ├── data ├── ILSVRC2012_val_00000525.JPEG ├── ILSVRC2012_val_00001172.JPEG ├── ILSVRC2012_val_00001970.JPEG ├── ILSVRC2012_val_00003004.JPEG ├── ILSVRC2012_val_00004291.JPEG ├── ILSVRC2012_val_00008229.JPEG ├── ILSVRC2012_val_00020814.JPEG ├── ILSVRC2012_val_00022130.JPEG ├── ILSVRC2012_val_00042095.JPEG ├── ILSVRC2012_val_00044065.JPEG ├── ILSVRC2012_val_00044640.JPEG ├── manipulation.gif ├── others │ ├── library1.jpg │ ├── library2.jpeg │ ├── list.txt │ ├── stones.jpg │ ├── window1.jpg │ ├── window2.jpg │ └── window3.jpeg ├── restoration.gif └── windows.png ├── dataset.py ├── example.py ├── experiments ├── examples │ ├── run_SR.sh │ ├── run_category_transfer.sh │ ├── run_colorization.sh │ ├── run_inpainting.sh │ ├── run_inpainting_list.sh │ ├── run_jitter.sh │ └── run_morphing.sh └── imagenet1k_128 │ ├── SR │ ├── D_biased │ │ ├── train.sh │ │ └── train_slurm.sh │ └── MSE_biased │ │ ├── train.sh │ │ └── train_slurm.sh │ ├── colorization │ ├── train.sh │ └── train_slurm.sh │ └── inpainting │ ├── train.sh │ └── train_slurm.sh ├── main.py ├── models ├── __init__.py ├── biggan.py ├── dgp.py ├── downsampler.py ├── layers.py └── nethook.py ├── pretrained └── .gitignore ├── requirements.txt ├── scripts └── imagenet_val_1k.txt ├── trainer.py └── utils ├── __init__.py ├── biggan_utils.py ├── common_utils.py ├── distributed_utils.py └── losses.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | __pycache__ 4 | *.tar 5 | checkpoint* 6 | images* 7 | logs* 8 | event* 9 | *.npy 10 | *.pth 11 | -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | BASED_ON_STYLE = pep8 3 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true 4 | SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true 5 | COLUMN_LIMIT = 100 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Xingang Pan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Deep Generative Prior (DGP) 2 | 3 | ### Paper 4 | 5 | Xingang Pan, Xiaohang Zhan, Bo Dai, Dahua Lin, Chen Change Loy, Ping Luo, "[Exploiting Deep Generative Prior for Versatile Image Restoration and Manipulation](https://arxiv.org/abs/2003.13659)", ECCV2020 (**Oral**) 6 | 7 | Video: https://youtu.be/p7ToqtwfVko
8 | 9 | ### Demos 10 | 11 | DGP exploits the image prior of an off-the-shelf GAN for various image restoration and manipulation. 12 | 13 | **Image restoration**: 14 | 15 |

16 | 17 |

18 | 19 | **Image manipulation**: 20 | 21 |

22 | 23 |

24 | 25 | A **learned prior** helps **internal learning**: 26 | 27 |

28 | 29 |

30 | 31 | ### Requirements 32 | 33 | * python>=3.6 34 | * pytorch>=1.0.1 35 | * others 36 | 37 | ```sh 38 | pip install -r requirements.txt 39 | ``` 40 | 41 | ### Get Started 42 | 43 | Before start, please download the pretrained BigGAN at [Google drive](https://drive.google.com/drive/folders/1buQ2BtbnUhkh4PEPXOgdPuVo2iRK7gvI?usp=sharing) or [Baidu cloud](https://pan.baidu.com/s/10GKkWt7kSClvhnEGQU4ckA) (password: uqtw), and put them to `pretrained` folder. 44 | 45 | Example1: run image colorization example: 46 | 47 | sh experiments/examples/run_colorization.sh 48 | 49 | The results will be saved in `experiments/examples/images` and `experiments/examples/image_sheet`. 50 | 51 | Example2: process images with an image list: 52 | 53 | sh experiments/examples/run_inpainting_list.sh 54 | 55 | Example3: evaluate on 1k ImageNet validation images via distributed training based on [slurm](https://slurm.schedmd.com/): 56 | 57 | # need to specifiy the root path of imagenet validate set in --root_dir 58 | sh experiments/imagenet1k_128/colorization/train_slurm.sh 59 | 60 | Note: 61 | \- BigGAN needs a class condition as input. If no class condition is provided, it would be chosen from a set of random samples. 62 | \- The hyperparameters provided may not be optimal, feel free to tune them. 63 | 64 | ### Acknowledgement 65 | 66 | The code of BigGAN is borrowed from [https://github.com/ajbrock/BigGAN-PyTorch](https://github.com/ajbrock/BigGAN-PyTorch). 67 | 68 | ### Citation 69 | 70 | ``` 71 | @inproceedings{pan2020dgp, 72 | author = {Pan, Xingang and Zhan, Xiaohang and Dai, Bo and Lin, Dahua and Loy, Chen Change and Luo, Ping}, 73 | title = {Exploiting Deep Generative Prior for Versatile Image Restoration and Manipulation}, 74 | booktitle = {European Conference on Computer Vision (ECCV)}, 75 | year = {2020} 76 | } 77 | 78 | @ARTICLE{pan2020dgp_pami, 79 | author={Pan, Xingang and Zhan, Xiaohang and Dai, Bo and Lin, Dahua and Loy, Chen Change and Luo, Ping}, 80 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 81 | title={Exploiting Deep Generative Prior for Versatile Image Restoration and Manipulation}, 82 | year={2021}, 83 | volume={}, 84 | number={}, 85 | pages={1-1}, 86 | doi={10.1109/TPAMI.2021.3115428} 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /data/ILSVRC2012_val_00000525.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00000525.JPEG -------------------------------------------------------------------------------- /data/ILSVRC2012_val_00001172.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00001172.JPEG -------------------------------------------------------------------------------- /data/ILSVRC2012_val_00001970.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00001970.JPEG -------------------------------------------------------------------------------- /data/ILSVRC2012_val_00003004.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00003004.JPEG -------------------------------------------------------------------------------- /data/ILSVRC2012_val_00004291.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00004291.JPEG -------------------------------------------------------------------------------- /data/ILSVRC2012_val_00008229.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00008229.JPEG -------------------------------------------------------------------------------- /data/ILSVRC2012_val_00020814.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00020814.JPEG -------------------------------------------------------------------------------- /data/ILSVRC2012_val_00022130.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00022130.JPEG -------------------------------------------------------------------------------- /data/ILSVRC2012_val_00042095.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00042095.JPEG -------------------------------------------------------------------------------- /data/ILSVRC2012_val_00044065.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00044065.JPEG -------------------------------------------------------------------------------- /data/ILSVRC2012_val_00044640.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00044640.JPEG -------------------------------------------------------------------------------- /data/manipulation.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/manipulation.gif -------------------------------------------------------------------------------- /data/others/library1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/others/library1.jpg -------------------------------------------------------------------------------- /data/others/library2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/others/library2.jpeg -------------------------------------------------------------------------------- /data/others/list.txt: -------------------------------------------------------------------------------- 1 | window1.jpg 2 | window2.jpg 3 | window3.jpeg 4 | library1.jpg 5 | library2.jpeg 6 | stones.jpg 7 | -------------------------------------------------------------------------------- /data/others/stones.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/others/stones.jpg -------------------------------------------------------------------------------- /data/others/window1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/others/window1.jpg -------------------------------------------------------------------------------- /data/others/window2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/others/window2.jpg -------------------------------------------------------------------------------- /data/others/window3.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/others/window3.jpeg -------------------------------------------------------------------------------- /data/restoration.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/restoration.gif -------------------------------------------------------------------------------- /data/windows.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/windows.png -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import torchvision.transforms as transforms 3 | from PIL import Image 4 | 5 | import utils 6 | 7 | 8 | def pil_loader(path): 9 | # open path as file to avoid ResourceWarning 10 | # (https://github.com/python-pillow/Pillow/issues/835) 11 | with open(path, 'rb') as f: 12 | img = Image.open(f) 13 | return img.convert('RGB') 14 | 15 | 16 | def accimage_loader(path): 17 | import accimage 18 | try: 19 | return accimage.Image(path) 20 | except IOError: 21 | # Potentially a decoding problem, fall back to PIL.Image 22 | return pil_loader(path) 23 | 24 | 25 | def default_loader(path): 26 | from torchvision import get_image_backend 27 | if get_image_backend() == 'accimage': 28 | return accimage_loader(path) 29 | else: 30 | return pil_loader(path) 31 | 32 | 33 | class ImageDataset(data.Dataset): 34 | 35 | def __init__(self, 36 | root_dir, 37 | meta_file, 38 | transform=None, 39 | image_size=128, 40 | normalize=True): 41 | self.root_dir = root_dir 42 | if transform is not None: 43 | self.transform = transform 44 | else: 45 | norm_mean = [0.5, 0.5, 0.5] 46 | norm_std = [0.5, 0.5, 0.5] 47 | if normalize: 48 | self.transform = transforms.Compose([ 49 | utils.CenterCropLongEdge(), 50 | transforms.Resize(image_size), 51 | transforms.ToTensor(), 52 | transforms.Normalize(norm_mean, norm_std) 53 | ]) 54 | else: 55 | self.transform = transforms.Compose([ 56 | utils.CenterCropLongEdge(), 57 | transforms.Resize(image_size), 58 | transforms.ToTensor() 59 | ]) 60 | with open(meta_file) as f: 61 | lines = f.readlines() 62 | print("building dataset from %s" % meta_file) 63 | self.num = len(lines) 64 | self.metas = [] 65 | self.classifier = None 66 | for line in lines: 67 | line_split = line.rstrip().split() 68 | if len(line_split) == 2: 69 | self.metas.append((line_split[0], int(line_split[1]))) 70 | else: 71 | self.metas.append((line_split[0], -1)) 72 | print("read meta done") 73 | 74 | def __len__(self): 75 | return self.num 76 | 77 | def __getitem__(self, idx): 78 | filename = self.root_dir + '/' + self.metas[idx][0] 79 | cls = self.metas[idx][1] 80 | img = default_loader(filename) 81 | 82 | # transform 83 | if self.transform is not None: 84 | img = self.transform(img) 85 | 86 | return img, cls, self.metas[idx][0] 87 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0' 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torchvision.utils as vutils 8 | 9 | import utils 10 | from models import DGP 11 | 12 | sys.path.append("./") 13 | 14 | 15 | # Arguments for demo 16 | def add_example_parser(parser): 17 | parser.add_argument( 18 | '--image_path', type=str, default='', 19 | help='Path of the image to be processed (default: %(default)s)') 20 | parser.add_argument( 21 | '--class', type=int, default=-1, 22 | help='class index of the image (default: %(default)s)') 23 | parser.add_argument( 24 | '--image_path2', type=str, default='', 25 | help='Path of the 2nd image to be processed, used in "morphing" mode (default: %(default)s)') 26 | parser.add_argument( 27 | '--class2', type=int, default=-1, 28 | help='class index of the 2nd image, used in "morphing" mode (default: %(default)s)') 29 | return parser 30 | 31 | 32 | # prepare arguments and save in config 33 | parser = utils.prepare_parser() 34 | parser = utils.add_dgp_parser(parser) 35 | parser = add_example_parser(parser) 36 | config = vars(parser.parse_args()) 37 | utils.dgp_update_config(config) 38 | 39 | # set random seed 40 | utils.seed_rng(config['seed']) 41 | 42 | if not os.path.exists('{}/images'.format(config['exp_path'])): 43 | os.makedirs('{}/images'.format(config['exp_path'])) 44 | if not os.path.exists('{}/images_sheet'.format(config['exp_path'])): 45 | os.makedirs('{}/images_sheet'.format(config['exp_path'])) 46 | 47 | # initialize DGP model 48 | dgp = DGP(config) 49 | 50 | # prepare the target image 51 | img = utils.get_img(config['image_path'], config['resolution']).cuda() 52 | category = torch.Tensor([config['class']]).long().cuda() 53 | dgp.set_target(img, category, config['image_path']) 54 | 55 | # prepare initial latent vector 56 | dgp.select_z(select_y=True if config['class'] < 0 else False) 57 | # start reconstruction 58 | loss_dict = dgp.run() 59 | 60 | if config['dgp_mode'] == 'category_transfer': 61 | save_imgs = img.clone().cpu() 62 | for i in range(151, 294): # dog & cat 63 | # for i in range(7, 25): # bird 64 | with torch.no_grad(): 65 | x = dgp.G(dgp.z, dgp.G.shared(dgp.y.fill_(i))) 66 | utils.save_img( 67 | x[0], 68 | '%s/images/%s_class%d.jpg' % (config['exp_path'], dgp.img_name, i)) 69 | save_imgs = torch.cat((save_imgs, x.cpu()), dim=0) 70 | vutils.save_image( 71 | save_imgs, 72 | '%s/images_sheet/%s_categories.jpg' % (config['exp_path'], dgp.img_name), 73 | nrow=int(save_imgs.size(0)**0.5), 74 | normalize=True) 75 | 76 | elif config['dgp_mode'] == 'morphing': 77 | dgp2 = DGP(config) 78 | dgp_interp = DGP(config) 79 | 80 | img2 = utils.get_img(config['image_path2'], config['resolution']).cuda() 81 | category2 = torch.Tensor([config['class2']]).long().cuda() 82 | 83 | dgp2.set_target(img2, category2, config['image_path2']) 84 | dgp2.select_z(select_y=True if config['class2'] < 0 else False) 85 | loss_dict = dgp2.run() 86 | 87 | weight1 = dgp.G.state_dict() 88 | weight2 = dgp2.G.state_dict() 89 | weight_interp = OrderedDict() 90 | save_imgs = [] 91 | with torch.no_grad(): 92 | for i in range(11): 93 | alpha = i / 10 94 | # interpolate between both latent vector and generator weight 95 | z_interp = alpha * dgp.z + (1 - alpha) * dgp2.z 96 | y_interp = alpha * dgp.G.shared(dgp.y) + (1 - alpha) * dgp2.G.shared(dgp2.y) 97 | for k, w1 in weight1.items(): 98 | w2 = weight2[k] 99 | weight_interp[k] = alpha * w1 + (1 - alpha) * w2 100 | dgp_interp.G.load_state_dict(weight_interp) 101 | x_interp = dgp_interp.G(z_interp, y_interp) 102 | save_imgs.append(x_interp.cpu()) 103 | # save images 104 | save_path = '%s/images/%s_%s' % (config['exp_path'], dgp.img_name, dgp2.img_name) 105 | if not os.path.exists(save_path): 106 | os.makedirs(save_path) 107 | utils.save_img(x_interp[0], '%s/%03d.jpg' % (save_path, i + 1)) 108 | save_imgs = torch.cat(save_imgs, 0) 109 | vutils.save_image( 110 | save_imgs, 111 | '%s/images_sheet/morphing_%s_%s.jpg' % (config['exp_path'], dgp.img_name, dgp2.img_name), 112 | nrow=int(save_imgs.size(0)**0.5), 113 | normalize=True) 114 | -------------------------------------------------------------------------------- /experiments/examples/run_SR.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | WORK_PATH=$(dirname $0) 4 | IMAGE_PATH=${1:-data/ILSVRC2012_val_00042095.JPEG} 5 | CLASS=${2:-260} 6 | #IMAGE_PATH=data/ILSVRC2012_val_00000525.JPEG 7 | #CLASS=863 8 | 9 | python -u -W ignore example.py \ 10 | --exp_path $WORK_PATH \ 11 | --image_path $IMAGE_PATH \ 12 | --class $CLASS \ 13 | --seed 0 \ 14 | --dgp_mode SR \ 15 | --update_G \ 16 | --ftr_num 8 8 8 8 8 \ 17 | --ft_num 2 3 4 5 7 \ 18 | --lr_ratio 1.0 1.0 1.0 1.0 1.0 \ 19 | --w_D_loss 1 1 1 1 1 \ 20 | --w_nll 0.02 \ 21 | --w_mse 1 1 1 1 1 \ 22 | --select_num 500 \ 23 | --sample_std 0.3 \ 24 | --iterations 200 200 200 200 200 \ 25 | --G_lrs 5e-5 5e-5 2e-5 1e-5 1e-5 \ 26 | --z_lrs 2e-3 1e-3 2e-5 1e-5 1e-5 \ 27 | --use_in False False False False False \ 28 | --resolution 256 \ 29 | --weights_root pretrained \ 30 | --load_weights 256 \ 31 | --G_ch 96 --D_ch 96 \ 32 | --G_shared \ 33 | --hier --dim_z 120 --shared_dim 128 \ 34 | --skip_init --use_ema 35 | -------------------------------------------------------------------------------- /experiments/examples/run_category_transfer.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | WORK_PATH=$(dirname $0) 4 | IMAGE_PATH=${1:-data/ILSVRC2012_val_00008229.JPEG} 5 | CLASS=${2:-174} 6 | 7 | python -u -W ignore example.py \ 8 | --exp_path $WORK_PATH \ 9 | --image_path $IMAGE_PATH \ 10 | --class $CLASS \ 11 | --seed 4 \ 12 | --dgp_mode category_transfer \ 13 | --update_G \ 14 | --ftr_num 8 8 8 \ 15 | --ft_num 7 7 7 \ 16 | --lr_ratio 1 1 1 \ 17 | --w_D_loss 1 1 1 \ 18 | --w_nll 0.2 \ 19 | --w_mse 0 0 0 \ 20 | --select_num 500 \ 21 | --sample_std 0.5 \ 22 | --iterations 125 125 100 \ 23 | --G_lrs 2e-7 2e-5 2e-6 \ 24 | --z_lrs 1e-1 1e-2 2e-4 \ 25 | --use_in False False False \ 26 | --resolution 256 \ 27 | --weights_root pretrained \ 28 | --load_weights 256 \ 29 | --G_ch 96 --D_ch 96 \ 30 | --G_shared \ 31 | --hier --dim_z 120 --shared_dim 128 \ 32 | --skip_init --use_ema 33 | -------------------------------------------------------------------------------- /experiments/examples/run_colorization.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | WORK_PATH=$(dirname $0) 4 | IMAGE_PATH=${1:-data/ILSVRC2012_val_00003004.JPEG} 5 | CLASS=${2:-693} # set CLASS=-1 if you don't know the class 6 | #IMAGE_PATH=data/ILSVRC2012_val_00004291.JPEG 7 | #CLASS=442 8 | 9 | python -u -W ignore example.py \ 10 | --exp_path $WORK_PATH \ 11 | --image_path $IMAGE_PATH \ 12 | --class $CLASS \ 13 | --seed 0 \ 14 | --dgp_mode colorization \ 15 | --update_G \ 16 | --ftr_num 7 7 7 7 7 \ 17 | --ft_num 2 3 4 5 6 \ 18 | --lr_ratio 0.7 0.7 0.8 0.9 1.0 \ 19 | --w_D_loss 1 1 1 1 1 \ 20 | --w_nll 0.02 \ 21 | --w_mse 0 0 0 0 0 \ 22 | --select_num 500 \ 23 | --sample_std 0.5 \ 24 | --iterations 200 200 300 400 300 \ 25 | --G_lrs 5e-5 5e-5 5e-5 5e-5 2e-5 \ 26 | --z_lrs 2e-3 1e-3 5e-4 5e-5 2e-5 \ 27 | --use_in False False False False False \ 28 | --resolution 256 \ 29 | --weights_root pretrained \ 30 | --load_weights 256 \ 31 | --G_ch 96 --D_ch 96 \ 32 | --G_shared \ 33 | --hier --dim_z 120 --shared_dim 128 \ 34 | --skip_init --use_ema 35 | -------------------------------------------------------------------------------- /experiments/examples/run_inpainting.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | WORK_PATH=$(dirname $0) 4 | IMAGE_PATH=${1:-data/ILSVRC2012_val_00001970.JPEG} 5 | CLASS=${2:--1} # set CLASS=-1 if you don't know the class 6 | #IMAGE_PATH=data/ILSVRC2012_val_00022130.JPEG 7 | #CLASS=511 8 | 9 | python -u -W ignore example.py \ 10 | --exp_path $WORK_PATH \ 11 | --image_path $IMAGE_PATH \ 12 | --class $CLASS \ 13 | --seed 4 \ 14 | --dgp_mode inpainting \ 15 | --update_G \ 16 | --update_embed \ 17 | --ftr_num 8 8 8 8 8 \ 18 | --ft_num 7 7 7 7 7 \ 19 | --lr_ratio 1.0 1.0 1.0 1.0 1.0 \ 20 | --w_D_loss 1 1 1 1 0.5 \ 21 | --w_nll 0.02 \ 22 | --w_mse 1 1 1 1 10 \ 23 | --select_num 500 \ 24 | --sample_std 0.3 \ 25 | --iterations 200 200 200 200 200 \ 26 | --G_lrs 5e-5 5e-5 2e-5 2e-5 1e-5 \ 27 | --z_lrs 2e-3 1e-3 2e-5 2e-5 1e-5 \ 28 | --use_in False False False False False \ 29 | --resolution 256 \ 30 | --weights_root pretrained \ 31 | --load_weights 256 \ 32 | --G_ch 96 --D_ch 96 \ 33 | --G_shared \ 34 | --hier --dim_z 120 --shared_dim 128 \ 35 | --skip_init --use_ema 36 | -------------------------------------------------------------------------------- /experiments/examples/run_inpainting_list.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | WORK_PATH=$(dirname $0) 4 | 5 | python -u -W ignore main.py \ 6 | --exp_path $WORK_PATH \ 7 | --root_dir data/others \ 8 | --list_file data/others/list.txt \ 9 | --seed 2 \ 10 | --dgp_mode inpainting \ 11 | --update_G \ 12 | --update_embed \ 13 | --ftr_num 8 8 8 8 8 \ 14 | --ft_num 7 7 7 7 7 \ 15 | --lr_ratio 1.0 1.0 1.0 1.0 1.0 \ 16 | --w_D_loss 1 1 1 1 0.5 \ 17 | --w_nll 0.02 \ 18 | --w_mse 1 1 1 1 10 \ 19 | --select_num 1000 \ 20 | --sample_std 0.3 \ 21 | --iterations 200 200 200 200 200 \ 22 | --G_lrs 5e-5 5e-5 2e-5 2e-5 1e-5 \ 23 | --z_lrs 2e-3 1e-3 2e-5 2e-5 1e-5 \ 24 | --use_in False False False False False \ 25 | --resolution 256 \ 26 | --weights_root pretrained \ 27 | --load_weights 256 \ 28 | --G_ch 96 --D_ch 96 \ 29 | --G_shared \ 30 | --hier --dim_z 120 --shared_dim 128 \ 31 | --skip_init --use_ema --no_tb 32 | -------------------------------------------------------------------------------- /experiments/examples/run_jitter.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | WORK_PATH=$(dirname $0) 4 | IMAGE_PATH=${1:-data/ILSVRC2012_val_00001172.JPEG} 5 | CLASS=${2:-269} 6 | 7 | python -u -W ignore example.py \ 8 | --exp_path $WORK_PATH \ 9 | --image_path $IMAGE_PATH \ 10 | --class $CLASS \ 11 | --seed 0 \ 12 | --dgp_mode jitter \ 13 | --update_G \ 14 | --ftr_num 8 8 8 \ 15 | --ft_num 8 8 8 \ 16 | --lr_ratio 1 1 1 \ 17 | --w_D_loss 1 1 1 \ 18 | --w_nll 0.2 \ 19 | --w_mse 0 0 0 \ 20 | --select_num 500 \ 21 | --sample_std 0.5 \ 22 | --iterations 125 125 100 \ 23 | --G_lrs 2e-7 2e-5 2e-6 \ 24 | --z_lrs 1e-1 1e-2 2e-4 \ 25 | --use_in False False False \ 26 | --resolution 256 \ 27 | --weights_root pretrained \ 28 | --load_weights 256 \ 29 | --G_ch 96 --D_ch 96 \ 30 | --G_shared \ 31 | --hier --dim_z 120 --shared_dim 128 \ 32 | --skip_init --use_ema 33 | -------------------------------------------------------------------------------- /experiments/examples/run_morphing.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | WORK_PATH=$(dirname $0) 4 | IMAGE_PATH=${1:-data/ILSVRC2012_val_00001172.JPEG} 5 | CLASS=${2:-269} 6 | IMAGE_PATH2=${3:-data/ILSVRC2012_val_00020814.JPEG} 7 | CLASS2=${4:-185} 8 | #IMAGE_PATH=data/ILSVRC2012_val_00044065.JPEG 9 | #CLASS=425 10 | #IMAGE_PATH2=data/ILSVRC2012_val_00044640.JPEG 11 | #CLASS2=425 12 | 13 | python -u -W ignore example.py \ 14 | --exp_path $WORK_PATH \ 15 | --seed 0 \ 16 | --image_path $IMAGE_PATH \ 17 | --class $CLASS \ 18 | --image_path2 $IMAGE_PATH2 \ 19 | --class2 $CLASS2 \ 20 | --dgp_mode morphing \ 21 | --update_G \ 22 | --update_embed \ 23 | --ftr_num 8 8 8 \ 24 | --ft_num 8 8 8 \ 25 | --lr_ratio 1 1 1 \ 26 | --w_D_loss 1 1 1 \ 27 | --w_nll 0.2 \ 28 | --w_mse 0 0 0 \ 29 | --select_num 500 \ 30 | --sample_std 0.5 \ 31 | --iterations 125 125 100 \ 32 | --G_lrs 2e-7 2e-5 2e-6 \ 33 | --z_lrs 1e-1 1e-2 2e-4 \ 34 | --use_in False False False \ 35 | --resolution 256 \ 36 | --weights_root pretrained \ 37 | --load_weights ch64_256 \ 38 | --G_ch 64 --D_ch 64 \ 39 | --G_shared \ 40 | --hier --dim_z 120 --shared_dim 128 \ 41 | --skip_init --use_ema 42 | -------------------------------------------------------------------------------- /experiments/imagenet1k_128/SR/D_biased/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | WORK_PATH=$(dirname $0) 4 | 5 | python -u -W ignore main.py \ 6 | --seed 0 \ 7 | --exp_path $WORK_PATH \ 8 | --root_dir /path_to_your_ImageNet/val \ 9 | --list_file scripts/imagenet_val_1k.txt \ 10 | --dgp_mode SR \ 11 | --update_G \ 12 | --ftr_num 8 8 8 8 8 \ 13 | --ft_num 2 3 4 5 6 \ 14 | --lr_ratio 1 1 1 1 1 \ 15 | --w_D_loss 1 1 1 1 1 \ 16 | --w_nll 0.02 \ 17 | --w_mse 1 1 1 1 1 \ 18 | --select_num 500 \ 19 | --sample_std 0.3 \ 20 | --iterations 200 200 200 200 200 \ 21 | --G_lrs 5e-5 5e-5 2e-5 1e-5 1e-5 \ 22 | --z_lrs 2e-3 1e-3 2e-5 1e-5 1e-5 \ 23 | --use_in False False False False False \ 24 | --resolution 128 \ 25 | --weights_root pretrained \ 26 | --load_weights 128 \ 27 | --G_ch 96 --D_ch 96 \ 28 | --G_shared \ 29 | --hier --dim_z 120 --shared_dim 128 \ 30 | --skip_init --use_ema \ 31 | 2>&1 | tee $WORK_PATH/log.txt 32 | -------------------------------------------------------------------------------- /experiments/imagenet1k_128/SR/D_biased/train_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | PARTITION=$1 # input your partition on lustre 4 | WORK_PATH=$(dirname $0) 5 | 6 | # make sure that the image list length could be divided by 7 | # the number of threads, e.g., 1000 % 4 = 0 8 | srun -p $PARTITION -n4 --gres=gpu:4 --ntasks-per-node=4 \ 9 | python -u -W ignore main.py \ 10 | --dist \ 11 | --port 12348 \ 12 | --seed 0 \ 13 | --exp_path $WORK_PATH \ 14 | --root_dir /path_to_your_ImageNet/val \ 15 | --list_file scripts/imagenet_val_1k.txt \ 16 | --dgp_mode SR \ 17 | --update_G \ 18 | --ftr_num 8 8 8 8 8 \ 19 | --ft_num 2 3 4 5 6 \ 20 | --lr_ratio 1 1 1 1 1 \ 21 | --w_D_loss 1 1 1 1 1 \ 22 | --w_nll 0.02 \ 23 | --w_mse 1 1 1 1 1 \ 24 | --select_num 500 \ 25 | --sample_std 0.3 \ 26 | --iterations 200 200 200 200 200 \ 27 | --G_lrs 5e-5 5e-5 2e-5 1e-5 1e-5 \ 28 | --z_lrs 2e-3 1e-3 2e-5 1e-5 1e-5 \ 29 | --use_in False False False False False \ 30 | --resolution 128 \ 31 | --weights_root pretrained \ 32 | --load_weights 128 \ 33 | --G_ch 96 --D_ch 96 \ 34 | --G_shared \ 35 | --hier --dim_z 120 --shared_dim 128 \ 36 | --skip_init --use_ema \ 37 | 2>&1 | tee $WORK_PATH/log.txt 38 | -------------------------------------------------------------------------------- /experiments/imagenet1k_128/SR/MSE_biased/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | WORK_PATH=$(dirname $0) 4 | 5 | python -u -W ignore main.py \ 6 | --seed 0 \ 7 | --exp_path $WORK_PATH \ 8 | --root_dir /path_to_your_ImageNet/val \ 9 | --list_file scripts/imagenet_val_1k.txt \ 10 | --dgp_mode SR \ 11 | --update_G \ 12 | --ftr_num 8 8 8 8 8 \ 13 | --ft_num 2 3 4 5 6 \ 14 | --lr_ratio 1 1 1 1 1 \ 15 | --w_D_loss 1 1 1 0.5 0.1 \ 16 | --w_nll 0.02 \ 17 | --w_mse 10 10 10 50 100 \ 18 | --select_num 500 \ 19 | --sample_std 0.3 \ 20 | --iterations 200 200 200 200 200 \ 21 | --G_lrs 2e-4 2e-4 1e-4 1e-4 1e-5 \ 22 | --z_lrs 1e-3 1e-3 1e-4 1e-4 1e-5 \ 23 | --use_in True True True True True \ 24 | --resolution 128 \ 25 | --weights_root pretrained \ 26 | --load_weights 128 \ 27 | --G_ch 96 --D_ch 96 \ 28 | --G_shared \ 29 | --hier --dim_z 120 --shared_dim 128 \ 30 | --skip_init --use_ema \ 31 | 2>&1 | tee $WORK_PATH/log.txt 32 | -------------------------------------------------------------------------------- /experiments/imagenet1k_128/SR/MSE_biased/train_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | PARTITION=$1 # input your partition on lustre 4 | WORK_PATH=$(dirname $0) 5 | 6 | # make sure that the image list length could be divided by 7 | # the number of threads, e.g., 1000 % 4 = 0 8 | srun -p $PARTITION -n4 --gres=gpu:4 --ntasks-per-node=4 \ 9 | python -u -W ignore main.py \ 10 | --dist \ 11 | --port 12347 \ 12 | --seed 0 \ 13 | --exp_path $WORK_PATH \ 14 | --root_dir /path_to_your_ImageNet/val \ 15 | --list_file scripts/imagenet_val_1k.txt \ 16 | --dgp_mode SR \ 17 | --update_G \ 18 | --ftr_num 8 8 8 8 8 \ 19 | --ft_num 2 3 4 5 6 \ 20 | --lr_ratio 1 1 1 1 1 \ 21 | --w_D_loss 1 1 1 0.5 0.1 \ 22 | --w_nll 0.02 \ 23 | --w_mse 10 10 10 50 100 \ 24 | --select_num 500 \ 25 | --sample_std 0.3 \ 26 | --iterations 200 200 200 200 200 \ 27 | --G_lrs 2e-4 2e-4 1e-4 1e-4 1e-5 \ 28 | --z_lrs 1e-3 1e-3 1e-4 1e-4 1e-5 \ 29 | --use_in True True True True True \ 30 | --resolution 128 \ 31 | --weights_root pretrained \ 32 | --load_weights 128 \ 33 | --G_ch 96 --D_ch 96 \ 34 | --G_shared \ 35 | --hier --dim_z 120 --shared_dim 128 \ 36 | --skip_init --use_ema \ 37 | 2>&1 | tee $WORK_PATH/log.txt 38 | -------------------------------------------------------------------------------- /experiments/imagenet1k_128/colorization/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | WORK_PATH=$(dirname $0) 4 | 5 | python -u -W ignore main.py \ 6 | --seed 0 \ 7 | --exp_path $WORK_PATH \ 8 | --root_dir /path_to_your_ImageNet/val \ 9 | --list_file scripts/imagenet_val_1k.txt \ 10 | --dgp_mode colorization \ 11 | --update_G \ 12 | --ftr_num 7 7 7 7 7 \ 13 | --ft_num 2 3 4 5 6 \ 14 | --lr_ratio 0.7 0.7 0.8 0.9 1 \ 15 | --w_D_loss 1 1 1 1 1 \ 16 | --w_nll 0.02 \ 17 | --w_mse 0 0 0 0 0 \ 18 | --select_num 500 \ 19 | --sample_std 0.5 \ 20 | --iterations 200 200 300 400 300 \ 21 | --G_lrs 5e-5 5e-5 5e-5 5e-5 2e-5 \ 22 | --z_lrs 2e-3 1e-3 5e-4 5e-5 2e-5 \ 23 | --use_in False False False False False \ 24 | --resolution 128 \ 25 | --weights_root pretrained \ 26 | --load_weights 128 \ 27 | --G_ch 96 --D_ch 96 \ 28 | --G_shared \ 29 | --hier --dim_z 120 --shared_dim 128 \ 30 | --skip_init --use_ema \ 31 | 2>&1 | tee $WORK_PATH/log.txt 32 | -------------------------------------------------------------------------------- /experiments/imagenet1k_128/colorization/train_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | PARTITION=$1 # input your partition on lustre 4 | WORK_PATH=$(dirname $0) 5 | 6 | # make sure that the image list length could be divided by 7 | # the number of threads, e.g., 1000 % 4 = 0 8 | srun -p $PARTITION -n4 --gres=gpu:4 --ntasks-per-node=4 \ 9 | python -u -W ignore main.py \ 10 | --dist \ 11 | --port 12345 \ 12 | --seed 0 \ 13 | --exp_path $WORK_PATH \ 14 | --root_dir /path_to_your_ImageNet/val \ 15 | --list_file scripts/imagenet_val_1k.txt \ 16 | --dgp_mode colorization \ 17 | --update_G \ 18 | --ftr_num 7 7 7 7 7 \ 19 | --ft_num 2 3 4 5 6 \ 20 | --lr_ratio 0.7 0.7 0.8 0.9 1 \ 21 | --w_D_loss 1 1 1 1 1 \ 22 | --w_nll 0.02 \ 23 | --w_mse 0 0 0 0 0 \ 24 | --select_num 500 \ 25 | --sample_std 0.5 \ 26 | --iterations 200 200 300 400 300 \ 27 | --G_lrs 5e-5 5e-5 5e-5 5e-5 2e-5 \ 28 | --z_lrs 2e-3 1e-3 5e-4 5e-5 2e-5 \ 29 | --use_in False False False False False \ 30 | --resolution 128 \ 31 | --weights_root pretrained \ 32 | --load_weights 128 \ 33 | --G_ch 96 --D_ch 96 \ 34 | --G_shared \ 35 | --hier --dim_z 120 --shared_dim 128 \ 36 | --skip_init --use_ema \ 37 | 2>&1 | tee $WORK_PATH/log.txt 38 | -------------------------------------------------------------------------------- /experiments/imagenet1k_128/inpainting/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | WORK_PATH=$(dirname $0) 4 | 5 | python -u -W ignore main.py \ 6 | --seed 0 \ 7 | --exp_path $WORK_PATH \ 8 | --root_dir /path_to_your_ImageNet/val \ 9 | --list_file scripts/imagenet_val_1k.txt \ 10 | --dgp_mode inpainting \ 11 | --update_G \ 12 | --update_embed \ 13 | --ftr_num 8 8 8 8 8 \ 14 | --ft_num 6 6 6 6 6 \ 15 | --lr_ratio 1 1 1 1 1 \ 16 | --w_D_loss 1 1 1 0.1 0.1 \ 17 | --w_nll 0.02 \ 18 | --w_mse 10 10 10 100 100 \ 19 | --select_num 500 \ 20 | --sample_std 0.3 \ 21 | --iterations 200 200 200 200 200 \ 22 | --G_lrs 2e-4 2e-4 1e-4 1e-4 1e-5 \ 23 | --z_lrs 1e-3 1e-3 1e-4 1e-4 1e-5 \ 24 | --use_in True True True True True \ 25 | --resolution 128 \ 26 | --weights_root pretrained \ 27 | --load_weights 128 \ 28 | --G_ch 96 --D_ch 96 \ 29 | --G_shared \ 30 | --hier --dim_z 120 --shared_dim 128 \ 31 | --skip_init --use_ema \ 32 | 2>&1 | tee $WORK_PATH/log.txt 33 | -------------------------------------------------------------------------------- /experiments/imagenet1k_128/inpainting/train_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env sh 2 | 3 | PARTITION=$1 # input your partition on lustre 4 | WORK_PATH=$(dirname $0) 5 | 6 | # make sure that the image list length could be divided by 7 | # the number of threads, e.g., 1000 % 4 = 0 8 | srun -p $PARTITION -n4 --gres=gpu:4 --ntasks-per-node=4 \ 9 | python -u -W ignore main.py \ 10 | --dist \ 11 | --port 12346 \ 12 | --seed 0 \ 13 | --exp_path $WORK_PATH \ 14 | --root_dir /path_to_your_ImageNet/val \ 15 | --list_file scripts/imagenet_val_1k.txt \ 16 | --dgp_mode inpainting \ 17 | --update_G \ 18 | --update_embed \ 19 | --ftr_num 8 8 8 8 8 \ 20 | --ft_num 6 6 6 6 6 \ 21 | --lr_ratio 1 1 1 1 1 \ 22 | --w_D_loss 1 1 1 0.1 0.1 \ 23 | --w_nll 0.02 \ 24 | --w_mse 10 10 10 100 100 \ 25 | --select_num 500 \ 26 | --sample_std 0.3 \ 27 | --iterations 200 200 200 200 200 \ 28 | --G_lrs 2e-4 2e-4 1e-4 1e-4 1e-5 \ 29 | --z_lrs 1e-3 1e-3 1e-4 1e-4 1e-5 \ 30 | --use_in True True True True True \ 31 | --resolution 128 \ 32 | --weights_root pretrained \ 33 | --load_weights 128 \ 34 | --G_ch 96 --D_ch 96 \ 35 | --G_shared \ 36 | --hier --dim_z 120 --shared_dim 128 \ 37 | --skip_init --use_ema \ 38 | 2>&1 | tee $WORK_PATH/log.txt 39 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.multiprocessing as mp 3 | 4 | import utils 5 | from trainer import Trainer 6 | from utils import dist_init 7 | 8 | 9 | def main(): 10 | parser = utils.prepare_parser() 11 | parser = utils.add_dgp_parser(parser) 12 | config = vars(parser.parse_args()) 13 | utils.dgp_update_config(config) 14 | print(config) 15 | 16 | rank = 0 17 | if mp.get_start_method(allow_none=True) != 'spawn': 18 | mp.set_start_method('spawn', force=True) 19 | if config['dist']: 20 | rank, world_size = dist_init(config['port']) 21 | 22 | # Seed RNG 23 | utils.seed_rng(rank + config['seed']) 24 | 25 | # Setup cudnn.benchmark for free speed 26 | torch.backends.cudnn.benchmark = True 27 | 28 | # train 29 | trainer = Trainer(config) 30 | trainer.run() 31 | 32 | 33 | if __name__ == '__main__': 34 | main() 35 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .biggan import * 2 | from .dgp import * 3 | from .nethook import * 4 | -------------------------------------------------------------------------------- /models/biggan.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.nn import init 8 | 9 | from . import layers 10 | 11 | 12 | # Architectures for G 13 | # Attention is passed in in the format '32_64' to mean applying an attention 14 | # block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64. 15 | def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'): 16 | arch = {} 17 | arch[512] = { 18 | 'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2, 1]], 19 | 'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1, 1]], 20 | 'upsample': [True] * 7, 21 | 'resolution': [8, 16, 32, 64, 128, 256, 512], 22 | 'attention': 23 | {2**i: (2**i in [int(item) for item in attention.split('_')]) 24 | for i in range(3, 10)} 25 | } 26 | arch[256] = { 27 | 'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2]], 28 | 'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1]], 29 | 'upsample': [True] * 6, 30 | 'resolution': [8, 16, 32, 64, 128, 256], 31 | 'attention': 32 | {2**i: (2**i in [int(item) for item in attention.split('_')]) 33 | for i in range(3, 9)} 34 | } 35 | arch[128] = { 36 | 'in_channels': [ch * item for item in [16, 16, 8, 4, 2]], 37 | 'out_channels': [ch * item for item in [16, 8, 4, 2, 1]], 38 | 'upsample': [True] * 5, 39 | 'resolution': [8, 16, 32, 64, 128], 40 | 'attention': 41 | {2**i: (2**i in [int(item) for item in attention.split('_')]) 42 | for i in range(3, 8)} 43 | } 44 | arch[64] = { 45 | 'in_channels': [ch * item for item in [16, 16, 8, 4]], 46 | 'out_channels': [ch * item for item in [16, 8, 4, 2]], 47 | 'upsample': [True] * 4, 48 | 'resolution': [8, 16, 32, 64], 49 | 'attention': 50 | {2**i: (2**i in [int(item) for item in attention.split('_')]) 51 | for i in range(3, 7)} 52 | } 53 | arch[32] = { 54 | 'in_channels': [ch * item for item in [4, 4, 4]], 55 | 'out_channels': [ch * item for item in [4, 4, 4]], 56 | 'upsample': [True] * 3, 57 | 'resolution': [8, 16, 32], 58 | 'attention': 59 | {2**i: (2**i in [int(item) for item in attention.split('_')]) 60 | for i in range(3, 6)} 61 | } 62 | 63 | return arch 64 | 65 | 66 | class Generator(nn.Module): 67 | 68 | def __init__(self, 69 | G_ch=64, 70 | dim_z=128, 71 | bottom_width=4, 72 | resolution=128, 73 | G_kernel_size=3, 74 | G_attn='64', 75 | n_classes=1000, 76 | num_G_SVs=1, 77 | num_G_SV_itrs=1, 78 | G_shared=True, 79 | shared_dim=0, 80 | hier=False, 81 | cross_replica=False, 82 | mybn=False, 83 | G_activation=nn.ReLU(inplace=True), 84 | optimizer='Adam', 85 | G_lr=5e-5, 86 | G_B1=0.0, 87 | G_B2=0.999, 88 | adam_eps=1e-8, 89 | BN_eps=1e-5, 90 | SN_eps=1e-12, 91 | G_mixed_precision=False, 92 | G_fp16=False, 93 | G_init='ortho', 94 | skip_init=False, 95 | no_optim=False, 96 | G_param='SN', 97 | norm_style='bn', 98 | **kwargs): 99 | super(Generator, self).__init__() 100 | # Channel width mulitplier 101 | self.ch = G_ch 102 | # Dimensionality of the latent space 103 | self.dim_z = dim_z 104 | # The initial spatial dimensions 105 | self.bottom_width = bottom_width 106 | # Resolution of the output 107 | self.resolution = resolution 108 | # Kernel size? 109 | self.kernel_size = G_kernel_size 110 | # Attention? 111 | self.attention = G_attn 112 | # number of classes, for use in categorical conditional generation 113 | self.n_classes = n_classes 114 | # Use shared embeddings? 115 | self.G_shared = G_shared 116 | # Dimensionality of the shared embedding? Unused if not using G_shared 117 | self.shared_dim = shared_dim if shared_dim > 0 else dim_z 118 | # Hierarchical latent space? 119 | self.hier = hier 120 | # Cross replica batchnorm? 121 | self.cross_replica = cross_replica 122 | # Use my batchnorm? 123 | self.mybn = mybn 124 | # nonlinearity for residual blocks 125 | self.activation = G_activation 126 | # Initialization style 127 | self.init = G_init 128 | # Parameterization style 129 | self.G_param = G_param 130 | # Normalization style 131 | self.norm_style = norm_style 132 | # Epsilon for BatchNorm? 133 | self.BN_eps = BN_eps 134 | # Epsilon for Spectral Norm? 135 | self.SN_eps = SN_eps 136 | # fp16? 137 | self.fp16 = G_fp16 138 | # Architecture dict 139 | self.arch = G_arch(self.ch, self.attention)[resolution] 140 | 141 | # If using hierarchical latents, adjust z 142 | if self.hier: 143 | # Number of places z slots into 144 | self.num_slots = len(self.arch['in_channels']) + 1 145 | self.z_chunk_size = (self.dim_z // self.num_slots) 146 | # Recalculate latent dimensionality for even splitting into chunks 147 | self.dim_z = self.z_chunk_size * self.num_slots 148 | else: 149 | self.num_slots = 1 150 | self.z_chunk_size = 0 151 | 152 | # Which convs, batchnorms, and linear layers to use 153 | if self.G_param == 'SN': 154 | self.which_conv = functools.partial( 155 | layers.SNConv2d, 156 | kernel_size=3, 157 | padding=1, 158 | num_svs=num_G_SVs, 159 | num_itrs=num_G_SV_itrs, 160 | eps=self.SN_eps) 161 | self.which_linear = functools.partial( 162 | layers.SNLinear, num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, eps=self.SN_eps) 163 | else: 164 | self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) 165 | self.which_linear = nn.Linear 166 | 167 | # We use a non-spectral-normed embedding here regardless; 168 | # For some reason applying SN to G's embedding seems to randomly cripple G 169 | self.which_embedding = nn.Embedding 170 | bn_linear = ( 171 | functools.partial(self.which_linear, bias=False) 172 | if self.G_shared else self.which_embedding) 173 | self.which_bn = functools.partial( 174 | layers.ccbn, 175 | which_linear=bn_linear, 176 | cross_replica=self.cross_replica, 177 | mybn=self.mybn, 178 | input_size=(self.shared_dim + self.z_chunk_size if self.G_shared else self.n_classes), 179 | norm_style=self.norm_style, 180 | eps=self.BN_eps) 181 | 182 | # Prepare model 183 | # If not using shared embeddings, self.shared is just a passthrough 184 | self.shared = ( 185 | self.which_embedding(n_classes, self.shared_dim) if G_shared else layers.identity()) 186 | 187 | # First linear layer 188 | self.linear = self.which_linear(self.dim_z // self.num_slots, 189 | self.arch['in_channels'][0] * (self.bottom_width**2)) 190 | 191 | # self.blocks is a doubly-nested list of modules, the outer loop intended 192 | # to be over blocks at a given resolution (resblocks and/or self-attention) 193 | # while the inner loop is over a given block 194 | self.blocks = [] 195 | for index in range(len(self.arch['out_channels'])): 196 | self.blocks += [[ 197 | layers.GBlock( 198 | in_channels=self.arch['in_channels'][index], 199 | out_channels=self.arch['out_channels'][index], 200 | which_conv=self.which_conv, 201 | which_bn=self.which_bn, 202 | activation=self.activation, 203 | upsample=(functools.partial(F.interpolate, scale_factor=2) 204 | if self.arch['upsample'][index] else None)) 205 | ]] 206 | 207 | # If attention on this block, attach it to the end 208 | if self.arch['attention'][self.arch['resolution'][index]]: 209 | print('Adding attention layer in G at resolution %d' % 210 | self.arch['resolution'][index]) 211 | self.blocks[-1] += [ 212 | layers.Attention(self.arch['out_channels'][index], self.which_conv) 213 | ] 214 | 215 | # Turn self.blocks into a ModuleList so that it's all properly registered. 216 | self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) 217 | 218 | # output layer: batchnorm-relu-conv. 219 | # Consider using a non-spectral conv here 220 | self.output_layer = nn.Sequential( 221 | layers.bn( 222 | self.arch['out_channels'][-1], cross_replica=self.cross_replica, mybn=self.mybn), 223 | self.activation, self.which_conv(self.arch['out_channels'][-1], 3)) 224 | 225 | # Initialize weights. Optionally skip init for testing. 226 | if not skip_init: 227 | self.init_weights() 228 | 229 | # Set up optimizer 230 | # If this is an EMA copy, no need for an optim, so just return now 231 | if no_optim: 232 | return 233 | self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps 234 | if G_mixed_precision: 235 | print('Using fp16 adam in G...') 236 | import utils 237 | self.optim = utils.Adam16( 238 | params=self.parameters(), 239 | lr=self.lr, 240 | betas=(self.B1, self.B2), 241 | weight_decay=0, 242 | eps=self.adam_eps) 243 | else: 244 | if optimizer == 'Adam': 245 | self.optim = optim.Adam( 246 | params=self.parameters(), 247 | lr=self.lr, 248 | betas=(self.B1, self.B2), 249 | weight_decay=0, 250 | eps=self.adam_eps) 251 | elif optimizer == 'SGD': 252 | self.optim = optim.SGD( 253 | params=self.parameters(), lr=self.lr, momentum=0.9, weight_decay=0) 254 | else: 255 | raise ValueError("optim has to be Adam or SGD, " "but got {}".format(optimizer)) 256 | 257 | # LR scheduling, left here for forward compatibility 258 | # self.lr_sched = {'itr' : 0}# if self.progressive else {} 259 | # self.j = 0 260 | 261 | # Initialize 262 | def init_weights(self): 263 | self.param_count = 0 264 | for module in self.modules(): 265 | if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) 266 | or isinstance(module, nn.Embedding)): 267 | if self.init == 'ortho': 268 | init.orthogonal_(module.weight) 269 | elif self.init == 'N02': 270 | init.normal_(module.weight, 0, 0.02) 271 | elif self.init in ['glorot', 'xavier']: 272 | init.xavier_uniform_(module.weight) 273 | else: 274 | print('Init style not recognized...') 275 | self.param_count += sum([p.data.nelement() for p in module.parameters()]) 276 | print('Param count for G' 's initialized parameters: %d' % self.param_count) 277 | 278 | def reset_in_init(self): 279 | for index, blocklist in enumerate(self.blocks): 280 | for block in blocklist: 281 | if isinstance(block, layers.GBlock): 282 | block.reset_in_init() 283 | 284 | def get_params(self, index=0, update_embed=False): 285 | if index == 0: 286 | for param in self.linear.parameters(): 287 | yield param 288 | if update_embed: 289 | for param in self.shared.parameters(): 290 | yield param 291 | elif index < len(self.blocks) + 1: 292 | for param in self.blocks[index - 1].parameters(): 293 | yield param 294 | elif index == len(self.blocks) + 1: 295 | for param in self.output_layer.parameters(): 296 | yield param 297 | else: 298 | raise ValueError('Index out of range') 299 | 300 | # Note on this forward function: we pass in a y vector which has 301 | # already been passed through G.shared to enable easy class-wise 302 | # interpolation later. If we passed in the one-hot and then ran it through 303 | # G.shared in this forward function, it would be harder to handle. 304 | def forward(self, z, y, use_in=False): 305 | # If hierarchical, concatenate zs and ys 306 | if self.hier: 307 | zs = torch.split(z, self.z_chunk_size, 1) 308 | z = zs[0] 309 | ys = [torch.cat([y, item], 1) for item in zs[1:]] 310 | else: 311 | ys = [y] * len(self.blocks) 312 | 313 | # First linear layer 314 | h = self.linear(z) 315 | # Reshape 316 | h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width) 317 | 318 | # Loop over blocks 319 | for index, blocklist in enumerate(self.blocks): 320 | # Second inner loop in case block has multiple layers 321 | for block in blocklist: 322 | h = block(h, ys[index], use_in) 323 | 324 | # Apply batchnorm-relu-conv-tanh at output 325 | return torch.tanh(self.output_layer(h)) 326 | 327 | 328 | # Discriminator architecture, same paradigm as G's above 329 | def D_arch(ch=64, attention='64', ksize='333333', dilation='111111'): 330 | arch = {} 331 | arch[256] = { 332 | 'in_channels': [3] + [ch * item for item in [1, 2, 4, 8, 8, 16]], 333 | 'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]], 334 | 'downsample': [True] * 6 + [False], 335 | 'resolution': [128, 64, 32, 16, 8, 4, 4], 336 | 'attention': 337 | {2**i: 2**i in [int(item) for item in attention.split('_')] 338 | for i in range(2, 8)} 339 | } 340 | arch[128] = { 341 | 'in_channels': [3] + [ch * item for item in [1, 2, 4, 8, 16]], 342 | 'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]], 343 | 'downsample': [True] * 5 + [False], 344 | 'resolution': [64, 32, 16, 8, 4, 4], 345 | 'attention': 346 | {2**i: 2**i in [int(item) for item in attention.split('_')] 347 | for i in range(2, 8)} 348 | } 349 | arch[64] = { 350 | 'in_channels': [3] + [ch * item for item in [1, 2, 4, 8]], 351 | 'out_channels': [item * ch for item in [1, 2, 4, 8, 16]], 352 | 'downsample': [True] * 4 + [False], 353 | 'resolution': [32, 16, 8, 4, 4], 354 | 'attention': 355 | {2**i: 2**i in [int(item) for item in attention.split('_')] 356 | for i in range(2, 7)} 357 | } 358 | arch[32] = { 359 | 'in_channels': [3] + [item * ch for item in [4, 4, 4]], 360 | 'out_channels': [item * ch for item in [4, 4, 4, 4]], 361 | 'downsample': [True, True, False, False], 362 | 'resolution': [16, 16, 16, 16], 363 | 'attention': 364 | {2**i: 2**i in [int(item) for item in attention.split('_')] 365 | for i in range(2, 6)} 366 | } 367 | return arch 368 | 369 | 370 | class Discriminator(nn.Module): 371 | 372 | def __init__(self, 373 | D_ch=64, 374 | D_wide=True, 375 | resolution=128, 376 | D_kernel_size=3, 377 | D_attn='64', 378 | n_classes=1000, 379 | num_D_SVs=1, 380 | num_D_SV_itrs=1, 381 | D_activation=nn.ReLU(inplace=False), 382 | D_lr=2e-4, 383 | D_B1=0.0, 384 | D_B2=0.999, 385 | adam_eps=1e-8, 386 | SN_eps=1e-12, 387 | output_dim=1, 388 | D_mixed_precision=False, 389 | D_fp16=False, 390 | D_init='ortho', 391 | skip_init=False, 392 | D_param='SN', 393 | **kwargs): 394 | super(Discriminator, self).__init__() 395 | # Width multiplier 396 | self.ch = D_ch 397 | # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN? 398 | self.D_wide = D_wide 399 | # Resolution 400 | self.resolution = resolution 401 | # Kernel size 402 | self.kernel_size = D_kernel_size 403 | # Attention? 404 | self.attention = D_attn 405 | # Number of classes 406 | self.n_classes = n_classes 407 | # Activation 408 | self.activation = D_activation 409 | # Initialization style 410 | self.init = D_init 411 | # Parameterization style 412 | self.D_param = D_param 413 | # Epsilon for Spectral Norm? 414 | self.SN_eps = SN_eps 415 | # Fp16? 416 | self.fp16 = D_fp16 417 | # Architecture 418 | self.arch = D_arch(self.ch, self.attention)[resolution] 419 | 420 | # Which convs, batchnorms, and linear layers to use 421 | # No option to turn off SN in D right now 422 | if self.D_param == 'SN': 423 | self.which_conv = functools.partial( 424 | layers.SNConv2d, 425 | kernel_size=3, 426 | padding=1, 427 | num_svs=num_D_SVs, 428 | num_itrs=num_D_SV_itrs, 429 | eps=self.SN_eps) 430 | self.which_linear = functools.partial( 431 | layers.SNLinear, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) 432 | self.which_embedding = functools.partial( 433 | layers.SNEmbedding, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps) 434 | else: 435 | self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1) 436 | self.which_linear = nn.Linear 437 | self.which_embedding = nn.Embedding 438 | 439 | # Prepare model 440 | # self.blocks is a doubly-nested list of modules, the outer loop intended 441 | # to be over blocks at a given resolution (resblocks and/or self-attention) 442 | self.blocks = [] 443 | for index in range(len(self.arch['out_channels'])): 444 | self.blocks += [[ 445 | layers.DBlock( 446 | in_channels=self.arch['in_channels'][index], 447 | out_channels=self.arch['out_channels'][index], 448 | which_conv=self.which_conv, 449 | wide=self.D_wide, 450 | activation=self.activation, 451 | preactivation=(index > 0), 452 | downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None)) 453 | ]] 454 | # If attention on this block, attach it to the end 455 | if self.arch['attention'][self.arch['resolution'][index]]: 456 | print('Adding attention layer in D at resolution %d' % 457 | self.arch['resolution'][index]) 458 | self.blocks[-1] += [ 459 | layers.Attention(self.arch['out_channels'][index], self.which_conv) 460 | ] 461 | # Turn self.blocks into a ModuleList so that it's all properly registered. 462 | self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks]) 463 | # Linear output layer. The output dimension is typically 1, but may be 464 | # larger if we're e.g. turning this into a VAE with an inference output 465 | self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim) 466 | # Embedding for projection discrimination 467 | self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1]) 468 | 469 | # Initialize weights 470 | if not skip_init: 471 | self.init_weights() 472 | 473 | # Set up optimizer 474 | self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps 475 | if D_mixed_precision: 476 | print('Using fp16 adam in D...') 477 | import utils 478 | self.optim = utils.Adam16( 479 | params=self.parameters(), 480 | lr=self.lr, 481 | betas=(self.B1, self.B2), 482 | weight_decay=0, 483 | eps=self.adam_eps) 484 | else: 485 | self.optim = optim.Adam( 486 | params=self.parameters(), 487 | lr=self.lr, 488 | betas=(self.B1, self.B2), 489 | weight_decay=0, 490 | eps=self.adam_eps) 491 | # LR scheduling, left here for forward compatibility 492 | # self.lr_sched = {'itr' : 0}# if self.progressive else {} 493 | # self.j = 0 494 | 495 | # Initialize 496 | def init_weights(self): 497 | self.param_count = 0 498 | for module in self.modules(): 499 | if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear) 500 | or isinstance(module, nn.Embedding)): 501 | if self.init == 'ortho': 502 | init.orthogonal_(module.weight) 503 | elif self.init == 'N02': 504 | init.normal_(module.weight, 0, 0.02) 505 | elif self.init in ['glorot', 'xavier']: 506 | init.xavier_uniform_(module.weight) 507 | else: 508 | print('Init style not recognized...') 509 | self.param_count += sum([p.data.nelement() for p in module.parameters()]) 510 | print('Param count for D' 's initialized parameters: %d' % self.param_count) 511 | 512 | def forward(self, x, y=None): 513 | # Stick x into h for cleaner for loops without flow control 514 | h = x 515 | h_list = [] 516 | # Loop over blocks 517 | for index, blocklist in enumerate(self.blocks): 518 | for block in blocklist: 519 | h = block(h) 520 | h_list.append(h) 521 | # Apply global sum pooling as in SN-GAN 522 | h = torch.sum(self.activation(h), [2, 3]) 523 | h_list.append(h) 524 | # Get initial class-unconditional output 525 | out = self.linear(h) 526 | if y is not None: 527 | # Get projection of final featureset onto class vectors and add to evidence 528 | out = out + torch.sum(self.embed(y) * h, 1, keepdim=True) 529 | 530 | return out, h_list 531 | -------------------------------------------------------------------------------- /models/dgp.py: -------------------------------------------------------------------------------- 1 | import os 2 | from copy import deepcopy 3 | 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | import torch.nn.functional as F 8 | import torchvision 9 | from PIL import Image 10 | from skimage import color 11 | from skimage.measure import compare_psnr, compare_ssim 12 | from torch.autograd import Variable 13 | 14 | import models 15 | import utils 16 | from models.downsampler import Downsampler 17 | 18 | 19 | class DGP(object): 20 | 21 | def __init__(self, config): 22 | self.rank, self.world_size = 0, 1 23 | if config['dist']: 24 | self.rank = dist.get_rank() 25 | self.world_size = dist.get_world_size() 26 | self.config = config 27 | self.mode = config['dgp_mode'] 28 | self.update_G = config['update_G'] 29 | self.update_embed = config['update_embed'] 30 | self.iterations = config['iterations'] 31 | self.ftr_num = config['ftr_num'] 32 | self.ft_num = config['ft_num'] 33 | self.lr_ratio = config['lr_ratio'] 34 | self.G_lrs = config['G_lrs'] 35 | self.z_lrs = config['z_lrs'] 36 | self.use_in = config['use_in'] 37 | self.select_num = config['select_num'] 38 | self.factor = 2 if self.mode == 'hybrid' else 4 # Downsample factor 39 | 40 | # create model 41 | self.G = models.Generator(**config).cuda() 42 | self.D = models.Discriminator( 43 | **config).cuda() if config['ftr_type'] == 'Discriminator' else None 44 | self.G.optim = torch.optim.Adam( 45 | [{'params': self.G.get_params(i, self.update_embed)} 46 | for i in range(len(self.G.blocks) + 1)], 47 | lr=config['G_lr'], 48 | betas=(config['G_B1'], config['G_B2']), 49 | weight_decay=0, 50 | eps=1e-8) 51 | 52 | # load weights 53 | if config['random_G']: 54 | self.random_G() 55 | else: 56 | utils.load_weights( 57 | self.G if not (config['use_ema']) else None, 58 | self.D, 59 | config['weights_root'], 60 | name_suffix=config['load_weights'], 61 | G_ema=self.G if config['use_ema'] else None, 62 | strict=False) 63 | 64 | self.G.eval() 65 | if self.D is not None: 66 | self.D.eval() 67 | self.G_weight = deepcopy(self.G.state_dict()) 68 | 69 | # prepare latent variable and optimizer 70 | self._prepare_latent() 71 | # prepare learning rate scheduler 72 | self.G_scheduler = utils.LRScheduler(self.G.optim, config['warm_up']) 73 | self.z_scheduler = utils.LRScheduler(self.z_optim, config['warm_up']) 74 | 75 | # loss functions 76 | self.mse = torch.nn.MSELoss() 77 | if config['ftr_type'] == 'Discriminator': 78 | self.ftr_net = self.D 79 | self.criterion = utils.DiscriminatorLoss(ftr_num=config['ftr_num'][0]) 80 | else: 81 | vgg = torchvision.models.vgg16(pretrained=True).cuda().eval() 82 | self.ftr_net = models.subsequence(vgg.features, last_layer='20') 83 | self.criterion = utils.PerceptLoss() 84 | 85 | # Downsampler for producing low-resolution image 86 | self.downsampler = Downsampler( 87 | n_planes=3, 88 | factor=self.factor, 89 | kernel_type='lanczos2', 90 | phase=0.5, 91 | preserve_size=True).type(torch.cuda.FloatTensor) 92 | 93 | def _prepare_latent(self): 94 | self.z = torch.zeros((1, self.G.dim_z)).normal_().cuda() 95 | self.z = Variable(self.z, requires_grad=True) 96 | self.z_optim = torch.optim.Adam( 97 | [{'params': self.z, 'lr': self.z_lrs[0]}], 98 | betas=(self.config['G_B1'], self.config['G_B2']), 99 | weight_decay=0, 100 | eps=1e-8 101 | ) 102 | self.y = torch.zeros(1).long().cuda() 103 | 104 | def reset_G(self): 105 | self.G.load_state_dict(self.G_weight, strict=False) 106 | self.G.reset_in_init() 107 | if self.config['random_G']: 108 | self.G.train() 109 | else: 110 | self.G.eval() 111 | 112 | def random_G(self): 113 | self.G.init_weights() 114 | 115 | def set_target(self, target, category, img_path): 116 | self.target_origin = target 117 | # apply degradation transform to the original image 118 | self.target = self.pre_process(target, True) 119 | self.y.fill_(category.item()) 120 | self.img_name = img_path[img_path.rfind('/') + 1:img_path.rfind('.')] 121 | 122 | def run(self, save_interval=None): 123 | save_imgs = self.target.clone() 124 | save_imgs2 = save_imgs.cpu().clone() 125 | loss_dict = {} 126 | curr_step = 0 127 | count = 0 128 | for stage, iteration in enumerate(self.iterations): 129 | # setup the number of features to use in discriminator 130 | self.criterion.set_ftr_num(self.ftr_num[stage]) 131 | 132 | for i in range(iteration): 133 | curr_step += 1 134 | # setup learning rate 135 | self.G_scheduler.update(curr_step, self.G_lrs[stage], 136 | self.ft_num[stage], self.lr_ratio[stage]) 137 | self.z_scheduler.update(curr_step, self.z_lrs[stage]) 138 | 139 | self.z_optim.zero_grad() 140 | if self.update_G: 141 | self.G.optim.zero_grad() 142 | x = self.G(self.z, self.G.shared(self.y), use_in=self.use_in[stage]) 143 | # apply degradation transform 144 | x_map = self.pre_process(x, False) 145 | 146 | # calculate losses in the degradation space 147 | ftr_loss = self.criterion(self.ftr_net, x_map, self.target) 148 | mse_loss = self.mse(x_map, self.target) 149 | # nll corresponds to a negative log-likelihood loss 150 | nll = self.z**2 / 2 151 | nll = nll.mean() 152 | l1_loss = F.l1_loss(x_map, self.target) 153 | loss = ftr_loss * self.config['w_D_loss'][stage] + \ 154 | mse_loss * self.config['w_mse'][stage] + \ 155 | nll * self.config['w_nll'] 156 | loss.backward() 157 | 158 | self.z_optim.step() 159 | if self.update_G: 160 | self.G.optim.step() 161 | 162 | # These losses are calculated in the [-1,1] image scale 163 | # We record the rescaled MSE and L1 loss, corresponding to [0,1] image scale 164 | loss_dict = { 165 | 'ftr_loss': ftr_loss, 166 | 'nll': nll, 167 | 'mse_loss': mse_loss / 4, 168 | 'l1_loss': l1_loss / 2 169 | } 170 | 171 | # calculate losses in the non-degradation space 172 | if self.mode in ['reconstruct', 'colorization', 'SR', 'inpainting']: 173 | # x2 is to get the post-processed result in colorization 174 | metrics, x2 = self.get_metrics(x) 175 | loss_dict = {**loss_dict, **metrics} 176 | 177 | if i == 0 or (i + 1) % self.config['print_interval'] == 0: 178 | if self.rank == 0: 179 | print(', '.join( 180 | ['Stage: [{0}/{1}]'.format(stage + 1, len(self.iterations))] + 181 | ['Iter: [{0}/{1}]'.format(i + 1, iteration)] + 182 | ['%s : %+4.4f' % (key, loss_dict[key]) for key in loss_dict] 183 | )) 184 | # save image sheet of the reconstruction process 185 | save_imgs = torch.cat((save_imgs, x), dim=0) 186 | torchvision.utils.save_image( 187 | save_imgs.float(), 188 | '%s/images_sheet/%s_%s.jpg' % 189 | (self.config['exp_path'], self.img_name, self.mode), 190 | nrow=int(save_imgs.size(0)**0.5), 191 | normalize=True) 192 | if self.mode == 'colorization': 193 | save_imgs2 = torch.cat((save_imgs2, x2), dim=0) 194 | torchvision.utils.save_image( 195 | save_imgs2.float(), 196 | '%s/images_sheet/%s_%s_2.jpg' % 197 | (self.config['exp_path'], self.img_name, self.mode), 198 | nrow=int(save_imgs.size(0)**0.5), 199 | normalize=True) 200 | 201 | if save_interval is not None: 202 | if i == 0 or (i + 1) % save_interval[stage] == 0: 203 | count += 1 204 | save_path = '%s/images/%s' % (self.config['exp_path'], 205 | self.img_name) 206 | if not os.path.exists(save_path): 207 | os.makedirs(save_path) 208 | img_path = os.path.join( 209 | save_path, '%s_%03d.jpg' % (self.img_name, count)) 210 | utils.save_img(x[0], img_path) 211 | 212 | # stop the reconstruction if the loss reaches a threshold 213 | if mse_loss.item() < self.config['stop_mse'] or ftr_loss.item( 214 | ) < self.config['stop_ftr']: 215 | break 216 | 217 | # save images 218 | utils.save_img( 219 | self.target[0], '%s/images/%s_%s_target.png' % 220 | (self.config['exp_path'], self.img_name, self.mode)) 221 | utils.save_img( 222 | self.target_origin[0], 223 | '%s/images/%s_%s_target_origin.png' % 224 | (self.config['exp_path'], self.img_name, self.mode)) 225 | utils.save_img( 226 | x[0], '%s/images/%s_%s.png' % 227 | (self.config['exp_path'], self.img_name, self.mode)) 228 | if self.mode == 'colorization': 229 | utils.save_img( 230 | x2[0], '%s/images/%s_%s2.png' % 231 | (self.config['exp_path'], self.img_name, self.mode)) 232 | 233 | if self.mode == 'jitter': 234 | # conduct random jittering 235 | self.jitter(x) 236 | if self.config['save_G']: 237 | torch.save( 238 | self.G.state_dict(), '%s/G_%s_%s.pth' % 239 | (self.config['exp_path'], self.img_name, self.mode)) 240 | torch.save( 241 | self.z, '%s/z_%s_%s.pth' % 242 | (self.config['exp_path'], self.img_name, self.mode)) 243 | return loss_dict 244 | 245 | def select_z(self, select_y=False): 246 | with torch.no_grad(): 247 | if self.select_num == 0: 248 | self.z.zero_() 249 | return 250 | elif self.select_num == 1: 251 | self.z.normal_() 252 | return 253 | z_all, y_all, loss_all = [], [], [] 254 | if self.rank == 0: 255 | print('Selecting z from {} samples'.format(self.select_num)) 256 | # only use last 3 discriminator features to compare 257 | self.criterion.set_ftr_num(3) 258 | for i in range(self.select_num): 259 | self.z.normal_(mean=0, std=self.config['sample_std']) 260 | z_all.append(self.z.cpu()) 261 | if select_y: 262 | self.y.random_(0, self.config['n_classes']) 263 | y_all.append(self.y.cpu()) 264 | x = self.G(self.z, self.G.shared(self.y)) 265 | x = self.pre_process(x) 266 | ftr_loss = self.criterion(self.ftr_net, x, self.target) 267 | loss_all.append(ftr_loss.view(1).cpu()) 268 | if self.rank == 0 and (i + 1) % 100 == 0: 269 | print('Generating {}th sample'.format(i + 1)) 270 | loss_all = torch.cat(loss_all) 271 | idx = torch.argmin(loss_all) 272 | self.z.copy_(z_all[idx]) 273 | if select_y: 274 | self.y.copy_(y_all[idx]) 275 | self.criterion.set_ftr_num(self.ftr_num[0]) 276 | 277 | def pre_process(self, image, target=True): 278 | if self.mode in ['SR', 'hybrid']: 279 | # apply downsampling, this part is the same as deep image prior 280 | if target: 281 | image_pil = utils.np_to_pil( 282 | utils.torch_to_np((image.cpu() + 1) / 2)) 283 | LR_size = [ 284 | image_pil.size[0] // self.factor, 285 | image_pil.size[1] // self.factor 286 | ] 287 | img_LR_pil = image_pil.resize(LR_size, Image.ANTIALIAS) 288 | image = utils.np_to_torch(utils.pil_to_np(img_LR_pil)).cuda() 289 | image = image * 2 - 1 290 | else: 291 | image = self.downsampler((image + 1) / 2) 292 | image = image * 2 - 1 293 | # interpolate to the orginal resolution via bilinear interpolation 294 | image = F.interpolate( 295 | image, scale_factor=self.factor, mode='bilinear') 296 | n, _, h, w = image.size() 297 | if self.mode in ['colorization', 'hybrid']: 298 | # transform the image to gray-scale 299 | r = image[:, 0, :, :] 300 | g = image[:, 1, :, :] 301 | b = image[:, 2, :, :] 302 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 303 | image = gray.view(n, 1, h, w).expand(n, 3, h, w) 304 | if self.mode in ['inpainting', 'hybrid']: 305 | # remove the center part of the image 306 | hole = min(h, w) // 3 307 | begin = (h - hole) // 2 308 | end = h - begin 309 | self.begin, self.end = begin, end 310 | mask = torch.ones(1, 1, h, w).cuda() 311 | mask[0, 0, begin:end, begin:end].zero_() 312 | image = image * mask 313 | return image 314 | 315 | def get_metrics(self, x): 316 | with torch.no_grad(): 317 | l1_loss_origin = F.l1_loss(x, self.target_origin) / 2 318 | mse_loss_origin = self.mse(x, self.target_origin) / 4 319 | metrics = { 320 | 'l1_loss_origin': l1_loss_origin, 321 | 'mse_loss_origin': mse_loss_origin 322 | } 323 | # transfer to numpy array and scale to [0, 1] 324 | target_np = (self.target_origin.detach().cpu().numpy()[0] + 1) / 2 325 | x_np = (x.detach().cpu().numpy()[0] + 1) / 2 326 | target_np = np.transpose(target_np, (1, 2, 0)) 327 | x_np = np.transpose(x_np, (1, 2, 0)) 328 | if self.mode == 'colorization': 329 | # combine the 'ab' dim of x with the 'L' dim of target image 330 | x_lab = color.rgb2lab(x_np) 331 | target_lab = color.rgb2lab(target_np) 332 | x_lab[:, :, 0] = target_lab[:, :, 0] 333 | x_np = color.lab2rgb(x_lab) 334 | x = torch.Tensor(np.transpose(x_np, (2, 0, 1))) * 2 - 1 335 | x = x.unsqueeze(0) 336 | elif self.mode == 'inpainting': 337 | # only use the inpainted area to calculate ssim and psnr 338 | x_np = x_np[self.begin:self.end, self.begin:self.end, :] 339 | target_np = target_np[self.begin:self.end, 340 | self.begin:self.end, :] 341 | ssim = compare_ssim(target_np, x_np, multichannel=True) 342 | psnr = compare_psnr(target_np, x_np) 343 | metrics['psnr'] = torch.Tensor([psnr]).cuda() 344 | metrics['ssim'] = torch.Tensor([ssim]).cuda() 345 | return metrics, x 346 | 347 | def jitter(self, x): 348 | save_imgs = x.clone().cpu() 349 | z_rand = self.z.clone() 350 | stds = [0.3, 0.5, 0.7] 351 | save_path = '%s/images/%s_jitter' % (self.config['exp_path'], 352 | self.img_name) 353 | if not os.path.exists(save_path): 354 | os.makedirs(save_path) 355 | with torch.no_grad(): 356 | for std in stds: 357 | for i in range(30): 358 | # add random noise to the latent vector 359 | z_rand.normal_() 360 | z = self.z + std * z_rand 361 | x_jitter = self.G(z, self.G.shared(self.y)) 362 | utils.save_img( 363 | x_jitter[0], '%s/std%.1f_%d.jpg' % (save_path, std, i)) 364 | save_imgs = torch.cat((save_imgs, x_jitter.cpu()), dim=0) 365 | 366 | torchvision.utils.save_image( 367 | save_imgs.float(), 368 | '%s/images_sheet/%s_jitters.jpg' % 369 | (self.config['exp_path'], self.img_name), 370 | nrow=int(save_imgs.size(0)**0.5), 371 | normalize=True) 372 | -------------------------------------------------------------------------------- /models/downsampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | class Downsampler(nn.Module): 6 | ''' 7 | http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf 8 | ''' 9 | def __init__(self, n_planes, factor, kernel_type, phase=0, kernel_width=None, support=None, sigma=None, preserve_size=False): 10 | super(Downsampler, self).__init__() 11 | 12 | assert phase in [0, 0.5], 'phase should be 0 or 0.5' 13 | 14 | if kernel_type == 'lanczos2': 15 | support = 2 16 | kernel_width = 4 * factor + 1 17 | kernel_type_ = 'lanczos' 18 | 19 | elif kernel_type == 'lanczos3': 20 | support = 3 21 | kernel_width = 6 * factor + 1 22 | kernel_type_ = 'lanczos' 23 | 24 | elif kernel_type == 'gauss12': 25 | kernel_width = 7 26 | sigma = 1/2 27 | kernel_type_ = 'gauss' 28 | 29 | elif kernel_type == 'gauss1sq2': 30 | kernel_width = 9 31 | sigma = 1./np.sqrt(2) 32 | kernel_type_ = 'gauss' 33 | 34 | elif kernel_type in ['lanczos', 'gauss', 'box']: 35 | kernel_type_ = kernel_type 36 | 37 | else: 38 | assert False, 'wrong name kernel' 39 | 40 | 41 | # note that `kernel width` will be different to actual size for phase = 1/2 42 | self.kernel = get_kernel(factor, kernel_type_, phase, kernel_width, support=support, sigma=sigma) 43 | 44 | downsampler = nn.Conv2d(n_planes, n_planes, kernel_size=self.kernel.shape, stride=factor, padding=0) 45 | downsampler.weight.data[:] = 0 46 | downsampler.bias.data[:] = 0 47 | 48 | kernel_torch = torch.from_numpy(self.kernel) 49 | for i in range(n_planes): 50 | downsampler.weight.data[i, i] = kernel_torch 51 | 52 | self.downsampler_ = downsampler 53 | 54 | if preserve_size: 55 | 56 | if self.kernel.shape[0] % 2 == 1: 57 | pad = int((self.kernel.shape[0] - 1) / 2.) 58 | else: 59 | pad = int((self.kernel.shape[0] - factor) / 2.) 60 | 61 | self.padding = nn.ReplicationPad2d(pad) 62 | 63 | self.preserve_size = preserve_size 64 | 65 | def forward(self, input): 66 | if self.preserve_size: 67 | x = self.padding(input) 68 | else: 69 | x= input 70 | self.x = x 71 | return self.downsampler_(x) 72 | 73 | def get_kernel(factor, kernel_type, phase, kernel_width, support=None, sigma=None): 74 | assert kernel_type in ['lanczos', 'gauss', 'box'] 75 | 76 | # factor = float(factor) 77 | if phase == 0.5 and kernel_type != 'box': 78 | kernel = np.zeros([kernel_width - 1, kernel_width - 1]) 79 | else: 80 | kernel = np.zeros([kernel_width, kernel_width]) 81 | 82 | 83 | if kernel_type == 'box': 84 | assert phase == 0.5, 'Box filter is always half-phased' 85 | kernel[:] = 1./(kernel_width * kernel_width) 86 | 87 | elif kernel_type == 'gauss': 88 | assert sigma, 'sigma is not specified' 89 | assert phase != 0.5, 'phase 1/2 for gauss not implemented' 90 | 91 | center = (kernel_width + 1.)/2. 92 | print(center, kernel_width) 93 | sigma_sq = sigma * sigma 94 | 95 | for i in range(1, kernel.shape[0] + 1): 96 | for j in range(1, kernel.shape[1] + 1): 97 | di = (i - center)/2. 98 | dj = (j - center)/2. 99 | kernel[i - 1][j - 1] = np.exp(-(di * di + dj * dj)/(2 * sigma_sq)) 100 | kernel[i - 1][j - 1] = kernel[i - 1][j - 1]/(2. * np.pi * sigma_sq) 101 | elif kernel_type == 'lanczos': 102 | assert support, 'support is not specified' 103 | center = (kernel_width + 1) / 2. 104 | 105 | for i in range(1, kernel.shape[0] + 1): 106 | for j in range(1, kernel.shape[1] + 1): 107 | 108 | if phase == 0.5: 109 | di = abs(i + 0.5 - center) / factor 110 | dj = abs(j + 0.5 - center) / factor 111 | else: 112 | di = abs(i - center) / factor 113 | dj = abs(j - center) / factor 114 | 115 | 116 | pi_sq = np.pi * np.pi 117 | 118 | val = 1 119 | if di != 0: 120 | val = val * support * np.sin(np.pi * di) * np.sin(np.pi * di / support) 121 | val = val / (np.pi * np.pi * di * di) 122 | 123 | if dj != 0: 124 | val = val * support * np.sin(np.pi * dj) * np.sin(np.pi * dj / support) 125 | val = val / (np.pi * np.pi * dj * dj) 126 | 127 | kernel[i - 1][j - 1] = val 128 | 129 | 130 | else: 131 | assert False, 'wrong method name' 132 | 133 | kernel /= kernel.sum() 134 | 135 | return kernel 136 | 137 | #a = Downsampler(n_planes=3, factor=2, kernel_type='lanczos2', phase='1', preserve_size=True) 138 | 139 | 140 | 141 | 142 | 143 | 144 | ################# 145 | # Learnable downsampler 146 | 147 | # KS = 32 148 | # dow = nn.Sequential(nn.ReplicationPad2d(int((KS - factor) / 2.)), nn.Conv2d(1,1,KS,factor)) 149 | 150 | # class Apply(nn.Module): 151 | # def __init__(self, what, dim, *args): 152 | # super(Apply, self).__init__() 153 | # self.dim = dim 154 | 155 | # self.what = what 156 | 157 | # def forward(self, input): 158 | # inputs = [] 159 | # for i in range(input.size(self.dim)): 160 | # inputs.append(self.what(input.narrow(self.dim, i, 1))) 161 | 162 | # return torch.cat(inputs, dim=self.dim) 163 | 164 | # def __len__(self): 165 | # return len(self._modules) 166 | 167 | # downs = Apply(dow, 1) 168 | # downs.type(dtype)(net_input.type(dtype)).size() 169 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | ''' Layers 2 | This file contains various layers for the BigGAN models. 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.nn import Parameter as P 8 | 9 | 10 | # Projection of x onto y 11 | def proj(x, y): 12 | return torch.mm(y, x.t()) * y / torch.mm(y, y.t()) 13 | 14 | 15 | # Orthogonalize x wrt list of vectors ys 16 | def gram_schmidt(x, ys): 17 | for y in ys: 18 | x = x - proj(x, y) 19 | return x 20 | 21 | 22 | # Apply num_itrs steps of the power method to estimate top N singular values. 23 | def power_iteration(W, u_, update=True, eps=1e-12): 24 | # Lists holding singular vectors and values 25 | us, vs, svs = [], [], [] 26 | for i, u in enumerate(u_): 27 | # Run one step of the power iteration 28 | with torch.no_grad(): 29 | v = torch.matmul(u, W) 30 | # Run Gram-Schmidt to subtract components of all other singular vectors 31 | v = F.normalize(gram_schmidt(v, vs), eps=eps) 32 | # Add to the list 33 | vs += [v] 34 | # Update the other singular vector 35 | u = torch.matmul(v, W.t()) 36 | # Run Gram-Schmidt to subtract components of all other singular vectors 37 | u = F.normalize(gram_schmidt(u, us), eps=eps) 38 | # Add to the list 39 | us += [u] 40 | if update: 41 | u_[i][:] = u 42 | # Compute this singular value and add it to the list 43 | svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))] 44 | # svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)] 45 | return svs, us, vs 46 | 47 | 48 | # Convenience passthrough function 49 | class identity(nn.Module): 50 | 51 | def forward(self, input): 52 | return input 53 | 54 | 55 | # Spectral normalization base class 56 | class SN(object): 57 | 58 | def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12): 59 | # Number of power iterations per step 60 | self.num_itrs = num_itrs 61 | # Number of singular values 62 | self.num_svs = num_svs 63 | # Transposed? 64 | self.transpose = transpose 65 | # Epsilon value for avoiding divide-by-0 66 | self.eps = eps 67 | # Register a singular vector for each sv 68 | for i in range(self.num_svs): 69 | self.register_buffer('u%d' % i, torch.randn(1, num_outputs)) 70 | self.register_buffer('sv%d' % i, torch.ones(1)) 71 | 72 | # Singular vectors (u side) 73 | @property 74 | def u(self): 75 | return [getattr(self, 'u%d' % i) for i in range(self.num_svs)] 76 | 77 | # Singular values; 78 | # note that these buffers are just for logging and are not used in training. 79 | @property 80 | def sv(self): 81 | return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)] 82 | 83 | # Compute the spectrally-normalized weight 84 | def W_(self): 85 | W_mat = self.weight.view(self.weight.size(0), -1) 86 | if self.transpose: 87 | W_mat = W_mat.t() 88 | # Apply num_itrs power iterations 89 | for _ in range(self.num_itrs): 90 | svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps) 91 | # Update the svs 92 | if self.training: 93 | with torch.no_grad( 94 | ): # Make sure to do this in a no_grad() context or you'll get memory leaks! 95 | for i, sv in enumerate(svs): 96 | self.sv[i][:] = sv 97 | return self.weight / svs[0] 98 | 99 | 100 | # 2D Conv layer with spectral norm 101 | class SNConv2d(nn.Conv2d, SN): 102 | 103 | def __init__(self, 104 | in_channels, 105 | out_channels, 106 | kernel_size, 107 | stride=1, 108 | padding=0, 109 | dilation=1, 110 | groups=1, 111 | bias=True, 112 | num_svs=1, 113 | num_itrs=1, 114 | eps=1e-12): 115 | nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, 116 | groups, bias) 117 | SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps) 118 | 119 | def forward(self, x): 120 | return F.conv2d(x, self.W_(), self.bias, self.stride, self.padding, self.dilation, 121 | self.groups) 122 | 123 | 124 | # Linear layer with spectral norm 125 | class SNLinear(nn.Linear, SN): 126 | 127 | def __init__(self, in_features, out_features, bias=True, num_svs=1, num_itrs=1, eps=1e-12): 128 | nn.Linear.__init__(self, in_features, out_features, bias) 129 | SN.__init__(self, num_svs, num_itrs, out_features, eps=eps) 130 | 131 | def forward(self, x): 132 | return F.linear(x, self.W_(), self.bias) 133 | 134 | 135 | # Embedding layer with spectral norm 136 | # We use num_embeddings as the dim instead of embedding_dim here 137 | # for convenience sake 138 | class SNEmbedding(nn.Embedding, SN): 139 | 140 | def __init__(self, 141 | num_embeddings, 142 | embedding_dim, 143 | padding_idx=None, 144 | max_norm=None, 145 | norm_type=2, 146 | scale_grad_by_freq=False, 147 | sparse=False, 148 | _weight=None, 149 | num_svs=1, 150 | num_itrs=1, 151 | eps=1e-12): 152 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, 153 | scale_grad_by_freq, sparse, _weight) 154 | SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps) 155 | 156 | def forward(self, x): 157 | return F.embedding(x, self.W_()) 158 | 159 | 160 | # A non-local block as used in SA-GAN 161 | # Note that the implementation as described in the paper is largely incorrect; 162 | # refer to the released code for the actual implementation. 163 | class Attention(nn.Module): 164 | 165 | def __init__(self, ch, which_conv=SNConv2d, name='attention'): 166 | super(Attention, self).__init__() 167 | # Channel multiplier 168 | self.ch = ch 169 | self.which_conv = which_conv 170 | self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) 171 | self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False) 172 | self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False) 173 | self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False) 174 | # Learnable gain parameter 175 | self.gamma = P(torch.tensor(0.), requires_grad=True) 176 | 177 | def forward(self, x, y=None, use_in=False): 178 | # Apply convs 179 | theta = self.theta(x) 180 | phi = F.max_pool2d(self.phi(x), [2, 2]) 181 | g = F.max_pool2d(self.g(x), [2, 2]) 182 | # Perform reshapes 183 | theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3]) 184 | phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4) 185 | g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4) 186 | # Matmul and softmax to get attention maps 187 | beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1) 188 | # Attention map times g path 189 | o = self.o( 190 | torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3])) 191 | output = self.gamma * o + x 192 | return output 193 | 194 | 195 | # Fused batchnorm op 196 | def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5): 197 | # Apply scale and shift--if gain and bias are provided, fuse them here 198 | # Prepare scale 199 | scale = torch.rsqrt(var + eps) 200 | # If a gain is provided, use it 201 | if gain is not None: 202 | scale = scale * gain 203 | # Prepare shift 204 | shift = mean * scale 205 | # If bias is provided, use it 206 | if bias is not None: 207 | shift = shift - bias 208 | return x * scale - shift 209 | # return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way. 210 | 211 | 212 | # Manual BN 213 | # Calculate means and variances using mean-of-squares minus mean-squared 214 | def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5): 215 | # Cast x to float32 if necessary 216 | float_x = x.float() 217 | # Calculate expected value of x (m) and expected value of x**2 (m2) 218 | # Mean of x 219 | m = torch.mean(float_x, [0, 2, 3], keepdim=True) 220 | # Mean of x squared 221 | m2 = torch.mean(float_x**2, [0, 2, 3], keepdim=True) 222 | # Calculate variance as mean of squared minus mean squared. 223 | var = (m2 - m**2) 224 | # Cast back to float 16 if necessary 225 | var = var.type(x.type()) 226 | m = m.type(x.type()) 227 | # Return mean and variance for updating stored mean/var if requested 228 | if return_mean_var: 229 | return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze() 230 | else: 231 | return fused_bn(x, m, var, gain, bias, eps) 232 | 233 | 234 | # My batchnorm, supports standing stats 235 | class myBN(nn.Module): 236 | 237 | def __init__(self, num_channels, eps=1e-5, momentum=0.1): 238 | super(myBN, self).__init__() 239 | # momentum for updating running stats 240 | self.momentum = momentum 241 | # epsilon to avoid dividing by 0 242 | self.eps = eps 243 | # Momentum 244 | self.momentum = momentum 245 | # Register buffers 246 | self.register_buffer('stored_mean', torch.zeros(num_channels)) 247 | self.register_buffer('stored_var', torch.ones(num_channels)) 248 | self.register_buffer('accumulation_counter', torch.zeros(1)) 249 | # Accumulate running means and vars 250 | self.accumulate_standing = False 251 | 252 | # reset standing stats 253 | def reset_stats(self): 254 | self.stored_mean[:] = 0 255 | self.stored_var[:] = 0 256 | self.accumulation_counter[:] = 0 257 | 258 | def forward(self, x, gain, bias): 259 | if self.training: 260 | out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps) 261 | # If accumulating standing stats, increment them 262 | if self.accumulate_standing: 263 | self.stored_mean[:] = self.stored_mean + mean.data 264 | self.stored_var[:] = self.stored_var + var.data 265 | self.accumulation_counter += 1.0 266 | # If not accumulating standing stats, take running averages 267 | else: 268 | self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum 269 | self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum 270 | return out 271 | # If not in training mode, use the stored statistics 272 | else: 273 | mean = self.stored_mean.view(1, -1, 1, 1) 274 | var = self.stored_var.view(1, -1, 1, 1) 275 | # If using standing stats, divide them by the accumulation counter 276 | if self.accumulate_standing: 277 | mean = mean / self.accumulation_counter 278 | var = var / self.accumulation_counter 279 | return fused_bn(x, mean, var, gain, bias, self.eps) 280 | 281 | 282 | # Class-conditional bn 283 | # output size is the number of channels, input size is for the linear layers 284 | # Andy's Note: this class feels messy but I'm not really sure how to clean it up 285 | # Suggestions welcome! (By which I mean, refactor this and make a pull request 286 | # if you want to make this more readable/usable). 287 | class ccbn(nn.Module): 288 | 289 | def __init__( 290 | self, 291 | output_size, 292 | input_size, 293 | which_linear, 294 | eps=1e-5, 295 | momentum=0.1, 296 | cross_replica=False, 297 | mybn=False, 298 | norm_style='bn', 299 | ): 300 | super(ccbn, self).__init__() 301 | self.output_size, self.input_size = output_size, input_size 302 | # Prepare gain and bias layers 303 | self.gain = which_linear(input_size, output_size) 304 | self.bias = which_linear(input_size, output_size) 305 | # epsilon to avoid dividing by 0 306 | self.eps = eps 307 | # Momentum 308 | self.momentum = momentum 309 | # Use cross-replica batchnorm? 310 | self.cross_replica = cross_replica 311 | # Use my batchnorm? 312 | self.mybn = mybn 313 | # Norm style? 314 | self.norm_style = norm_style 315 | 316 | if self.cross_replica: 317 | raise NotImplementedError 318 | elif self.mybn: 319 | self.bn = myBN(output_size, self.eps, self.momentum) 320 | elif self.norm_style in ['bn', 'in']: 321 | self.register_buffer('stored_mean', torch.zeros(output_size)) 322 | self.register_buffer('stored_var', torch.ones(output_size)) 323 | 324 | def forward(self, x, y): 325 | # Calculate class-conditional gains and biases 326 | if y is not None: 327 | gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1) 328 | bias = self.bias(y).view(y.size(0), -1, 1, 1) 329 | # If using my batchnorm 330 | if self.mybn: 331 | return self.bn(x, gain=gain, bias=bias) 332 | elif self.cross_replica: 333 | return self.bn(x) * gain + bias 334 | # else: 335 | else: 336 | if self.norm_style == 'bn': 337 | out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, self.training, 338 | 0.1, self.eps) 339 | elif self.norm_style == 'in': 340 | out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None, 341 | self.training, 0.1, self.eps) 342 | elif self.norm_style == 'nonorm': 343 | out = x 344 | if y is not None: 345 | out = out * gain + bias 346 | return out 347 | 348 | def extra_repr(self): 349 | s = 'out: {output_size}, in: {input_size},' 350 | s += ' cross_replica={cross_replica}' 351 | return s.format(**self.__dict__) 352 | 353 | 354 | # Normal, non-class-conditional BN 355 | class bn(nn.Module): 356 | 357 | def __init__(self, output_size, eps=1e-5, momentum=0.1, cross_replica=False, mybn=False): 358 | super(bn, self).__init__() 359 | self.output_size = output_size 360 | # Prepare gain and bias layers 361 | self.gain = P(torch.ones(output_size), requires_grad=True) 362 | self.bias = P(torch.zeros(output_size), requires_grad=True) 363 | # epsilon to avoid dividing by 0 364 | self.eps = eps 365 | # Momentum 366 | self.momentum = momentum 367 | # Use cross-replica batchnorm? 368 | self.cross_replica = cross_replica 369 | # Use my batchnorm? 370 | self.mybn = mybn 371 | 372 | if self.cross_replica: 373 | raise NotImplementedError 374 | elif mybn: 375 | self.bn = myBN(output_size, self.eps, self.momentum) 376 | # Register buffers if neither of the above 377 | else: 378 | self.register_buffer('stored_mean', torch.zeros(output_size)) 379 | self.register_buffer('stored_var', torch.ones(output_size)) 380 | 381 | def forward(self, x, y=None): 382 | if self.cross_replica or self.mybn: 383 | gain = self.gain.view(1, -1, 1, 1) 384 | bias = self.bias.view(1, -1, 1, 1) 385 | if self.mybn: 386 | return self.bn(x, gain=gain, bias=bias) 387 | elif self.cross_replica: 388 | return self.bn(x) * gain + bias 389 | else: 390 | return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain, self.bias, 391 | self.training, self.momentum, self.eps) 392 | 393 | 394 | # Generator blocks 395 | # Note that this class assumes the kernel size and padding (and any other 396 | # settings) have been selected in the main generator module and passed in 397 | # through the which_conv arg. Similar rules apply with which_bn (the input 398 | # size [which is actually the number of channels of the conditional info] must 399 | # be preselected) 400 | class GBlock(nn.Module): 401 | 402 | def __init__(self, 403 | in_channels, 404 | out_channels, 405 | which_conv=nn.Conv2d, 406 | which_bn=bn, 407 | activation=None, 408 | upsample=None): 409 | super(GBlock, self).__init__() 410 | 411 | self.in_channels, self.out_channels = in_channels, out_channels 412 | self.which_conv, self.which_bn = which_conv, which_bn 413 | self.activation = activation 414 | self.upsample = upsample 415 | # Conv layers 416 | self.conv1 = self.which_conv(self.in_channels, self.out_channels) 417 | self.conv2 = self.which_conv(self.out_channels, self.out_channels) 418 | self.learnable_sc = in_channels != out_channels or upsample 419 | if self.learnable_sc: 420 | self.conv_sc = self.which_conv(in_channels, out_channels, kernel_size=1, padding=0) 421 | # Batchnorm layers 422 | self.bn1 = self.which_bn(in_channels) 423 | self.bn2 = self.which_bn(out_channels) 424 | # upsample layers 425 | self.upsample = upsample 426 | 427 | # instance norm layers 428 | self.in_initialized = False 429 | self.in1 = nn.InstanceNorm2d(in_channels, affine=True) 430 | self.in2 = nn.InstanceNorm2d(out_channels, affine=True) 431 | self.in1.weight.requires_grad = False 432 | self.in1.bias.requires_grad = False 433 | self.in2.weight.requires_grad = False 434 | self.in2.bias.requires_grad = False 435 | 436 | def reset_in_init(self): 437 | self.in_initialized = False 438 | 439 | def init_in(self, which_bn, which_in, x, y): 440 | # carefully initialize IN's weights such that the output does not change 441 | with torch.no_grad(): 442 | h = which_bn(x, y) 443 | mean = torch.mean(h, (2, 3)).squeeze(0) 444 | std = torch.std(h.view(h.size(0), h.size(1), -1), 2).squeeze(0) 445 | which_in.weight.copy_(std) 446 | which_in.bias.copy_(mean) 447 | 448 | def forward(self, x, y, use_in): 449 | if use_in: 450 | if not self.in_initialized: 451 | self.init_in(self.bn1, self.in1, x, y) 452 | h = self.in1(x) 453 | else: 454 | h = self.bn1(x, y) 455 | self.in_initialized = False 456 | 457 | h = self.activation(h) 458 | if self.upsample: 459 | h = self.upsample(h) 460 | x = self.upsample(x) 461 | h = self.conv1(h) 462 | 463 | if use_in: 464 | if not self.in_initialized: 465 | self.init_in(self.bn2, self.in2, h, y) 466 | self.in_initialized = True 467 | h = self.in2(h) 468 | else: 469 | h = self.bn2(h, y) 470 | 471 | h = self.activation(h) 472 | h = self.conv2(h) 473 | if self.learnable_sc: 474 | x = self.conv_sc(x) 475 | return h + x 476 | 477 | 478 | # Residual block for the discriminator 479 | class DBlock(nn.Module): 480 | 481 | def __init__( 482 | self, 483 | in_channels, 484 | out_channels, 485 | which_conv=SNConv2d, 486 | wide=True, 487 | preactivation=False, 488 | activation=None, 489 | downsample=None, 490 | ): 491 | super(DBlock, self).__init__() 492 | self.in_channels, self.out_channels = in_channels, out_channels 493 | # If using wide D (as in SA-GAN and BigGAN), change the channel pattern 494 | self.hidden_channels = self.out_channels if wide else self.in_channels 495 | self.which_conv = which_conv 496 | self.preactivation = preactivation 497 | self.activation = activation 498 | self.downsample = downsample 499 | 500 | # Conv layers 501 | self.conv1 = self.which_conv(self.in_channels, self.hidden_channels) 502 | self.conv2 = self.which_conv(self.hidden_channels, self.out_channels) 503 | self.learnable_sc = True if (in_channels != out_channels) or downsample else False 504 | if self.learnable_sc: 505 | self.conv_sc = self.which_conv(in_channels, out_channels, kernel_size=1, padding=0) 506 | 507 | def shortcut(self, x): 508 | if self.preactivation: 509 | if self.learnable_sc: 510 | x = self.conv_sc(x) 511 | if self.downsample: 512 | x = self.downsample(x) 513 | else: 514 | if self.downsample: 515 | x = self.downsample(x) 516 | if self.learnable_sc: 517 | x = self.conv_sc(x) 518 | return x 519 | 520 | def forward(self, x): 521 | if self.preactivation: 522 | # h = self.activation(x) # NOT TODAY SATAN 523 | # Andy's note: This line *must* be an out-of-place ReLU or it 524 | # will negatively affect the shortcut connection. 525 | h = F.relu(x) 526 | else: 527 | h = x 528 | h = self.conv1(h) 529 | h = self.conv2(self.activation(h)) 530 | if self.downsample: 531 | h = self.downsample(h) 532 | 533 | return h + self.shortcut(x) 534 | -------------------------------------------------------------------------------- /models/nethook.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Utilities for instrumenting a torch model. 3 | 4 | InstrumentedModel will wrap a pytorch model and allow hooking 5 | arbitrary layers to monitor or modify their output directly. 6 | ''' 7 | 8 | import torch, numpy, types, copy 9 | from collections import OrderedDict, defaultdict 10 | 11 | 12 | class InstrumentedModel(torch.nn.Module): 13 | ''' 14 | A wrapper for hooking, probing and intervening in pytorch Modules. 15 | Example usage: 16 | 17 | ``` 18 | model = load_my_model() 19 | with inst as InstrumentedModel(model): 20 | inst.retain_layer(layername) 21 | inst.edit_layer(layername, ablation=0.5, replacement=target_features) 22 | inst(inputs) 23 | original_features = inst.retained_layer(layername) 24 | ``` 25 | ''' 26 | 27 | def __init__(self, model): 28 | super().__init__() 29 | self.model = model 30 | self._retained = OrderedDict() 31 | self._detach_retained = {} 32 | self._editargs = defaultdict(dict) 33 | self._editrule = {} 34 | self._hooked_layer = {} 35 | self._old_forward = {} 36 | if isinstance(model, torch.nn.Sequential): 37 | self._hook_sequential() 38 | 39 | def __enter__(self): 40 | return self 41 | 42 | def __exit__(self, type, value, traceback): 43 | self.close() 44 | 45 | def forward(self, *inputs, **kwargs): 46 | return self.model(*inputs, **kwargs) 47 | 48 | def retain_layer(self, layername, detach=True): 49 | ''' 50 | Pass a fully-qualified layer name (E.g., module.submodule.conv3) 51 | to hook that layer and retain its output each time the model is run. 52 | A pair (layername, aka) can be provided, and the aka will be used 53 | as the key for the retained value instead of the layername. 54 | ''' 55 | self.retain_layers([layername], detach=detach) 56 | 57 | def retain_layers(self, layernames, detach=True): 58 | ''' 59 | Retains a list of a layers at once. 60 | ''' 61 | self.add_hooks(layernames) 62 | for layername in layernames: 63 | aka = layername 64 | if not isinstance(aka, str): 65 | layername, aka = layername 66 | if aka not in self._retained: 67 | self._retained[aka] = None 68 | self._detach_retained[aka] = detach 69 | 70 | def stop_retaining_layers(self, layernames): 71 | ''' 72 | Removes a list of layers from the set retained. 73 | ''' 74 | self.add_hooks(layernames) 75 | for layername in layernames: 76 | aka = layername 77 | if not isinstance(aka, str): 78 | layername, aka = layername 79 | if aka in self._retained: 80 | del self._retained[aka] 81 | del self._detach_retained[aka] 82 | 83 | def retained_features(self, clear=False): 84 | ''' 85 | Returns a dict of all currently retained features. 86 | ''' 87 | result = OrderedDict(self._retained) 88 | if clear: 89 | for k in result: 90 | self._retained[k] = None 91 | return result 92 | 93 | def retained_layer(self, aka=None, clear=False): 94 | ''' 95 | Retrieve retained data that was previously hooked by retain_layer. 96 | Call this after the model is run. If clear is set, then the 97 | retained value will return and also cleared. 98 | ''' 99 | if aka is None: 100 | # Default to the first retained layer. 101 | aka = next(self._retained.keys().__iter__()) 102 | result = self._retained[aka] 103 | if clear: 104 | self._retained[aka] = None 105 | return result 106 | 107 | def edit_layer(self, layername, rule=None, **kwargs): 108 | ''' 109 | Pass a fully-qualified layer name (E.g., module.submodule.conv3) 110 | to hook that layer and modify its output each time the model is run. 111 | The output of the layer will be modified to be a convex combination 112 | of the replacement and x interpolated according to the ablation, i.e.: 113 | `output = x * (1 - a) + (r * a)`. 114 | ''' 115 | if not isinstance(layername, str): 116 | layername, aka = layername 117 | else: 118 | aka = layername 119 | 120 | # The default editing rule is apply_ablation_replacement 121 | if rule is None: 122 | rule = apply_ablation_replacement 123 | 124 | self.add_hooks([(layername, aka)]) 125 | self._editargs[aka].update(kwargs) 126 | self._editrule[aka] = rule 127 | 128 | def remove_edits(self, layername=None): 129 | ''' 130 | Removes edits at the specified layer, or removes edits at all layers 131 | if no layer name is specified. 132 | ''' 133 | if layername is None: 134 | self._editargs.clear() 135 | self._editrule.clear() 136 | return 137 | 138 | if not isinstance(layername, str): 139 | layername, aka = layername 140 | else: 141 | aka = layername 142 | if aka in self._editargs: 143 | del self._editargs[aka] 144 | if aka in self._editrule: 145 | del self._editrule[aka] 146 | 147 | def add_hooks(self, layernames): 148 | ''' 149 | Sets up a set of layers to be hooked. 150 | 151 | Usually not called directly: use edit_layer or retain_layer instead. 152 | ''' 153 | needed = set() 154 | aka_map = {} 155 | for name in layernames: 156 | aka = name 157 | if not isinstance(aka, str): 158 | name, aka = name 159 | if self._hooked_layer.get(aka, None) != name: 160 | aka_map[name] = aka 161 | needed.add(name) 162 | if not needed: 163 | return 164 | for name, layer in self.model.named_modules(): 165 | if name in aka_map: 166 | needed.remove(name) 167 | aka = aka_map[name] 168 | self._hook_layer(layer, name, aka) 169 | for name in needed: 170 | raise ValueError('Layer %s not found in model' % name) 171 | 172 | def _hook_layer(self, layer, layername, aka): 173 | ''' 174 | Internal method to replace a forward method with a closure that 175 | intercepts the call, and tracks the hook so that it can be reverted. 176 | ''' 177 | if aka in self._hooked_layer: 178 | raise ValueError('Layer %s already hooked' % aka) 179 | if layername in self._old_forward: 180 | raise ValueError('Layer %s already hooked' % layername) 181 | self._hooked_layer[aka] = layername 182 | self._old_forward[layername] = (layer, aka, layer.__dict__.get('forward', None)) 183 | editor = self 184 | original_forward = layer.forward 185 | 186 | def new_forward(self, *inputs, **kwargs): 187 | original_x = original_forward(*inputs, **kwargs) 188 | x = editor._postprocess_forward(original_x, aka) 189 | return x 190 | 191 | layer.forward = types.MethodType(new_forward, layer) 192 | 193 | def _unhook_layer(self, aka): 194 | ''' 195 | Internal method to remove a hook, restoring the original forward method. 196 | ''' 197 | if aka not in self._hooked_layer: 198 | return 199 | layername = self._hooked_layer[aka] 200 | # Remove any retained data and any edit rules 201 | if aka in self._retained: 202 | del self._retained[aka] 203 | del self._detach_retained[aka] 204 | self.remove_edits(aka) 205 | # Restore the unhooked method for the layer 206 | layer, check, old_forward = self._old_forward[layername] 207 | assert check == aka 208 | if old_forward is None: 209 | if 'forward' in layer.__dict__: 210 | del layer.__dict__['forward'] 211 | else: 212 | layer.forward = old_forward 213 | del self._old_forward[layername] 214 | del self._hooked_layer[aka] 215 | 216 | def _postprocess_forward(self, x, aka): 217 | ''' 218 | The internal method called by the hooked layers after they are run. 219 | ''' 220 | # Retain output before edits, if desired. 221 | if aka in self._retained: 222 | if self._detach_retained[aka]: 223 | self._retained[aka] = x.detach() 224 | else: 225 | self._retained[aka] = x 226 | # Apply any edits requested. 227 | rule = self._editrule.get(aka, None) 228 | if rule is not None: 229 | x = rule(x, self, **(self._editargs[aka])) 230 | return x 231 | 232 | def _hook_sequential(self): 233 | ''' 234 | Replaces 'forward' of sequential with a version that takes 235 | additional keyword arguments: layer allows a single layer to be run; 236 | first_layer and last_layer allow a subsequence of layers to be run. 237 | ''' 238 | model = self.model 239 | self._hooked_layer['.'] = '.' 240 | self._old_forward['.'] = (model, '.', model.__dict__.get('forward', None)) 241 | 242 | def new_forward(this, x, layer=None, first_layer=None, last_layer=None): 243 | assert layer is None or (first_layer is None and last_layer is None) 244 | first_layer, last_layer = [ 245 | str(layer) if layer is not None else str(d) if d is not None else None 246 | for d in [first_layer, last_layer] 247 | ] 248 | including_children = (first_layer is None) 249 | for name, layer in this._modules.items(): 250 | if name == first_layer: 251 | first_layer = None 252 | including_children = True 253 | if including_children: 254 | x = layer(x) 255 | if name == last_layer: 256 | last_layer = None 257 | including_children = False 258 | assert first_layer is None, '%s not found' % first_layer 259 | assert last_layer is None, '%s not found' % last_layer 260 | return x 261 | 262 | model.forward = types.MethodType(new_forward, model) 263 | 264 | def close(self): 265 | ''' 266 | Unhooks all hooked layers in the model. 267 | ''' 268 | for aka in list(self._old_forward.keys()): 269 | self._unhook_layer(aka) 270 | assert len(self._old_forward) == 0 271 | 272 | 273 | def apply_ablation_replacement(x, imodel, **buffers): 274 | if buffers is not None: 275 | # Apply any edits requested. 276 | a = make_matching_tensor(buffers, 'ablation', x) 277 | if a is not None: 278 | x = x * (1 - a) 279 | v = make_matching_tensor(buffers, 'replacement', x) 280 | if v is not None: 281 | x += (v * a) 282 | return x 283 | 284 | 285 | def make_matching_tensor(valuedict, name, data): 286 | ''' 287 | Converts `valuedict[name]` to be a tensor with the same dtype, device, 288 | and dimension count as `data`, and caches the converted tensor. 289 | ''' 290 | v = valuedict.get(name, None) 291 | if v is None: 292 | return None 293 | if not isinstance(v, torch.Tensor): 294 | # Accept non-torch data. 295 | v = torch.from_numpy(numpy.array(v)) 296 | valuedict[name] = v 297 | if not v.device == data.device or not v.dtype == data.dtype: 298 | # Ensure device and type matches. 299 | assert not v.requires_grad, '%s wrong device or type' % (name) 300 | v = v.to(device=data.device, dtype=data.dtype) 301 | valuedict[name] = v 302 | if len(v.shape) < len(data.shape): 303 | # Ensure dimensions are unsqueezed as needed. 304 | assert not v.requires_grad, '%s wrong dimensions' % (name) 305 | v = v.view((1, ) + tuple(v.shape) + (1, ) * (len(data.shape) - len(v.shape) - 1)) 306 | valuedict[name] = v 307 | return v 308 | 309 | 310 | def subsequence(sequential, 311 | first_layer=None, 312 | last_layer=None, 313 | upto_layer=None, 314 | single_layer=None, 315 | share_weights=False): 316 | ''' 317 | Creates a subsequence of a pytorch Sequential model, copying over 318 | modules together with parameters for the subsequence. Only 319 | modules from first_layer to last_layer (inclusive) are included. 320 | 321 | If share_weights is True, then references the original modules 322 | and their parameters without copying them. Otherwise, by default, 323 | makes a separate brand-new copy. 324 | ''' 325 | assert ((single_layer is None) or (first_layer is last_layer is upto_layer is None)) 326 | assert (last_layer is None) or (upto_layer is None) 327 | if single_layer is not None: 328 | first_layer = single_layer 329 | last_layer = single_layer 330 | included_children = OrderedDict() 331 | including_children = (first_layer is None) 332 | for name, layer in sequential._modules.items(): 333 | if name == first_layer: 334 | first_layer = None 335 | including_children = True 336 | if name == upto_layer: 337 | upto_layer = None 338 | including_children = False 339 | if including_children: 340 | included_children[name] = layer if share_weights else (copy.deepcopy(layer)) 341 | if name == last_layer: 342 | last_layer = None 343 | including_children = False 344 | if first_layer is not None: 345 | raise ValueError('Layer %s not found' % first_layer) 346 | if last_layer is not None: 347 | raise ValueError('Layer %s not found' % last_layer) 348 | if upto_layer is not None: 349 | raise ValueError('Layer %s not found' % upto_layer) 350 | # if not len(included_children): 351 | # raise ValueError('Empty subsequence') 352 | return torch.nn.Sequential(OrderedDict(included_children)) 353 | 354 | 355 | def set_requires_grad(requires_grad, *models): 356 | for model in models: 357 | if isinstance(model, torch.nn.Module): 358 | for param in model.parameters(): 359 | param.requires_grad = requires_grad 360 | elif isinstance(model, (torch.nn.Parameter, torch.Tensor)): 361 | model.requires_grad = requires_grad 362 | else: 363 | assert False, 'unknown type %r' % type(model) 364 | -------------------------------------------------------------------------------- /pretrained/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scikit-image 3 | scipy 4 | pillow<=6.2.2 5 | torchvision 6 | tensorboardX 7 | -------------------------------------------------------------------------------- /scripts/imagenet_val_1k.txt: -------------------------------------------------------------------------------- 1 | ILSVRC2012_val_00000293.JPEG 0 2 | ILSVRC2012_val_00000236.JPEG 1 3 | ILSVRC2012_val_00002338.JPEG 2 4 | ILSVRC2012_val_00002922.JPEG 3 5 | ILSVRC2012_val_00001676.JPEG 4 6 | ILSVRC2012_val_00000921.JPEG 5 7 | ILSVRC2012_val_00001935.JPEG 6 8 | ILSVRC2012_val_00000329.JPEG 7 9 | ILSVRC2012_val_00001114.JPEG 8 10 | ILSVRC2012_val_00001031.JPEG 9 11 | ILSVRC2012_val_00000651.JPEG 10 12 | ILSVRC2012_val_00000570.JPEG 11 13 | ILSVRC2012_val_00000873.JPEG 12 14 | ILSVRC2012_val_00000247.JPEG 13 15 | ILSVRC2012_val_00000414.JPEG 14 16 | ILSVRC2012_val_00001598.JPEG 15 17 | ILSVRC2012_val_00000198.JPEG 16 18 | ILSVRC2012_val_00000880.JPEG 17 19 | ILSVRC2012_val_00000476.JPEG 18 20 | ILSVRC2012_val_00000837.JPEG 19 21 | ILSVRC2012_val_00000962.JPEG 20 22 | ILSVRC2012_val_00000073.JPEG 21 23 | ILSVRC2012_val_00000128.JPEG 22 24 | ILSVRC2012_val_00000018.JPEG 23 25 | ILSVRC2012_val_00000318.JPEG 24 26 | ILSVRC2012_val_00000225.JPEG 25 27 | ILSVRC2012_val_00000498.JPEG 26 28 | ILSVRC2012_val_00000258.JPEG 27 29 | ILSVRC2012_val_00000088.JPEG 28 30 | ILSVRC2012_val_00000052.JPEG 29 31 | ILSVRC2012_val_00001652.JPEG 30 32 | ILSVRC2012_val_00000944.JPEG 31 33 | ILSVRC2012_val_00000037.JPEG 32 34 | ILSVRC2012_val_00000404.JPEG 33 35 | ILSVRC2012_val_00000791.JPEG 34 36 | ILSVRC2012_val_00000229.JPEG 35 37 | ILSVRC2012_val_00002832.JPEG 36 38 | ILSVRC2012_val_00000362.JPEG 37 39 | ILSVRC2012_val_00003904.JPEG 38 40 | ILSVRC2012_val_00000153.JPEG 39 41 | ILSVRC2012_val_00001931.JPEG 40 42 | ILSVRC2012_val_00000908.JPEG 41 43 | ILSVRC2012_val_00003821.JPEG 42 44 | ILSVRC2012_val_00000297.JPEG 43 45 | ILSVRC2012_val_00001022.JPEG 44 46 | ILSVRC2012_val_00001475.JPEG 45 47 | ILSVRC2012_val_00000064.JPEG 46 48 | ILSVRC2012_val_00001375.JPEG 47 49 | ILSVRC2012_val_00000541.JPEG 48 50 | ILSVRC2012_val_00001573.JPEG 49 51 | ILSVRC2012_val_00000090.JPEG 50 52 | ILSVRC2012_val_00002051.JPEG 51 53 | ILSVRC2012_val_00003569.JPEG 52 54 | ILSVRC2012_val_00001042.JPEG 53 55 | ILSVRC2012_val_00002568.JPEG 54 56 | ILSVRC2012_val_00000029.JPEG 55 57 | ILSVRC2012_val_00000512.JPEG 56 58 | ILSVRC2012_val_00000006.JPEG 57 59 | ILSVRC2012_val_00000084.JPEG 58 60 | ILSVRC2012_val_00001108.JPEG 59 61 | ILSVRC2012_val_00000337.JPEG 60 62 | ILSVRC2012_val_00000786.JPEG 61 63 | ILSVRC2012_val_00000688.JPEG 62 64 | ILSVRC2012_val_00003143.JPEG 63 65 | ILSVRC2012_val_00000298.JPEG 64 66 | ILSVRC2012_val_00000001.JPEG 65 67 | ILSVRC2012_val_00001582.JPEG 66 68 | ILSVRC2012_val_00000749.JPEG 67 69 | ILSVRC2012_val_00001706.JPEG 68 70 | ILSVRC2012_val_00000299.JPEG 69 71 | ILSVRC2012_val_00000107.JPEG 70 72 | ILSVRC2012_val_00001007.JPEG 71 73 | ILSVRC2012_val_00001037.JPEG 72 74 | ILSVRC2012_val_00002688.JPEG 73 75 | ILSVRC2012_val_00000040.JPEG 74 76 | ILSVRC2012_val_00001150.JPEG 75 77 | ILSVRC2012_val_00000396.JPEG 76 78 | ILSVRC2012_val_00000218.JPEG 77 79 | ILSVRC2012_val_00000679.JPEG 78 80 | ILSVRC2012_val_00003981.JPEG 79 81 | ILSVRC2012_val_00000075.JPEG 80 82 | ILSVRC2012_val_00000860.JPEG 81 83 | ILSVRC2012_val_00002539.JPEG 82 84 | ILSVRC2012_val_00001461.JPEG 83 85 | ILSVRC2012_val_00001146.JPEG 84 86 | ILSVRC2012_val_00003015.JPEG 85 87 | ILSVRC2012_val_00000576.JPEG 86 88 | ILSVRC2012_val_00001629.JPEG 87 89 | ILSVRC2012_val_00001163.JPEG 88 90 | ILSVRC2012_val_00001288.JPEG 89 91 | ILSVRC2012_val_00001749.JPEG 90 92 | ILSVRC2012_val_00000370.JPEG 91 93 | ILSVRC2012_val_00000051.JPEG 92 94 | ILSVRC2012_val_00002100.JPEG 93 95 | ILSVRC2012_val_00002974.JPEG 94 96 | ILSVRC2012_val_00002326.JPEG 95 97 | ILSVRC2012_val_00000952.JPEG 96 98 | ILSVRC2012_val_00000415.JPEG 97 99 | ILSVRC2012_val_00002309.JPEG 98 100 | ILSVRC2012_val_00000547.JPEG 99 101 | ILSVRC2012_val_00000466.JPEG 100 102 | ILSVRC2012_val_00000067.JPEG 101 103 | ILSVRC2012_val_00000319.JPEG 102 104 | ILSVRC2012_val_00000189.JPEG 103 105 | ILSVRC2012_val_00000655.JPEG 104 106 | ILSVRC2012_val_00000970.JPEG 105 107 | ILSVRC2012_val_00001933.JPEG 106 108 | ILSVRC2012_val_00000063.JPEG 107 109 | ILSVRC2012_val_00000017.JPEG 108 110 | ILSVRC2012_val_00000011.JPEG 109 111 | ILSVRC2012_val_00003054.JPEG 110 112 | ILSVRC2012_val_00000836.JPEG 111 113 | ILSVRC2012_val_00000781.JPEG 112 114 | ILSVRC2012_val_00001268.JPEG 113 115 | ILSVRC2012_val_00000546.JPEG 114 116 | ILSVRC2012_val_00000154.JPEG 115 117 | ILSVRC2012_val_00000689.JPEG 116 118 | ILSVRC2012_val_00001356.JPEG 117 119 | ILSVRC2012_val_00000195.JPEG 118 120 | ILSVRC2012_val_00002684.JPEG 119 121 | ILSVRC2012_val_00000818.JPEG 120 122 | ILSVRC2012_val_00000238.JPEG 121 123 | ILSVRC2012_val_00000111.JPEG 122 124 | ILSVRC2012_val_00002152.JPEG 123 125 | ILSVRC2012_val_00001381.JPEG 124 126 | ILSVRC2012_val_00000231.JPEG 125 127 | ILSVRC2012_val_00000669.JPEG 126 128 | ILSVRC2012_val_00000690.JPEG 127 129 | ILSVRC2012_val_00001068.JPEG 128 130 | ILSVRC2012_val_00000043.JPEG 129 131 | ILSVRC2012_val_00000356.JPEG 130 132 | ILSVRC2012_val_00000794.JPEG 131 133 | ILSVRC2012_val_00001067.JPEG 132 134 | ILSVRC2012_val_00000363.JPEG 133 135 | ILSVRC2012_val_00000389.JPEG 134 136 | ILSVRC2012_val_00002126.JPEG 135 137 | ILSVRC2012_val_00003399.JPEG 136 138 | ILSVRC2012_val_00000825.JPEG 137 139 | ILSVRC2012_val_00001073.JPEG 138 140 | ILSVRC2012_val_00001800.JPEG 139 141 | ILSVRC2012_val_00000177.JPEG 140 142 | ILSVRC2012_val_00000470.JPEG 141 143 | ILSVRC2012_val_00000098.JPEG 142 144 | ILSVRC2012_val_00000548.JPEG 143 145 | ILSVRC2012_val_00003147.JPEG 144 146 | ILSVRC2012_val_00001056.JPEG 145 147 | ILSVRC2012_val_00000207.JPEG 146 148 | ILSVRC2012_val_00000016.JPEG 147 149 | ILSVRC2012_val_00000545.JPEG 148 150 | ILSVRC2012_val_00000072.JPEG 149 151 | ILSVRC2012_val_00000033.JPEG 150 152 | ILSVRC2012_val_00001049.JPEG 151 153 | ILSVRC2012_val_00000462.JPEG 152 154 | ILSVRC2012_val_00000457.JPEG 153 155 | ILSVRC2012_val_00001069.JPEG 154 156 | ILSVRC2012_val_00000907.JPEG 155 157 | ILSVRC2012_val_00000613.JPEG 156 158 | ILSVRC2012_val_00001389.JPEG 157 159 | ILSVRC2012_val_00000170.JPEG 158 160 | ILSVRC2012_val_00000077.JPEG 159 161 | ILSVRC2012_val_00000082.JPEG 160 162 | ILSVRC2012_val_00000838.JPEG 161 163 | ILSVRC2012_val_00000599.JPEG 162 164 | ILSVRC2012_val_00000141.JPEG 163 165 | ILSVRC2012_val_00001097.JPEG 164 166 | ILSVRC2012_val_00000327.JPEG 165 167 | ILSVRC2012_val_00000511.JPEG 166 168 | ILSVRC2012_val_00000028.JPEG 167 169 | ILSVRC2012_val_00001267.JPEG 168 170 | ILSVRC2012_val_00001626.JPEG 169 171 | ILSVRC2012_val_00001765.JPEG 170 172 | ILSVRC2012_val_00001825.JPEG 171 173 | ILSVRC2012_val_00001846.JPEG 172 174 | ILSVRC2012_val_00000022.JPEG 173 175 | ILSVRC2012_val_00002328.JPEG 174 176 | ILSVRC2012_val_00000079.JPEG 175 177 | ILSVRC2012_val_00000263.JPEG 176 178 | ILSVRC2012_val_00000405.JPEG 177 179 | ILSVRC2012_val_00001385.JPEG 178 180 | ILSVRC2012_val_00000248.JPEG 179 181 | ILSVRC2012_val_00000191.JPEG 180 182 | ILSVRC2012_val_00000883.JPEG 181 183 | ILSVRC2012_val_00000832.JPEG 182 184 | ILSVRC2012_val_00000062.JPEG 183 185 | ILSVRC2012_val_00000588.JPEG 184 186 | ILSVRC2012_val_00000391.JPEG 185 187 | ILSVRC2012_val_00000612.JPEG 186 188 | ILSVRC2012_val_00000696.JPEG 187 189 | ILSVRC2012_val_00001001.JPEG 188 190 | ILSVRC2012_val_00002194.JPEG 189 191 | ILSVRC2012_val_00000474.JPEG 190 192 | ILSVRC2012_val_00003002.JPEG 191 193 | ILSVRC2012_val_00000379.JPEG 192 194 | ILSVRC2012_val_00001564.JPEG 193 195 | ILSVRC2012_val_00000279.JPEG 194 196 | ILSVRC2012_val_00000350.JPEG 195 197 | ILSVRC2012_val_00000176.JPEG 196 198 | ILSVRC2012_val_00002573.JPEG 197 199 | ILSVRC2012_val_00000044.JPEG 198 200 | ILSVRC2012_val_00000411.JPEG 199 201 | ILSVRC2012_val_00000802.JPEG 200 202 | ILSVRC2012_val_00000105.JPEG 201 203 | ILSVRC2012_val_00000748.JPEG 202 204 | ILSVRC2012_val_00002784.JPEG 203 205 | ILSVRC2012_val_00001481.JPEG 204 206 | ILSVRC2012_val_00000824.JPEG 205 207 | ILSVRC2012_val_00001434.JPEG 206 208 | ILSVRC2012_val_00001112.JPEG 207 209 | ILSVRC2012_val_00000375.JPEG 208 210 | ILSVRC2012_val_00000668.JPEG 209 211 | ILSVRC2012_val_00001070.JPEG 210 212 | ILSVRC2012_val_00002323.JPEG 211 213 | ILSVRC2012_val_00000172.JPEG 212 214 | ILSVRC2012_val_00000161.JPEG 213 215 | ILSVRC2012_val_00000708.JPEG 214 216 | ILSVRC2012_val_00000227.JPEG 215 217 | ILSVRC2012_val_00000643.JPEG 216 218 | ILSVRC2012_val_00000665.JPEG 217 219 | ILSVRC2012_val_00002298.JPEG 218 220 | ILSVRC2012_val_00000524.JPEG 219 221 | ILSVRC2012_val_00000539.JPEG 220 222 | ILSVRC2012_val_00001693.JPEG 221 223 | ILSVRC2012_val_00000507.JPEG 222 224 | ILSVRC2012_val_00000533.JPEG 223 225 | ILSVRC2012_val_00001109.JPEG 224 226 | ILSVRC2012_val_00003176.JPEG 225 227 | ILSVRC2012_val_00004938.JPEG 226 228 | ILSVRC2012_val_00000113.JPEG 227 229 | ILSVRC2012_val_00000383.JPEG 228 230 | ILSVRC2012_val_00002247.JPEG 229 231 | ILSVRC2012_val_00000003.JPEG 230 232 | ILSVRC2012_val_00003480.JPEG 231 233 | ILSVRC2012_val_00000672.JPEG 232 234 | ILSVRC2012_val_00000203.JPEG 233 235 | ILSVRC2012_val_00003623.JPEG 234 236 | ILSVRC2012_val_00000946.JPEG 235 237 | ILSVRC2012_val_00000134.JPEG 236 238 | ILSVRC2012_val_00004386.JPEG 237 239 | ILSVRC2012_val_00000531.JPEG 238 240 | ILSVRC2012_val_00002708.JPEG 239 241 | ILSVRC2012_val_00000594.JPEG 240 242 | ILSVRC2012_val_00001005.JPEG 241 243 | ILSVRC2012_val_00000635.JPEG 242 244 | ILSVRC2012_val_00000975.JPEG 243 245 | ILSVRC2012_val_00004506.JPEG 244 246 | ILSVRC2012_val_00001313.JPEG 245 247 | ILSVRC2012_val_00001408.JPEG 246 248 | ILSVRC2012_val_00001937.JPEG 247 249 | ILSVRC2012_val_00000178.JPEG 248 250 | ILSVRC2012_val_00000729.JPEG 249 251 | ILSVRC2012_val_00000269.JPEG 250 252 | ILSVRC2012_val_00001702.JPEG 251 253 | ILSVRC2012_val_00002267.JPEG 252 254 | ILSVRC2012_val_00001344.JPEG 253 255 | ILSVRC2012_val_00000220.JPEG 254 256 | ILSVRC2012_val_00001203.JPEG 255 257 | ILSVRC2012_val_00000045.JPEG 256 258 | ILSVRC2012_val_00000174.JPEG 257 259 | ILSVRC2012_val_00000590.JPEG 258 260 | ILSVRC2012_val_00000057.JPEG 259 261 | ILSVRC2012_val_00000355.JPEG 260 262 | ILSVRC2012_val_00000658.JPEG 261 263 | ILSVRC2012_val_00001568.JPEG 262 264 | ILSVRC2012_val_00001096.JPEG 263 265 | ILSVRC2012_val_00000997.JPEG 264 266 | ILSVRC2012_val_00000147.JPEG 265 267 | ILSVRC2012_val_00000867.JPEG 266 268 | ILSVRC2012_val_00000399.JPEG 267 269 | ILSVRC2012_val_00000719.JPEG 268 270 | ILSVRC2012_val_00000530.JPEG 269 271 | ILSVRC2012_val_00000027.JPEG 270 272 | ILSVRC2012_val_00003172.JPEG 271 273 | ILSVRC2012_val_00000156.JPEG 272 274 | ILSVRC2012_val_00000604.JPEG 273 275 | ILSVRC2012_val_00001148.JPEG 274 276 | ILSVRC2012_val_00000078.JPEG 275 277 | ILSVRC2012_val_00000133.JPEG 276 278 | ILSVRC2012_val_00000157.JPEG 277 279 | ILSVRC2012_val_00000756.JPEG 278 280 | ILSVRC2012_val_00001335.JPEG 279 281 | ILSVRC2012_val_00000323.JPEG 280 282 | ILSVRC2012_val_00001341.JPEG 281 283 | ILSVRC2012_val_00000796.JPEG 282 284 | ILSVRC2012_val_00000130.JPEG 283 285 | ILSVRC2012_val_00000211.JPEG 284 286 | ILSVRC2012_val_00000709.JPEG 285 287 | ILSVRC2012_val_00000012.JPEG 286 288 | ILSVRC2012_val_00000898.JPEG 287 289 | ILSVRC2012_val_00000600.JPEG 288 290 | ILSVRC2012_val_00000186.JPEG 289 291 | ILSVRC2012_val_00000610.JPEG 290 292 | ILSVRC2012_val_00001359.JPEG 291 293 | ILSVRC2012_val_00000303.JPEG 292 294 | ILSVRC2012_val_00000348.JPEG 293 295 | ILSVRC2012_val_00000871.JPEG 294 296 | ILSVRC2012_val_00000592.JPEG 295 297 | ILSVRC2012_val_00001498.JPEG 296 298 | ILSVRC2012_val_00001745.JPEG 297 299 | ILSVRC2012_val_00000518.JPEG 298 300 | ILSVRC2012_val_00000779.JPEG 299 301 | ILSVRC2012_val_00000361.JPEG 300 302 | ILSVRC2012_val_00000625.JPEG 301 303 | ILSVRC2012_val_00000956.JPEG 302 304 | ILSVRC2012_val_00000926.JPEG 303 305 | ILSVRC2012_val_00000510.JPEG 304 306 | ILSVRC2012_val_00000342.JPEG 305 307 | ILSVRC2012_val_00000692.JPEG 306 308 | ILSVRC2012_val_00001570.JPEG 307 309 | ILSVRC2012_val_00001763.JPEG 308 310 | ILSVRC2012_val_00001876.JPEG 309 311 | ILSVRC2012_val_00000407.JPEG 310 312 | ILSVRC2012_val_00000125.JPEG 311 313 | ILSVRC2012_val_00001873.JPEG 312 314 | ILSVRC2012_val_00001585.JPEG 313 315 | ILSVRC2012_val_00001082.JPEG 314 316 | ILSVRC2012_val_00000772.JPEG 315 317 | ILSVRC2012_val_00002649.JPEG 316 318 | ILSVRC2012_val_00003032.JPEG 317 319 | ILSVRC2012_val_00001466.JPEG 318 320 | ILSVRC2012_val_00000418.JPEG 319 321 | ILSVRC2012_val_00000639.JPEG 320 322 | ILSVRC2012_val_00001187.JPEG 321 323 | ILSVRC2012_val_00000810.JPEG 322 324 | ILSVRC2012_val_00001290.JPEG 323 325 | ILSVRC2012_val_00000031.JPEG 324 326 | ILSVRC2012_val_00000581.JPEG 325 327 | ILSVRC2012_val_00001536.JPEG 326 328 | ILSVRC2012_val_00001161.JPEG 327 329 | ILSVRC2012_val_00000138.JPEG 328 330 | ILSVRC2012_val_00000301.JPEG 329 331 | ILSVRC2012_val_00000097.JPEG 330 332 | ILSVRC2012_val_00000644.JPEG 331 333 | ILSVRC2012_val_00000010.JPEG 332 334 | ILSVRC2012_val_00000617.JPEG 333 335 | ILSVRC2012_val_00000007.JPEG 334 336 | ILSVRC2012_val_00000502.JPEG 335 337 | ILSVRC2012_val_00000675.JPEG 336 338 | ILSVRC2012_val_00003202.JPEG 337 339 | ILSVRC2012_val_00000565.JPEG 338 340 | ILSVRC2012_val_00000179.JPEG 339 341 | ILSVRC2012_val_00002070.JPEG 340 342 | ILSVRC2012_val_00000478.JPEG 341 343 | ILSVRC2012_val_00002006.JPEG 342 344 | ILSVRC2012_val_00000136.JPEG 343 345 | ILSVRC2012_val_00001066.JPEG 344 346 | ILSVRC2012_val_00001280.JPEG 345 347 | ILSVRC2012_val_00000306.JPEG 346 348 | ILSVRC2012_val_00002433.JPEG 347 349 | ILSVRC2012_val_00001027.JPEG 348 350 | ILSVRC2012_val_00000100.JPEG 349 351 | ILSVRC2012_val_00000444.JPEG 350 352 | ILSVRC2012_val_00001232.JPEG 351 353 | ILSVRC2012_val_00000127.JPEG 352 354 | ILSVRC2012_val_00003125.JPEG 353 355 | ILSVRC2012_val_00002271.JPEG 354 356 | ILSVRC2012_val_00000343.JPEG 355 357 | ILSVRC2012_val_00000723.JPEG 356 358 | ILSVRC2012_val_00002028.JPEG 357 359 | ILSVRC2012_val_00000055.JPEG 358 360 | ILSVRC2012_val_00000326.JPEG 359 361 | ILSVRC2012_val_00001407.JPEG 360 362 | ILSVRC2012_val_00000181.JPEG 361 363 | ILSVRC2012_val_00000726.JPEG 362 364 | ILSVRC2012_val_00002448.JPEG 363 365 | ILSVRC2012_val_00001442.JPEG 364 366 | ILSVRC2012_val_00001535.JPEG 365 367 | ILSVRC2012_val_00000093.JPEG 366 368 | ILSVRC2012_val_00000553.JPEG 367 369 | ILSVRC2012_val_00003388.JPEG 368 370 | ILSVRC2012_val_00000087.JPEG 369 371 | ILSVRC2012_val_00000013.JPEG 370 372 | ILSVRC2012_val_00000964.JPEG 371 373 | ILSVRC2012_val_00000486.JPEG 372 374 | ILSVRC2012_val_00000095.JPEG 373 375 | ILSVRC2012_val_00001254.JPEG 374 376 | ILSVRC2012_val_00001095.JPEG 375 377 | ILSVRC2012_val_00000206.JPEG 376 378 | ILSVRC2012_val_00000520.JPEG 377 379 | ILSVRC2012_val_00002450.JPEG 378 380 | ILSVRC2012_val_00003118.JPEG 379 381 | ILSVRC2012_val_00000340.JPEG 380 382 | ILSVRC2012_val_00003812.JPEG 381 383 | ILSVRC2012_val_00002011.JPEG 382 384 | ILSVRC2012_val_00000092.JPEG 383 385 | ILSVRC2012_val_00001025.JPEG 384 386 | ILSVRC2012_val_00000597.JPEG 385 387 | ILSVRC2012_val_00001177.JPEG 386 388 | ILSVRC2012_val_00002863.JPEG 387 389 | ILSVRC2012_val_00000481.JPEG 388 390 | ILSVRC2012_val_00000312.JPEG 389 391 | ILSVRC2012_val_00000066.JPEG 390 392 | ILSVRC2012_val_00002364.JPEG 391 393 | ILSVRC2012_val_00000124.JPEG 392 394 | ILSVRC2012_val_00001176.JPEG 393 395 | ILSVRC2012_val_00000050.JPEG 394 396 | ILSVRC2012_val_00001975.JPEG 395 397 | ILSVRC2012_val_00000196.JPEG 396 398 | ILSVRC2012_val_00000699.JPEG 397 399 | ILSVRC2012_val_00000038.JPEG 398 400 | ILSVRC2012_val_00001285.JPEG 399 401 | ILSVRC2012_val_00000851.JPEG 400 402 | ILSVRC2012_val_00000204.JPEG 401 403 | ILSVRC2012_val_00001416.JPEG 402 404 | ILSVRC2012_val_00000289.JPEG 403 405 | ILSVRC2012_val_00000631.JPEG 404 406 | ILSVRC2012_val_00001602.JPEG 405 407 | ILSVRC2012_val_00000932.JPEG 406 408 | ILSVRC2012_val_00001910.JPEG 407 409 | ILSVRC2012_val_00000998.JPEG 408 410 | ILSVRC2012_val_00001011.JPEG 409 411 | ILSVRC2012_val_00002391.JPEG 410 412 | ILSVRC2012_val_00000771.JPEG 411 413 | ILSVRC2012_val_00000422.JPEG 412 414 | ILSVRC2012_val_00000268.JPEG 413 415 | ILSVRC2012_val_00000459.JPEG 414 416 | ILSVRC2012_val_00000008.JPEG 415 417 | ILSVRC2012_val_00000494.JPEG 416 418 | ILSVRC2012_val_00001731.JPEG 417 419 | ILSVRC2012_val_00000721.JPEG 418 420 | ILSVRC2012_val_00000400.JPEG 419 421 | ILSVRC2012_val_00000315.JPEG 420 422 | ILSVRC2012_val_00000272.JPEG 421 423 | ILSVRC2012_val_00004406.JPEG 422 424 | ILSVRC2012_val_00000759.JPEG 423 425 | ILSVRC2012_val_00000076.JPEG 424 426 | ILSVRC2012_val_00000234.JPEG 425 427 | ILSVRC2012_val_00000287.JPEG 426 428 | ILSVRC2012_val_00000687.JPEG 427 429 | ILSVRC2012_val_00002509.JPEG 428 430 | ILSVRC2012_val_00001468.JPEG 429 431 | ILSVRC2012_val_00001517.JPEG 430 432 | ILSVRC2012_val_00000354.JPEG 431 433 | ILSVRC2012_val_00001090.JPEG 432 434 | ILSVRC2012_val_00001013.JPEG 433 435 | ILSVRC2012_val_00001384.JPEG 434 436 | ILSVRC2012_val_00002563.JPEG 435 437 | ILSVRC2012_val_00000175.JPEG 436 438 | ILSVRC2012_val_00000483.JPEG 437 439 | ILSVRC2012_val_00000253.JPEG 438 440 | ILSVRC2012_val_00001048.JPEG 439 441 | ILSVRC2012_val_00000654.JPEG 440 442 | ILSVRC2012_val_00003782.JPEG 441 443 | ILSVRC2012_val_00001103.JPEG 442 444 | ILSVRC2012_val_00001252.JPEG 443 445 | ILSVRC2012_val_00000953.JPEG 444 446 | ILSVRC2012_val_00000239.JPEG 445 447 | ILSVRC2012_val_00003446.JPEG 446 448 | ILSVRC2012_val_00000917.JPEG 447 449 | ILSVRC2012_val_00000514.JPEG 448 450 | ILSVRC2012_val_00000193.JPEG 449 451 | ILSVRC2012_val_00000601.JPEG 450 452 | ILSVRC2012_val_00000763.JPEG 451 453 | ILSVRC2012_val_00000906.JPEG 452 454 | ILSVRC2012_val_00003274.JPEG 453 455 | ILSVRC2012_val_00000879.JPEG 454 456 | ILSVRC2012_val_00000634.JPEG 455 457 | ILSVRC2012_val_00001354.JPEG 456 458 | ILSVRC2012_val_00000126.JPEG 457 459 | ILSVRC2012_val_00001145.JPEG 458 460 | ILSVRC2012_val_00000808.JPEG 459 461 | ILSVRC2012_val_00000358.JPEG 460 462 | ILSVRC2012_val_00000080.JPEG 461 463 | ILSVRC2012_val_00002937.JPEG 462 464 | ILSVRC2012_val_00000226.JPEG 463 465 | ILSVRC2012_val_00000558.JPEG 464 466 | ILSVRC2012_val_00000366.JPEG 465 467 | ILSVRC2012_val_00000562.JPEG 466 468 | ILSVRC2012_val_00000071.JPEG 467 469 | ILSVRC2012_val_00000056.JPEG 468 470 | ILSVRC2012_val_00000302.JPEG 469 471 | ILSVRC2012_val_00002596.JPEG 470 472 | ILSVRC2012_val_00000392.JPEG 471 473 | ILSVRC2012_val_00000347.JPEG 472 474 | ILSVRC2012_val_00000101.JPEG 473 475 | ILSVRC2012_val_00000061.JPEG 474 476 | ILSVRC2012_val_00000213.JPEG 475 477 | ILSVRC2012_val_00000074.JPEG 476 478 | ILSVRC2012_val_00001041.JPEG 477 479 | ILSVRC2012_val_00000019.JPEG 478 480 | ILSVRC2012_val_00000085.JPEG 479 481 | ILSVRC2012_val_00000580.JPEG 480 482 | ILSVRC2012_val_00000353.JPEG 481 483 | ILSVRC2012_val_00000557.JPEG 482 484 | ILSVRC2012_val_00000122.JPEG 483 485 | ILSVRC2012_val_00000402.JPEG 484 486 | ILSVRC2012_val_00003491.JPEG 485 487 | ILSVRC2012_val_00000108.JPEG 486 488 | ILSVRC2012_val_00000089.JPEG 487 489 | ILSVRC2012_val_00000376.JPEG 488 490 | ILSVRC2012_val_00001094.JPEG 489 491 | ILSVRC2012_val_00002280.JPEG 490 492 | ILSVRC2012_val_00000537.JPEG 491 493 | ILSVRC2012_val_00001168.JPEG 492 494 | ILSVRC2012_val_00000162.JPEG 493 495 | ILSVRC2012_val_00001601.JPEG 494 496 | ILSVRC2012_val_00000346.JPEG 495 497 | ILSVRC2012_val_00001908.JPEG 496 498 | ILSVRC2012_val_00003351.JPEG 497 499 | ILSVRC2012_val_00000086.JPEG 498 500 | ILSVRC2012_val_00000564.JPEG 499 501 | ILSVRC2012_val_00001426.JPEG 500 502 | ILSVRC2012_val_00000750.JPEG 501 503 | ILSVRC2012_val_00001500.JPEG 502 504 | ILSVRC2012_val_00000344.JPEG 503 505 | ILSVRC2012_val_00000164.JPEG 504 506 | ILSVRC2012_val_00001940.JPEG 505 507 | ILSVRC2012_val_00000915.JPEG 506 508 | ILSVRC2012_val_00004971.JPEG 507 509 | ILSVRC2012_val_00000321.JPEG 508 510 | ILSVRC2012_val_00001190.JPEG 509 511 | ILSVRC2012_val_00000403.JPEG 510 512 | ILSVRC2012_val_00001201.JPEG 511 513 | ILSVRC2012_val_00000678.JPEG 512 514 | ILSVRC2012_val_00002171.JPEG 513 515 | ILSVRC2012_val_00000766.JPEG 514 516 | ILSVRC2012_val_00001875.JPEG 515 517 | ILSVRC2012_val_00000005.JPEG 516 518 | ILSVRC2012_val_00000020.JPEG 517 519 | ILSVRC2012_val_00001689.JPEG 518 520 | ILSVRC2012_val_00001411.JPEG 519 521 | ILSVRC2012_val_00003036.JPEG 520 522 | ILSVRC2012_val_00000534.JPEG 521 523 | ILSVRC2012_val_00000249.JPEG 522 524 | ILSVRC2012_val_00001337.JPEG 523 525 | ILSVRC2012_val_00000730.JPEG 524 526 | ILSVRC2012_val_00000330.JPEG 525 527 | ILSVRC2012_val_00000777.JPEG 526 528 | ILSVRC2012_val_00001246.JPEG 527 529 | ILSVRC2012_val_00000137.JPEG 528 530 | ILSVRC2012_val_00000493.JPEG 529 531 | ILSVRC2012_val_00003643.JPEG 530 532 | ILSVRC2012_val_00000339.JPEG 531 533 | ILSVRC2012_val_00000241.JPEG 532 534 | ILSVRC2012_val_00000256.JPEG 533 535 | ILSVRC2012_val_00000698.JPEG 534 536 | ILSVRC2012_val_00002489.JPEG 535 537 | ILSVRC2012_val_00001729.JPEG 536 538 | ILSVRC2012_val_00002266.JPEG 537 539 | ILSVRC2012_val_00001868.JPEG 538 540 | ILSVRC2012_val_00001893.JPEG 539 541 | ILSVRC2012_val_00000504.JPEG 540 542 | ILSVRC2012_val_00001609.JPEG 541 543 | ILSVRC2012_val_00000542.JPEG 542 544 | ILSVRC2012_val_00001045.JPEG 543 545 | ILSVRC2012_val_00000182.JPEG 544 546 | ILSVRC2012_val_00000942.JPEG 545 547 | ILSVRC2012_val_00000374.JPEG 546 548 | ILSVRC2012_val_00001699.JPEG 547 549 | ILSVRC2012_val_00000199.JPEG 548 550 | ILSVRC2012_val_00000419.JPEG 549 551 | ILSVRC2012_val_00001545.JPEG 550 552 | ILSVRC2012_val_00000624.JPEG 551 553 | ILSVRC2012_val_00001256.JPEG 552 554 | ILSVRC2012_val_00000132.JPEG 553 555 | ILSVRC2012_val_00000288.JPEG 554 556 | ILSVRC2012_val_00000764.JPEG 555 557 | ILSVRC2012_val_00000743.JPEG 556 558 | ILSVRC2012_val_00000445.JPEG 557 559 | ILSVRC2012_val_00000140.JPEG 558 560 | ILSVRC2012_val_00001616.JPEG 559 561 | ILSVRC2012_val_00002562.JPEG 560 562 | ILSVRC2012_val_00001579.JPEG 561 563 | ILSVRC2012_val_00001018.JPEG 562 564 | ILSVRC2012_val_00002799.JPEG 563 565 | ILSVRC2012_val_00002945.JPEG 564 566 | ILSVRC2012_val_00000047.JPEG 565 567 | ILSVRC2012_val_00000957.JPEG 566 568 | ILSVRC2012_val_00000516.JPEG 567 569 | ILSVRC2012_val_00000641.JPEG 568 570 | ILSVRC2012_val_00001144.JPEG 569 571 | ILSVRC2012_val_00002290.JPEG 570 572 | ILSVRC2012_val_00000732.JPEG 571 573 | ILSVRC2012_val_00000311.JPEG 572 574 | ILSVRC2012_val_00000032.JPEG 573 575 | ILSVRC2012_val_00002314.JPEG 574 576 | ILSVRC2012_val_00001116.JPEG 575 577 | ILSVRC2012_val_00000913.JPEG 576 578 | ILSVRC2012_val_00001052.JPEG 577 579 | ILSVRC2012_val_00000780.JPEG 578 580 | ILSVRC2012_val_00000662.JPEG 579 581 | ILSVRC2012_val_00000295.JPEG 580 582 | ILSVRC2012_val_00000852.JPEG 581 583 | ILSVRC2012_val_00000372.JPEG 582 584 | ILSVRC2012_val_00000169.JPEG 583 585 | ILSVRC2012_val_00000190.JPEG 584 586 | ILSVRC2012_val_00000856.JPEG 585 587 | ILSVRC2012_val_00000035.JPEG 586 588 | ILSVRC2012_val_00000887.JPEG 587 589 | ILSVRC2012_val_00000060.JPEG 588 590 | ILSVRC2012_val_00000876.JPEG 589 591 | ILSVRC2012_val_00000149.JPEG 590 592 | ILSVRC2012_val_00000054.JPEG 591 593 | ILSVRC2012_val_00000367.JPEG 592 594 | ILSVRC2012_val_00001030.JPEG 593 595 | ILSVRC2012_val_00000448.JPEG 594 596 | ILSVRC2012_val_00000015.JPEG 595 597 | ILSVRC2012_val_00002436.JPEG 596 598 | ILSVRC2012_val_00002822.JPEG 597 599 | ILSVRC2012_val_00000387.JPEG 598 600 | ILSVRC2012_val_00001784.JPEG 599 601 | ILSVRC2012_val_00000575.JPEG 600 602 | ILSVRC2012_val_00000776.JPEG 601 603 | ILSVRC2012_val_00001827.JPEG 602 604 | ILSVRC2012_val_00000752.JPEG 603 605 | ILSVRC2012_val_00001943.JPEG 604 606 | ILSVRC2012_val_00000647.JPEG 605 607 | ILSVRC2012_val_00000208.JPEG 606 608 | ILSVRC2012_val_00000308.JPEG 607 609 | ILSVRC2012_val_00000110.JPEG 608 610 | ILSVRC2012_val_00000659.JPEG 609 611 | ILSVRC2012_val_00000219.JPEG 610 612 | ILSVRC2012_val_00000607.JPEG 611 613 | ILSVRC2012_val_00000281.JPEG 612 614 | ILSVRC2012_val_00000969.JPEG 613 615 | ILSVRC2012_val_00000900.JPEG 614 616 | ILSVRC2012_val_00000686.JPEG 615 617 | ILSVRC2012_val_00002377.JPEG 616 618 | ILSVRC2012_val_00002627.JPEG 617 619 | ILSVRC2012_val_00000684.JPEG 618 620 | ILSVRC2012_val_00000427.JPEG 619 621 | ILSVRC2012_val_00000151.JPEG 620 622 | ILSVRC2012_val_00000384.JPEG 621 623 | ILSVRC2012_val_00000359.JPEG 622 624 | ILSVRC2012_val_00001977.JPEG 623 625 | ILSVRC2012_val_00000715.JPEG 624 626 | ILSVRC2012_val_00000556.JPEG 625 627 | ILSVRC2012_val_00003417.JPEG 626 628 | ILSVRC2012_val_00000185.JPEG 627 629 | ILSVRC2012_val_00000664.JPEG 628 630 | ILSVRC2012_val_00000984.JPEG 629 631 | ILSVRC2012_val_00000385.JPEG 630 632 | ILSVRC2012_val_00001814.JPEG 631 633 | ILSVRC2012_val_00000109.JPEG 632 634 | ILSVRC2012_val_00000999.JPEG 633 635 | ILSVRC2012_val_00000760.JPEG 634 636 | ILSVRC2012_val_00001229.JPEG 635 637 | ILSVRC2012_val_00000221.JPEG 636 638 | ILSVRC2012_val_00000152.JPEG 637 639 | ILSVRC2012_val_00000117.JPEG 638 640 | ILSVRC2012_val_00001420.JPEG 639 641 | ILSVRC2012_val_00002223.JPEG 640 642 | ILSVRC2012_val_00000710.JPEG 641 643 | ILSVRC2012_val_00001796.JPEG 642 644 | ILSVRC2012_val_00000608.JPEG 643 645 | ILSVRC2012_val_00000259.JPEG 644 646 | ILSVRC2012_val_00000120.JPEG 645 647 | ILSVRC2012_val_00000118.JPEG 646 648 | ILSVRC2012_val_00000163.JPEG 647 649 | ILSVRC2012_val_00001318.JPEG 648 650 | ILSVRC2012_val_00000382.JPEG 649 651 | ILSVRC2012_val_00001982.JPEG 650 652 | ILSVRC2012_val_00000519.JPEG 651 653 | ILSVRC2012_val_00003199.JPEG 652 654 | ILSVRC2012_val_00001170.JPEG 653 655 | ILSVRC2012_val_00000314.JPEG 654 656 | ILSVRC2012_val_00001192.JPEG 655 657 | ILSVRC2012_val_00001722.JPEG 656 658 | ILSVRC2012_val_00000700.JPEG 657 659 | ILSVRC2012_val_00001574.JPEG 658 660 | ILSVRC2012_val_00000173.JPEG 659 661 | ILSVRC2012_val_00000254.JPEG 660 662 | ILSVRC2012_val_00002010.JPEG 661 663 | ILSVRC2012_val_00000222.JPEG 662 664 | ILSVRC2012_val_00000670.JPEG 663 665 | ILSVRC2012_val_00000119.JPEG 664 666 | ILSVRC2012_val_00000717.JPEG 665 667 | ILSVRC2012_val_00000168.JPEG 666 668 | ILSVRC2012_val_00000923.JPEG 667 669 | ILSVRC2012_val_00000680.JPEG 668 670 | ILSVRC2012_val_00000859.JPEG 669 671 | ILSVRC2012_val_00001892.JPEG 670 672 | ILSVRC2012_val_00001659.JPEG 671 673 | ILSVRC2012_val_00002082.JPEG 672 674 | ILSVRC2012_val_00000803.JPEG 673 675 | ILSVRC2012_val_00000009.JPEG 674 676 | ILSVRC2012_val_00000365.JPEG 675 677 | ILSVRC2012_val_00000277.JPEG 676 678 | ILSVRC2012_val_00000649.JPEG 677 679 | ILSVRC2012_val_00000243.JPEG 678 680 | ILSVRC2012_val_00000850.JPEG 679 681 | ILSVRC2012_val_00000540.JPEG 680 682 | ILSVRC2012_val_00000677.JPEG 681 683 | ILSVRC2012_val_00002278.JPEG 682 684 | ILSVRC2012_val_00001464.JPEG 683 685 | ILSVRC2012_val_00000521.JPEG 684 686 | ILSVRC2012_val_00001380.JPEG 685 687 | ILSVRC2012_val_00000114.JPEG 686 688 | ILSVRC2012_val_00000166.JPEG 687 689 | ILSVRC2012_val_00001016.JPEG 688 690 | ILSVRC2012_val_00001961.JPEG 689 691 | ILSVRC2012_val_00001581.JPEG 690 692 | ILSVRC2012_val_00000549.JPEG 691 693 | ILSVRC2012_val_00002069.JPEG 692 694 | ILSVRC2012_val_00003004.JPEG 693 695 | ILSVRC2012_val_00000264.JPEG 694 696 | ILSVRC2012_val_00000265.JPEG 695 697 | ILSVRC2012_val_00000360.JPEG 696 698 | ILSVRC2012_val_00001424.JPEG 697 699 | ILSVRC2012_val_00003804.JPEG 698 700 | ILSVRC2012_val_00000681.JPEG 699 701 | ILSVRC2012_val_00001833.JPEG 700 702 | ILSVRC2012_val_00001440.JPEG 701 703 | ILSVRC2012_val_00000240.JPEG 702 704 | ILSVRC2012_val_00000192.JPEG 703 705 | ILSVRC2012_val_00000201.JPEG 704 706 | ILSVRC2012_val_00000096.JPEG 705 707 | ILSVRC2012_val_00001665.JPEG 706 708 | ILSVRC2012_val_00000335.JPEG 707 709 | ILSVRC2012_val_00002193.JPEG 708 710 | ILSVRC2012_val_00001265.JPEG 709 711 | ILSVRC2012_val_00004441.JPEG 710 712 | ILSVRC2012_val_00001754.JPEG 711 713 | ILSVRC2012_val_00000465.JPEG 712 714 | ILSVRC2012_val_00000489.JPEG 713 715 | ILSVRC2012_val_00001365.JPEG 714 716 | ILSVRC2012_val_00000834.JPEG 715 717 | ILSVRC2012_val_00000626.JPEG 716 718 | ILSVRC2012_val_00000049.JPEG 717 719 | ILSVRC2012_val_00000121.JPEG 718 720 | ILSVRC2012_val_00001023.JPEG 719 721 | ILSVRC2012_val_00000112.JPEG 720 722 | ILSVRC2012_val_00000438.JPEG 721 723 | ILSVRC2012_val_00000428.JPEG 722 724 | ILSVRC2012_val_00000991.JPEG 723 725 | ILSVRC2012_val_00000966.JPEG 724 726 | ILSVRC2012_val_00000046.JPEG 725 727 | ILSVRC2012_val_00000144.JPEG 726 728 | ILSVRC2012_val_00000024.JPEG 727 729 | ILSVRC2012_val_00000722.JPEG 728 730 | ILSVRC2012_val_00001853.JPEG 729 731 | ILSVRC2012_val_00000292.JPEG 730 732 | ILSVRC2012_val_00002729.JPEG 731 733 | ILSVRC2012_val_00000765.JPEG 732 734 | ILSVRC2012_val_00000334.JPEG 733 735 | ILSVRC2012_val_00000452.JPEG 734 736 | ILSVRC2012_val_00004159.JPEG 735 737 | ILSVRC2012_val_00002385.JPEG 736 738 | ILSVRC2012_val_00001970.JPEG 737 739 | ILSVRC2012_val_00001919.JPEG 738 740 | ILSVRC2012_val_00000458.JPEG 739 741 | ILSVRC2012_val_00000589.JPEG 740 742 | ILSVRC2012_val_00000230.JPEG 741 743 | ILSVRC2012_val_00002985.JPEG 742 744 | ILSVRC2012_val_00002554.JPEG 743 745 | ILSVRC2012_val_00000468.JPEG 744 746 | ILSVRC2012_val_00001253.JPEG 745 747 | ILSVRC2012_val_00000454.JPEG 746 748 | ILSVRC2012_val_00002782.JPEG 747 749 | ILSVRC2012_val_00001343.JPEG 748 750 | ILSVRC2012_val_00001345.JPEG 749 751 | ILSVRC2012_val_00000260.JPEG 750 752 | ILSVRC2012_val_00000135.JPEG 751 753 | ILSVRC2012_val_00001786.JPEG 752 754 | ILSVRC2012_val_00000770.JPEG 753 755 | ILSVRC2012_val_00001346.JPEG 754 756 | ILSVRC2012_val_00002046.JPEG 755 757 | ILSVRC2012_val_00000042.JPEG 756 758 | ILSVRC2012_val_00000014.JPEG 757 759 | ILSVRC2012_val_00000473.JPEG 758 760 | ILSVRC2012_val_00001085.JPEG 759 761 | ILSVRC2012_val_00000331.JPEG 760 762 | ILSVRC2012_val_00000275.JPEG 761 763 | ILSVRC2012_val_00000591.JPEG 762 764 | ILSVRC2012_val_00000158.JPEG 763 765 | ILSVRC2012_val_00000244.JPEG 764 766 | ILSVRC2012_val_00000421.JPEG 765 767 | ILSVRC2012_val_00001625.JPEG 766 768 | ILSVRC2012_val_00000453.JPEG 767 769 | ILSVRC2012_val_00001497.JPEG 768 770 | ILSVRC2012_val_00001210.JPEG 769 771 | ILSVRC2012_val_00001465.JPEG 770 772 | ILSVRC2012_val_00000143.JPEG 771 773 | ILSVRC2012_val_00002068.JPEG 772 774 | ILSVRC2012_val_00000792.JPEG 773 775 | ILSVRC2012_val_00002426.JPEG 774 776 | ILSVRC2012_val_00000442.JPEG 775 777 | ILSVRC2012_val_00000614.JPEG 776 778 | ILSVRC2012_val_00000039.JPEG 777 779 | ILSVRC2012_val_00001091.JPEG 778 780 | ILSVRC2012_val_00001736.JPEG 779 781 | ILSVRC2012_val_00000094.JPEG 780 782 | ILSVRC2012_val_00000167.JPEG 781 783 | ILSVRC2012_val_00001269.JPEG 782 784 | ILSVRC2012_val_00001639.JPEG 783 785 | ILSVRC2012_val_00000310.JPEG 784 786 | ILSVRC2012_val_00001792.JPEG 785 787 | ILSVRC2012_val_00000377.JPEG 786 788 | ILSVRC2012_val_00000232.JPEG 787 789 | ILSVRC2012_val_00000083.JPEG 788 790 | ILSVRC2012_val_00000159.JPEG 789 791 | ILSVRC2012_val_00000629.JPEG 790 792 | ILSVRC2012_val_00001316.JPEG 791 793 | ILSVRC2012_val_00002065.JPEG 792 794 | ILSVRC2012_val_00002394.JPEG 793 795 | ILSVRC2012_val_00000527.JPEG 794 796 | ILSVRC2012_val_00000568.JPEG 795 797 | ILSVRC2012_val_00000550.JPEG 796 798 | ILSVRC2012_val_00001395.JPEG 797 799 | ILSVRC2012_val_00000266.JPEG 798 800 | ILSVRC2012_val_00002286.JPEG 799 801 | ILSVRC2012_val_00000857.JPEG 800 802 | ILSVRC2012_val_00000508.JPEG 801 803 | ILSVRC2012_val_00000131.JPEG 802 804 | ILSVRC2012_val_00000767.JPEG 803 805 | ILSVRC2012_val_00003585.JPEG 804 806 | ILSVRC2012_val_00002349.JPEG 805 807 | ILSVRC2012_val_00000936.JPEG 806 808 | ILSVRC2012_val_00000897.JPEG 807 809 | ILSVRC2012_val_00002335.JPEG 808 810 | ILSVRC2012_val_00000004.JPEG 809 811 | ILSVRC2012_val_00001672.JPEG 810 812 | ILSVRC2012_val_00004382.JPEG 811 813 | ILSVRC2012_val_00000559.JPEG 812 814 | ILSVRC2012_val_00001587.JPEG 813 815 | ILSVRC2012_val_00000901.JPEG 814 816 | ILSVRC2012_val_00002128.JPEG 815 817 | ILSVRC2012_val_00000349.JPEG 816 818 | ILSVRC2012_val_00001247.JPEG 817 819 | ILSVRC2012_val_00000345.JPEG 818 820 | ILSVRC2012_val_00001388.JPEG 819 821 | ILSVRC2012_val_00000544.JPEG 820 822 | ILSVRC2012_val_00000271.JPEG 821 823 | ILSVRC2012_val_00000842.JPEG 822 824 | ILSVRC2012_val_00001183.JPEG 823 825 | ILSVRC2012_val_00001059.JPEG 824 826 | ILSVRC2012_val_00000171.JPEG 825 827 | ILSVRC2012_val_00000290.JPEG 826 828 | ILSVRC2012_val_00000205.JPEG 827 829 | ILSVRC2012_val_00001470.JPEG 828 830 | ILSVRC2012_val_00000714.JPEG 829 831 | ILSVRC2012_val_00000535.JPEG 830 832 | ILSVRC2012_val_00001811.JPEG 831 833 | ILSVRC2012_val_00000522.JPEG 832 834 | ILSVRC2012_val_00000595.JPEG 833 835 | ILSVRC2012_val_00000380.JPEG 834 836 | ILSVRC2012_val_00002015.JPEG 835 837 | ILSVRC2012_val_00000413.JPEG 836 838 | ILSVRC2012_val_00000815.JPEG 837 839 | ILSVRC2012_val_00001102.JPEG 838 840 | ILSVRC2012_val_00000778.JPEG 839 841 | ILSVRC2012_val_00000701.JPEG 840 842 | ILSVRC2012_val_00000070.JPEG 841 843 | ILSVRC2012_val_00000065.JPEG 842 844 | ILSVRC2012_val_00001086.JPEG 843 845 | ILSVRC2012_val_00000053.JPEG 844 846 | ILSVRC2012_val_00002160.JPEG 845 847 | ILSVRC2012_val_00000026.JPEG 846 848 | ILSVRC2012_val_00002120.JPEG 847 849 | ILSVRC2012_val_00000627.JPEG 848 850 | ILSVRC2012_val_00000464.JPEG 849 851 | ILSVRC2012_val_00000602.JPEG 850 852 | ILSVRC2012_val_00001532.JPEG 851 853 | ILSVRC2012_val_00000123.JPEG 852 854 | ILSVRC2012_val_00003312.JPEG 853 855 | ILSVRC2012_val_00002206.JPEG 854 856 | ILSVRC2012_val_00001263.JPEG 855 857 | ILSVRC2012_val_00000416.JPEG 856 858 | ILSVRC2012_val_00001160.JPEG 857 859 | ILSVRC2012_val_00000030.JPEG 858 860 | ILSVRC2012_val_00000737.JPEG 859 861 | ILSVRC2012_val_00000561.JPEG 860 862 | ILSVRC2012_val_00000449.JPEG 861 863 | ILSVRC2012_val_00001008.JPEG 862 864 | ILSVRC2012_val_00000525.JPEG 863 865 | ILSVRC2012_val_00000671.JPEG 864 866 | ILSVRC2012_val_00000261.JPEG 865 867 | ILSVRC2012_val_00000294.JPEG 866 868 | ILSVRC2012_val_00000187.JPEG 867 869 | ILSVRC2012_val_00000513.JPEG 868 870 | ILSVRC2012_val_00000291.JPEG 869 871 | ILSVRC2012_val_00000069.JPEG 870 872 | ILSVRC2012_val_00001519.JPEG 871 873 | ILSVRC2012_val_00000059.JPEG 872 874 | ILSVRC2012_val_00001325.JPEG 873 875 | ILSVRC2012_val_00000573.JPEG 874 876 | ILSVRC2012_val_00000146.JPEG 875 877 | ILSVRC2012_val_00000874.JPEG 876 878 | ILSVRC2012_val_00001750.JPEG 877 879 | ILSVRC2012_val_00000104.JPEG 878 880 | ILSVRC2012_val_00000381.JPEG 879 881 | ILSVRC2012_val_00002599.JPEG 880 882 | ILSVRC2012_val_00000284.JPEG 881 883 | ILSVRC2012_val_00000712.JPEG 882 884 | ILSVRC2012_val_00000895.JPEG 883 885 | ILSVRC2012_val_00001002.JPEG 884 886 | ILSVRC2012_val_00000433.JPEG 885 887 | ILSVRC2012_val_00001881.JPEG 886 888 | ILSVRC2012_val_00000036.JPEG 887 889 | ILSVRC2012_val_00000296.JPEG 888 890 | ILSVRC2012_val_00000212.JPEG 889 891 | ILSVRC2012_val_00001680.JPEG 890 892 | ILSVRC2012_val_00000797.JPEG 891 893 | ILSVRC2012_val_00004307.JPEG 892 894 | ILSVRC2012_val_00000866.JPEG 893 895 | ILSVRC2012_val_00000482.JPEG 894 896 | ILSVRC2012_val_00000369.JPEG 895 897 | ILSVRC2012_val_00000538.JPEG 896 898 | ILSVRC2012_val_00000280.JPEG 897 899 | ILSVRC2012_val_00000745.JPEG 898 900 | ILSVRC2012_val_00000286.JPEG 899 901 | ILSVRC2012_val_00000586.JPEG 900 902 | ILSVRC2012_val_00002867.JPEG 901 903 | ILSVRC2012_val_00000497.JPEG 902 904 | ILSVRC2012_val_00000854.JPEG 903 905 | ILSVRC2012_val_00000351.JPEG 904 906 | ILSVRC2012_val_00000587.JPEG 905 907 | ILSVRC2012_val_00000106.JPEG 906 908 | ILSVRC2012_val_00000805.JPEG 907 909 | ILSVRC2012_val_00001901.JPEG 908 910 | ILSVRC2012_val_00001328.JPEG 909 911 | ILSVRC2012_val_00002685.JPEG 910 912 | ILSVRC2012_val_00001208.JPEG 911 913 | ILSVRC2012_val_00000437.JPEG 912 914 | ILSVRC2012_val_00001034.JPEG 913 915 | ILSVRC2012_val_00003145.JPEG 914 916 | ILSVRC2012_val_00000460.JPEG 915 917 | ILSVRC2012_val_00000830.JPEG 916 918 | ILSVRC2012_val_00000985.JPEG 917 919 | ILSVRC2012_val_00002362.JPEG 918 920 | ILSVRC2012_val_00000903.JPEG 919 921 | ILSVRC2012_val_00000274.JPEG 920 922 | ILSVRC2012_val_00000980.JPEG 921 923 | ILSVRC2012_val_00000209.JPEG 922 924 | ILSVRC2012_val_00001611.JPEG 923 925 | ILSVRC2012_val_00000728.JPEG 924 926 | ILSVRC2012_val_00000267.JPEG 925 927 | ILSVRC2012_val_00001512.JPEG 926 928 | ILSVRC2012_val_00000620.JPEG 927 929 | ILSVRC2012_val_00000574.JPEG 928 930 | ILSVRC2012_val_00002619.JPEG 929 931 | ILSVRC2012_val_00000820.JPEG 930 932 | ILSVRC2012_val_00001405.JPEG 931 933 | ILSVRC2012_val_00000891.JPEG 932 934 | ILSVRC2012_val_00002226.JPEG 933 935 | ILSVRC2012_val_00000129.JPEG 934 936 | ILSVRC2012_val_00000183.JPEG 935 937 | ILSVRC2012_val_00001826.JPEG 936 938 | ILSVRC2012_val_00000155.JPEG 937 939 | ILSVRC2012_val_00000332.JPEG 938 940 | ILSVRC2012_val_00000554.JPEG 939 941 | ILSVRC2012_val_00000461.JPEG 940 942 | ILSVRC2012_val_00001932.JPEG 941 943 | ILSVRC2012_val_00000674.JPEG 942 944 | ILSVRC2012_val_00000463.JPEG 943 945 | ILSVRC2012_val_00000443.JPEG 944 946 | ILSVRC2012_val_00000423.JPEG 945 947 | ILSVRC2012_val_00000881.JPEG 946 948 | ILSVRC2012_val_00002490.JPEG 947 949 | ILSVRC2012_val_00000023.JPEG 948 950 | ILSVRC2012_val_00000099.JPEG 949 951 | ILSVRC2012_val_00001406.JPEG 950 952 | ILSVRC2012_val_00000711.JPEG 951 953 | ILSVRC2012_val_00002124.JPEG 952 954 | ILSVRC2012_val_00001479.JPEG 953 955 | ILSVRC2012_val_00000605.JPEG 954 956 | ILSVRC2012_val_00000691.JPEG 955 957 | ILSVRC2012_val_00000798.JPEG 956 958 | ILSVRC2012_val_00001514.JPEG 957 959 | ILSVRC2012_val_00000816.JPEG 958 960 | ILSVRC2012_val_00000116.JPEG 959 961 | ILSVRC2012_val_00000782.JPEG 960 962 | ILSVRC2012_val_00000543.JPEG 961 963 | ILSVRC2012_val_00001115.JPEG 962 964 | ILSVRC2012_val_00000773.JPEG 963 965 | ILSVRC2012_val_00000491.JPEG 964 966 | ILSVRC2012_val_00000572.JPEG 965 967 | ILSVRC2012_val_00001064.JPEG 966 968 | ILSVRC2012_val_00002934.JPEG 967 969 | ILSVRC2012_val_00001726.JPEG 968 970 | ILSVRC2012_val_00000139.JPEG 969 971 | ILSVRC2012_val_00000002.JPEG 970 972 | ILSVRC2012_val_00000697.JPEG 971 973 | ILSVRC2012_val_00001444.JPEG 972 974 | ILSVRC2012_val_00000235.JPEG 973 975 | ILSVRC2012_val_00000394.JPEG 974 976 | ILSVRC2012_val_00000150.JPEG 975 977 | ILSVRC2012_val_00002476.JPEG 976 978 | ILSVRC2012_val_00000145.JPEG 977 979 | ILSVRC2012_val_00000214.JPEG 978 980 | ILSVRC2012_val_00000313.JPEG 979 981 | ILSVRC2012_val_00002103.JPEG 980 982 | ILSVRC2012_val_00000034.JPEG 981 983 | ILSVRC2012_val_00000341.JPEG 982 984 | ILSVRC2012_val_00000255.JPEG 983 985 | ILSVRC2012_val_00000216.JPEG 984 986 | ILSVRC2012_val_00004514.JPEG 985 987 | ILSVRC2012_val_00000551.JPEG 986 988 | ILSVRC2012_val_00000484.JPEG 987 989 | ILSVRC2012_val_00001058.JPEG 988 990 | ILSVRC2012_val_00000251.JPEG 989 991 | ILSVRC2012_val_00001120.JPEG 990 992 | ILSVRC2012_val_00000304.JPEG 991 993 | ILSVRC2012_val_00000731.JPEG 992 994 | ILSVRC2012_val_00000200.JPEG 993 995 | ILSVRC2012_val_00000058.JPEG 994 996 | ILSVRC2012_val_00000755.JPEG 995 997 | ILSVRC2012_val_00002167.JPEG 996 998 | ILSVRC2012_val_00001362.JPEG 997 999 | ILSVRC2012_val_00000408.JPEG 998 1000 | ILSVRC2012_val_00001079.JPEG 999 1001 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import torch.optim 8 | import torchvision.utils as vutils 9 | from torch.utils.data import DataLoader 10 | 11 | import models 12 | import utils 13 | from dataset import ImageDataset 14 | 15 | 16 | class Trainer(object): 17 | 18 | def __init__(self, config): 19 | self.rank, self.world_size = 0, 1 20 | if config['dist']: 21 | self.rank = dist.get_rank() 22 | self.world_size = dist.get_world_size() 23 | 24 | self.mode = config['dgp_mode'] 25 | assert self.mode in [ 26 | 'reconstruct', 'colorization', 'SR', 'hybrid', 'inpainting', 27 | 'morphing', 'defence', 'jitter' 28 | ] 29 | 30 | if self.rank == 0: 31 | # mkdir path 32 | if not os.path.exists('{}/images'.format(config['exp_path'])): 33 | os.makedirs('{}/images'.format(config['exp_path'])) 34 | if not os.path.exists('{}/images_sheet'.format( 35 | config['exp_path'])): 36 | os.makedirs('{}/images_sheet'.format(config['exp_path'])) 37 | if not os.path.exists('{}/logs'.format(config['exp_path'])): 38 | os.makedirs('{}/logs'.format(config['exp_path'])) 39 | 40 | # prepare logger 41 | if not config['no_tb']: 42 | try: 43 | from tensorboardX import SummaryWriter 44 | except ImportError: 45 | raise Exception("Please switch off \"tensorboard\" " 46 | "in your config file if you do not " 47 | "want to use it, otherwise install it.") 48 | self.tb_logger = SummaryWriter('{}'.format(config['exp_path'])) 49 | else: 50 | self.tb_logger = None 51 | 52 | self.logger = utils.create_logger( 53 | 'global_logger', 54 | '{}/logs/log_train.txt'.format(config['exp_path'])) 55 | 56 | self.model = models.DGP(config) 57 | if self.mode == 'morphing': 58 | self.model2 = models.DGP(config) 59 | self.model_interp = models.DGP(config) 60 | 61 | # Data loader 62 | train_dataset = ImageDataset( 63 | config['root_dir'], 64 | config['list_file'], 65 | image_size=config['resolution'], 66 | normalize=True) 67 | sampler = utils.DistributedSampler( 68 | train_dataset) if config['dist'] else None 69 | self.train_loader = DataLoader( 70 | train_dataset, 71 | batch_size=1, 72 | shuffle=False, 73 | sampler=sampler, 74 | num_workers=1, 75 | pin_memory=False) 76 | self.config = config 77 | 78 | def run(self): 79 | # train 80 | if self.mode == 'morphing': 81 | self.train_morphing() 82 | else: 83 | self.train() 84 | 85 | def train(self): 86 | btime_rec = utils.AverageMeter() 87 | dtime_rec = utils.AverageMeter() 88 | recorder = {} 89 | end = time.time() 90 | for i, (image, category, img_path) in enumerate(self.train_loader): 91 | # measure data loading time 92 | dtime_rec.update(time.time() - end) 93 | 94 | torch.cuda.empty_cache() 95 | 96 | image = image.cuda() 97 | category = category.cuda() 98 | img_path = img_path[0] 99 | 100 | self.model.reset_G() 101 | self.model.set_target(image, category, img_path) 102 | # when category is unkonwn (category=-1), it would be selected from samples 103 | self.model.select_z(select_y=True if category.item() < 0 else False) 104 | loss_dict = self.model.run(save_interval=self.config['save_interval']) 105 | 106 | # average loss if distributed 107 | if self.config['dist']: 108 | for k, v in loss_dict.items(): 109 | reduced = v.data.clone() / self.world_size 110 | dist.all_reduce_multigpu([reduced]) 111 | loss_dict[k] = reduced 112 | 113 | if len(recorder) == 0: 114 | for k in loss_dict.keys(): 115 | recorder[k] = utils.AverageMeter() 116 | for k in loss_dict.keys(): 117 | recorder[k].update(loss_dict[k].item()) 118 | 119 | btime_rec.update(time.time() - end) 120 | end = time.time() 121 | 122 | # logging 123 | loss_str = "" 124 | if self.rank == 0: 125 | for k in recorder.keys(): 126 | if self.tb_logger is not None: 127 | self.tb_logger.add_scalar('train_{}'.format(k), 128 | recorder[k].avg, i + 1) 129 | loss_str += '{}: {loss.val:.4g} ({loss.avg:.4g}) '.format( 130 | k, loss=recorder[k]) 131 | 132 | self.logger.info( 133 | 'Iter: [{0}/{1}] '.format(i + 1, len(self.train_loader)) + 134 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '. 135 | format(batch_time=btime_rec) + 136 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) '.format( 137 | data_time=dtime_rec) + 'Image {} '.format(img_path) + 138 | loss_str) 139 | 140 | def train_morphing(self): 141 | btime_rec = utils.AverageMeter() 142 | dtime_rec = utils.AverageMeter() 143 | recorder = {} 144 | last_category = -1 145 | end = time.time() 146 | for i, (image, category, img_path) in enumerate(self.train_loader): 147 | # measure data loading time 148 | dtime_rec.update(time.time() - end) 149 | 150 | assert image.shape[0] > 0 151 | image = image.cuda() 152 | category = category.cuda() 153 | img_path = img_path[0] 154 | 155 | self.model.reset_G() 156 | self.model.set_target(image, category, img_path) 157 | self.model.select_z() 158 | loss_dict = self.model.run( 159 | save_interval=self.config['save_interval']) 160 | 161 | # apply image morphing within the same category 162 | if category == last_category: 163 | self.morphing() 164 | torch.cuda.empty_cache() 165 | 166 | with torch.no_grad(): 167 | self.model2.G.load_state_dict(self.model.G.state_dict()) 168 | self.model2.z.copy_(self.model.z) 169 | self.model2.img_name = self.model.img_name 170 | self.model2.target = self.model.target 171 | self.model2.category = self.model.category 172 | 173 | if category == last_category: 174 | # average loss if distributed 175 | if self.config['dist']: 176 | for k, v in loss_dict.items(): 177 | reduced = v.data.clone() / self.world_size 178 | dist.all_reduce_multigpu([reduced]) 179 | loss_dict[k] = reduced 180 | 181 | if len(recorder) < len(loss_dict): 182 | for k in loss_dict.keys(): 183 | recorder[k] = utils.AverageMeter() 184 | for k in loss_dict.keys(): 185 | recorder[k].update(loss_dict[k].item()) 186 | 187 | btime_rec.update(time.time() - end) 188 | end = time.time() 189 | 190 | # logging 191 | loss_str = "" 192 | if self.rank == 0: 193 | for k in recorder.keys(): 194 | if self.tb_logger is not None: 195 | self.tb_logger.add_scalar('train_{}'.format(k), 196 | recorder[k].avg, i + 1) 197 | loss_str += '{}: {loss.val:.4g} ({loss.avg:.4g}) '.format( 198 | k, loss=recorder[k]) 199 | 200 | self.logger.info( 201 | 'Iter: [{0}/{1}] '.format(i, len(self.train_loader)) + 202 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '. 203 | format(batch_time=btime_rec) + 204 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) '. 205 | format(data_time=dtime_rec) + 206 | 'Image {} '.format(img_path) + loss_str) 207 | 208 | last_category = category 209 | 210 | def morphing(self): 211 | weight1 = self.model.G.state_dict() 212 | weight2 = self.model2.G.state_dict() 213 | weight_interp = OrderedDict() 214 | imgs = [] 215 | with torch.no_grad(): 216 | for i in range(11): 217 | alpha = i / 10 218 | # interpolate between both latent vector and generator weight 219 | z_interp = alpha * self.model.z + (1 - alpha) * self.model2.z 220 | for k, w1 in weight1.items(): 221 | w2 = weight2[k] 222 | weight_interp[k] = alpha * w1 + (1 - alpha) * w2 223 | self.model_interp.G.load_state_dict(weight_interp) 224 | x_interp = self.model_interp.G( 225 | z_interp, self.model_interp.G.shared(self.model.y)) 226 | imgs.append(x_interp.cpu()) 227 | # save image 228 | save_path = '%s/images/%s_%s' % (self.config['exp_path'], 229 | self.model.img_name, 230 | self.model2.img_name) 231 | if not os.path.exists(save_path): 232 | os.makedirs(save_path) 233 | utils.save_img(x_interp[0], '%s/%03d.jpg' % (save_path, i + 1)) 234 | imgs = torch.cat(imgs, 0) 235 | vutils.save_image( 236 | imgs, 237 | '%s/images_sheet/morphing_class%d_%s_%s.jpg' % 238 | (self.config['exp_path'], self.model.category, self.model.img_name, 239 | self.model2.img_name), 240 | nrow=int(imgs.size(0)**0.5), 241 | normalize=True) 242 | del weight_interp 243 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .biggan_utils import * 2 | from .common_utils import * 3 | from .distributed_utils import * 4 | from .losses import * 5 | -------------------------------------------------------------------------------- /utils/biggan_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' Utilities file 4 | This file contains utility functions for bookkeeping, logging, and data loading. 5 | Methods which directly affect training should either go in layers, the model, 6 | or train_fns.py. 7 | ''' 8 | 9 | from __future__ import print_function 10 | 11 | import math 12 | import os 13 | from argparse import ArgumentParser 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | import torchvision.transforms as transforms 19 | from torch.optim.optimizer import Optimizer 20 | 21 | 22 | def prepare_parser(): 23 | usage = 'Parser for all scripts.' 24 | parser = ArgumentParser(description=usage) 25 | 26 | ### Pipeline stuff ### 27 | parser.add_argument( 28 | '--eval_mode', action='store_true', default=False, 29 | help='Evaluation mode? (do not save logs) ' 30 | ' (default: %(default)s)') 31 | 32 | ### Model stuff ### 33 | parser.add_argument( 34 | '--model', type=str, default='BigGAN', 35 | help='Name of the model module (default: %(default)s)') 36 | parser.add_argument( 37 | '--G_param', type=str, default='SN', 38 | help='Parameterization style to use for G, spectral norm (SN) or SVD (SVD)' 39 | ' or None (default: %(default)s)') 40 | parser.add_argument( 41 | '--D_param', type=str, default='SN', 42 | help='Parameterization style to use for D, spectral norm (SN) or SVD (SVD)' 43 | ' or None (default: %(default)s)') 44 | parser.add_argument( 45 | '--G_ch', type=int, default=64, 46 | help='Channel multiplier for G (default: %(default)s)') 47 | parser.add_argument( 48 | '--D_ch', type=int, default=64, 49 | help='Channel multiplier for D (default: %(default)s)') 50 | parser.add_argument( 51 | '--G_depth', type=int, default=1, 52 | help='Number of resblocks per stage in G? (default: %(default)s)') 53 | parser.add_argument( 54 | '--D_depth', type=int, default=1, 55 | help='Number of resblocks per stage in D? (default: %(default)s)') 56 | parser.add_argument( 57 | '--D_thin', action='store_false', dest='D_wide', default=True, 58 | help='Use the SN-GAN channel pattern for D? (default: %(default)s)') 59 | parser.add_argument( 60 | '--G_shared', action='store_true', default=False, 61 | help='Use shared embeddings in G? (default: %(default)s)') 62 | parser.add_argument( 63 | '--shared_dim', type=int, default=0, 64 | help='G''s shared embedding dimensionality; if 0, will be equal to dim_z. ' 65 | '(default: %(default)s)') 66 | parser.add_argument( 67 | '--dim_z', type=int, default=128, 68 | help='Noise dimensionality: %(default)s)') 69 | parser.add_argument( 70 | '--hier', action='store_true', default=False, 71 | help='Use hierarchical z in G? (default: %(default)s)') 72 | parser.add_argument( 73 | '--n_classes', type=int, default=1000, 74 | help='Number of class conditions %(default)s)') 75 | parser.add_argument( 76 | '--cross_replica', action='store_true', default=False, 77 | help='Cross_replica batchnorm in G?(default: %(default)s)') 78 | parser.add_argument( 79 | '--mybn', action='store_true', default=False, 80 | help='Use my batchnorm (which supports standing stats?) %(default)s)') 81 | parser.add_argument( 82 | '--G_nl', type=str, default='inplace_relu', 83 | help='Activation function for G (default: %(default)s)') 84 | parser.add_argument( 85 | '--D_nl', type=str, default='inplace_relu', 86 | help='Activation function for D (default: %(default)s)') 87 | parser.add_argument( 88 | '--G_attn', type=str, default='64', 89 | help='What resolutions to use attention on for G (underscore separated) ' 90 | '(default: %(default)s)') 91 | parser.add_argument( 92 | '--D_attn', type=str, default='64', 93 | help='What resolutions to use attention on for D (underscore separated) ' 94 | '(default: %(default)s)') 95 | parser.add_argument( 96 | '--norm_style', type=str, default='bn', 97 | help='Normalizer style for G, one of bn [batchnorm], in [instancenorm], ' 98 | 'ln [layernorm], gn [groupnorm] (default: %(default)s)') 99 | 100 | ### Model init stuff ### 101 | parser.add_argument( 102 | '--seed', type=int, default=0, 103 | help='Random seed to use; affects both initialization and ' 104 | ' dataloading. (default: %(default)s)') 105 | parser.add_argument( 106 | '--G_init', type=str, default='ortho', 107 | help='Init style to use for G (default: %(default)s)') 108 | parser.add_argument( 109 | '--D_init', type=str, default='ortho', 110 | help='Init style to use for D(default: %(default)s)') 111 | parser.add_argument( 112 | '--skip_init', action='store_true', default=False, 113 | help='Skip initialization, ideal for testing when ortho init was used ' 114 | '(default: %(default)s)') 115 | 116 | ### Optimizer stuff ### 117 | parser.add_argument( 118 | '--optimizer', type=str, default='Adam', 119 | help='Optimizer, Adam or SGD (default: %(default)s)') 120 | parser.add_argument( 121 | '--G_lr', type=float, default=5e-5, 122 | help='Learning rate to use for Generator (default: %(default)s)') 123 | parser.add_argument( 124 | '--D_lr', type=float, default=2e-4, 125 | help='Learning rate to use for Discriminator (default: %(default)s)') 126 | parser.add_argument( 127 | '--Z_lr_mult', type=float, default=50, 128 | help='Learning rate multiplication to use for Z (default: %(default)s)') 129 | parser.add_argument( 130 | '--G_B1', type=float, default=0.0, 131 | help='Beta1 to use for Generator (default: %(default)s)') 132 | parser.add_argument( 133 | '--D_B1', type=float, default=0.0, 134 | help='Beta1 to use for Discriminator (default: %(default)s)') 135 | parser.add_argument( 136 | '--G_B2', type=float, default=0.999, 137 | help='Beta2 to use for Generator (default: %(default)s)') 138 | parser.add_argument( 139 | '--D_B2', type=float, default=0.999, 140 | help='Beta2 to use for Discriminator (default: %(default)s)') 141 | 142 | ### Batch size, parallel, and precision stuff ### 143 | parser.add_argument( 144 | '--G_fp16', action='store_true', default=False, 145 | help='Train with half-precision in G? (default: %(default)s)') 146 | parser.add_argument( 147 | '--D_fp16', action='store_true', default=False, 148 | help='Train with half-precision in D? (default: %(default)s)') 149 | parser.add_argument( 150 | '--D_mixed_precision', action='store_true', default=False, 151 | help='Train with half-precision activations but fp32 params in D? ' 152 | '(default: %(default)s)') 153 | parser.add_argument( 154 | '--G_mixed_precision', action='store_true', default=False, 155 | help='Train with half-precision activations but fp32 params in G? ' 156 | '(default: %(default)s)') 157 | parser.add_argument( 158 | '--accumulate_stats', action='store_true', default=False, 159 | help='Accumulate "standing" batchnorm stats? (default: %(default)s)') 160 | parser.add_argument( 161 | '--num_standing_accumulations', type=int, default=16, 162 | help='Number of forward passes to use in accumulating standing stats? ' 163 | '(default: %(default)s)') 164 | 165 | ### Bookkeping stuff ### 166 | parser.add_argument( 167 | '--weights_root', type=str, default='weights', 168 | help='Default location to store weights (default: %(default)s)') 169 | 170 | ### EMA Stuff ### 171 | parser.add_argument( 172 | '--use_ema', action='store_true', default=False, 173 | help='Use the EMA parameters of G for evaluation? (default: %(default)s)') 174 | 175 | ### Numerical precision and SV stuff ### 176 | parser.add_argument( 177 | '--adam_eps', type=float, default=1e-6, 178 | help='epsilon value to use for Adam (default: %(default)s)') 179 | parser.add_argument( 180 | '--BN_eps', type=float, default=1e-5, 181 | help='epsilon value to use for BatchNorm (default: %(default)s)') 182 | parser.add_argument( 183 | '--SN_eps', type=float, default=1e-6, 184 | help='epsilon value to use for Spectral Norm(default: %(default)s)') 185 | parser.add_argument( 186 | '--num_G_SVs', type=int, default=1, 187 | help='Number of SVs to track in G (default: %(default)s)') 188 | parser.add_argument( 189 | '--num_D_SVs', type=int, default=1, 190 | help='Number of SVs to track in D (default: %(default)s)') 191 | parser.add_argument( 192 | '--num_G_SV_itrs', type=int, default=1, 193 | help='Number of SV itrs in G (default: %(default)s)') 194 | parser.add_argument( 195 | '--num_D_SV_itrs', type=int, default=1, 196 | help='Number of SV itrs in D (default: %(default)s)') 197 | 198 | ### Resume training stuff 199 | parser.add_argument( 200 | '--load_weights', type=str, default='', 201 | help='Suffix for which weights to load (e.g. best0, copy0) ' 202 | '(default: %(default)s)') 203 | 204 | ### Log stuff ### 205 | parser.add_argument( 206 | '--no_tb', action='store_true', default=False, 207 | help='Do not use tensorboard? ' 208 | '(default: %(default)s)') 209 | return parser 210 | 211 | 212 | activation_dict = {'inplace_relu': nn.ReLU(inplace=True), 213 | 'relu': nn.ReLU(inplace=False), 214 | 'ir': nn.ReLU(inplace=True)} 215 | 216 | 217 | def dgp_update_config(config): 218 | config['G_activation'] = activation_dict[config['G_nl']] 219 | config['D_activation'] = activation_dict[config['D_nl']] 220 | 221 | 222 | class CenterCropLongEdge(object): 223 | """Crops the given PIL Image on the long edge. 224 | Args: 225 | size (sequence or int): Desired output size of the crop. If size is an 226 | int instead of sequence like (h, w), a square crop (size, size) is 227 | made. 228 | """ 229 | 230 | def __call__(self, img): 231 | """ 232 | Args: 233 | img (PIL Image): Image to be cropped. 234 | Returns: 235 | PIL Image: Cropped image. 236 | """ 237 | return transforms.functional.center_crop(img, min(img.size)) 238 | 239 | def __repr__(self): 240 | return self.__class__.__name__ 241 | 242 | 243 | class RandomCropLongEdge(object): 244 | """Crops the given PIL Image on the long edge with a random start point. 245 | Args: 246 | size (sequence or int): Desired output size of the crop. If size is an 247 | int instead of sequence like (h, w), a square crop (size, size) is 248 | made. 249 | """ 250 | 251 | def __call__(self, img): 252 | """ 253 | Args: 254 | img (PIL Image): Image to be cropped. 255 | Returns: 256 | PIL Image: Cropped image. 257 | """ 258 | size = (min(img.size), min(img.size)) 259 | # Only step forward along this edge if it's the long edge 260 | i = (0 if size[0] == img.size[0] else np.random.randint( 261 | low=0, high=img.size[0] - size[0])) 262 | j = (0 if size[1] == img.size[1] else np.random.randint( 263 | low=0, high=img.size[1] - size[1])) 264 | return transforms.functional.crop(img, i, j, size[0], size[1]) 265 | 266 | def __repr__(self): 267 | return self.__class__.__name__ 268 | 269 | 270 | # Utility file to seed rngs 271 | def seed_rng(seed): 272 | torch.manual_seed(seed) 273 | torch.cuda.manual_seed(seed) 274 | np.random.seed(seed) 275 | 276 | 277 | # Apply modified ortho reg to a model 278 | # This function is an optimized version that directly computes the gradient, 279 | # instead of computing and then differentiating the loss. 280 | def ortho(model, strength=1e-4, blacklist=[], invconvonly=False): 281 | with torch.no_grad(): 282 | for name, param in model.named_parameters(): 283 | if invconvonly: 284 | if 'weight' not in name or 'NN' in name: 285 | continue 286 | # Only apply this to parameters with at least 2 axes, and not in the blacklist 287 | if len(param.shape) < 2 or any([param is item for item in blacklist]): 288 | continue 289 | w = param.view(param.shape[0], -1) 290 | grad = (2 * torch.mm( 291 | torch.mm(w, w.t()) * 292 | (1. - torch.eye(w.shape[0], device=w.device)), w)) 293 | param.grad.data += strength * grad.view(param.shape) 294 | 295 | 296 | # Default ortho reg 297 | # This function is an optimized version that directly computes the gradient, 298 | # instead of computing and then differentiating the loss. 299 | def default_ortho(model, strength=1e-4, blacklist=[]): 300 | with torch.no_grad(): 301 | for param in model.parameters(): 302 | # Only apply this to parameters with at least 2 axes & not in blacklist 303 | if len(param.shape) < 2 or param in blacklist: 304 | continue 305 | w = param.view(param.shape[0], -1) 306 | grad = (2 * torch.mm( 307 | torch.mm(w, w.t()) - torch.eye(w.shape[0], device=w.device), 308 | w)) 309 | param.grad.data += strength * grad.view(param.shape) 310 | 311 | 312 | # Convenience utility to switch off requires_grad 313 | def toggle_grad(model, on_or_off): 314 | for param in model.parameters(): 315 | param.requires_grad = on_or_off 316 | 317 | 318 | # Function to join strings or ignore them 319 | # Base string is the string to link "strings," while strings 320 | # is a list of strings or Nones. 321 | def join_strings(base_string, strings): 322 | return base_string.join([item for item in strings if item]) 323 | 324 | 325 | # Load a model's weights 326 | def load_weights(G, 327 | D, 328 | weights_root, 329 | name_suffix=None, 330 | G_ema=None, 331 | strict=False): 332 | def map_func(storage, location): 333 | return storage.cuda() 334 | 335 | if name_suffix: 336 | print('Loading %s weights from %s...' % (name_suffix, weights_root)) 337 | else: 338 | print('Loading weights from %s...' % weights_root) 339 | if G is not None: 340 | G.load_state_dict( 341 | torch.load( 342 | '%s/%s.pth' % 343 | (weights_root, join_strings('_', ['G', name_suffix])), 344 | map_location=map_func), 345 | strict=strict) 346 | if D is not None: 347 | D.load_state_dict( 348 | torch.load( 349 | '%s/%s.pth' % 350 | (weights_root, join_strings('_', ['D', name_suffix])), 351 | map_location=map_func), 352 | strict=strict) 353 | if G_ema is not None: 354 | print('Loading ema generator...') 355 | G_ema.load_state_dict( 356 | torch.load( 357 | '%s/%s.pth' % 358 | (weights_root, join_strings('_', ['G_ema', name_suffix])), 359 | map_location=map_func), 360 | strict=strict) 361 | 362 | 363 | # Get singular values to log. This will use the state dict to find them 364 | # and substitute underscores for dots. 365 | def get_SVs(net, prefix): 366 | d = net.state_dict() 367 | return {('%s_%s' % (prefix, key)).replace('.', '_'): float(d[key].item()) 368 | for key in d if 'sv' in key} 369 | 370 | 371 | # Name an experiment based on its config 372 | def name_from_config(config): 373 | name = '_'.join([ 374 | item for item in [ 375 | 'Big%s' % config['which_train_fn'], 376 | config['dataset'], 377 | config['model'] if config['model'] != 'BigGAN' else None, 378 | 'seed%d' % config['seed'], 379 | 'Gch%d' % config['G_ch'], 380 | 'Dch%d' % config['D_ch'], 381 | 'Gd%d' % config['G_depth'] if config['G_depth'] > 1 else None, 382 | 'Dd%d' % config['D_depth'] if config['D_depth'] > 1 else None, 383 | 'bs%d' % config['batch_size'], 384 | 'Gfp16' if config['G_fp16'] else None, 385 | 'Dfp16' if config['D_fp16'] else None, 386 | 'nDs%d' % config['num_D_steps'] if config['num_D_steps'] > 1 else None, 387 | 'nDa%d' % config['num_D_accumulations'] if config['num_D_accumulations'] > 1 else None, 388 | 'nGa%d' % config['num_G_accumulations'] if config['num_G_accumulations'] > 1 else None, 389 | 'Glr%2.1e' % config['G_lr'], 390 | 'Dlr%2.1e' % config['D_lr'], 391 | 'GB%3.3f' % config['G_B1'] if config['G_B1'] != 0.0 else None, 392 | 'GBB%3.3f' % config['G_B2'] if config['G_B2'] != 0.999 else None, 393 | 'DB%3.3f' % config['D_B1'] if config['D_B1'] != 0.0 else None, 394 | 'DBB%3.3f' % config['D_B2'] if config['D_B2'] != 0.999 else None, 395 | 'Gnl%s' % config['G_nl'], 396 | 'Dnl%s' % config['D_nl'], 397 | 'Ginit%s' % config['G_init'], 398 | 'Dinit%s' % config['D_init'], 399 | 'G%s' % config['G_param'] if config['G_param'] != 'SN' else None, 400 | 'D%s' % config['D_param'] if config['D_param'] != 'SN' else None, 401 | 'Gattn%s' % config['G_attn'] if config['G_attn'] != '0' else None, 402 | 'Dattn%s' % config['D_attn'] if config['D_attn'] != '0' else None, 403 | 'Gortho%2.1e' % config['G_ortho'] if config['G_ortho'] > 0.0 else None, 404 | 'Dortho%2.1e' % config['D_ortho'] if config['D_ortho'] > 0.0 else None, 405 | config['norm_style'] if config['norm_style'] != 'bn' else None, 406 | 'cr' if config['cross_replica'] else None, 407 | 'Gshared' if config['G_shared'] else None, 408 | 'hier' if config['hier'] else None, 409 | 'ema' if config['ema'] else None, 410 | 'Glow' if config['Glow'] else None, 411 | 'nlevels%d' % config['n_levels'] if config['n_levels'] else None, 412 | 'depth%d' % config['depth'] if config['depth'] else None, 413 | config['name_suffix'] if config['name_suffix'] else None, 414 | ] if item is not None 415 | ]) 416 | return name 417 | 418 | 419 | # Get GPU memory, -i is the index 420 | def query_gpu(indices): 421 | os.system('nvidia-smi -i 0 --query-gpu=memory.free --format=csv') 422 | 423 | 424 | # Convenience function to count the number of parameters in a module 425 | def count_parameters(module): 426 | print('Number of parameters: {}'.format( 427 | sum([p.data.nelement() for p in module.parameters()]))) 428 | 429 | 430 | # Convenience function to sample an index, not actually a 1-hot 431 | def sample_1hot(batch_size, num_classes, device='cuda'): 432 | return torch.randint( 433 | low=0, 434 | high=num_classes, 435 | size=(batch_size, ), 436 | device=device, 437 | dtype=torch.int64, 438 | requires_grad=False) 439 | 440 | 441 | # A highly simplified convenience class for sampling from distributions 442 | # One could also use PyTorch's inbuilt distributions package. 443 | # Note that this class requires initialization to proceed as 444 | # x = Distribution(torch.randn(size)) 445 | # x.init_distribution(dist_type, **dist_kwargs) 446 | # x = x.to(device,dtype) 447 | # This is partially based on https://discuss.pytorch.org/t/subclassing-torch-tensor/23754/2 448 | class Distribution(torch.Tensor): 449 | # Init the params of the distribution 450 | def init_distribution(self, dist_type, **kwargs): 451 | self.dist_type = dist_type 452 | self.dist_kwargs = kwargs 453 | if self.dist_type == 'normal': 454 | self.mean, self.var = kwargs['mean'], kwargs['var'] 455 | elif self.dist_type == 'categorical': 456 | self.num_categories = kwargs['num_categories'] 457 | 458 | def sample_(self): 459 | if self.dist_type == 'normal': 460 | self.normal_(self.mean, self.var) 461 | elif self.dist_type == 'categorical': 462 | self.random_(0, self.num_categories) 463 | 464 | # update the mean and covariance for multi-variate distribution 465 | def update(self, mean, cov): 466 | self.mean, self.cov = mean, cov 467 | 468 | # Silly hack: overwrite the to() method to wrap the new object 469 | # in a distribution as well 470 | def to(self, *args, **kwargs): 471 | new_obj = Distribution(self) 472 | new_obj.init_distribution(self.dist_type, **self.dist_kwargs) 473 | new_obj.data = super().to(*args, **kwargs) 474 | return new_obj 475 | 476 | 477 | # Convenience function to prepare a z and y vector 478 | def prepare_z_y(G_batch_size, 479 | dim_z, 480 | nclasses, 481 | device='cuda', 482 | fp16=False, 483 | z_var=1.0): 484 | z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False)) 485 | z_.init_distribution('normal', mean=0, var=z_var) 486 | z_ = z_.to(device, torch.float16 if fp16 else torch.float32) 487 | 488 | if fp16: 489 | z_ = z_.half() 490 | 491 | y_ = Distribution(torch.zeros(G_batch_size, requires_grad=False)) 492 | y_.init_distribution('categorical', num_categories=nclasses) 493 | y_ = y_.to(device, torch.int64) 494 | return z_, y_ 495 | 496 | 497 | def initiate_standing_stats(net): 498 | for module in net.modules(): 499 | if hasattr(module, 'accumulate_standing'): 500 | module.reset_stats() 501 | module.accumulate_standing = True 502 | 503 | 504 | def accumulate_standing_stats(net, z, y, nclasses, num_accumulations=16): 505 | initiate_standing_stats(net) 506 | net.train() 507 | for i in range(num_accumulations): 508 | with torch.no_grad(): 509 | z.sample_() 510 | y.random_(0, nclasses) 511 | x = net(z, net.shared(y)) # No need to parallelize here unless using syncbn 512 | # Set to eval mode 513 | net.eval() 514 | 515 | 516 | # This version of Adam keeps an fp32 copy of the parameters and 517 | # does all of the parameter updates in fp32, while still doing the 518 | # forwards and backwards passes using fp16 (i.e. fp16 copies of the 519 | # parameters and fp16 activations). 520 | # 521 | # Note that this calls .float().cuda() on the params. 522 | 523 | 524 | class Adam16(Optimizer): 525 | 526 | def __init__(self, 527 | params, 528 | lr=1e-3, 529 | betas=(0.9, 0.999), 530 | eps=1e-8, 531 | weight_decay=0): 532 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 533 | params = list(params) 534 | super(Adam16, self).__init__(params, defaults) 535 | 536 | # Safety modification to make sure we floatify our state 537 | def load_state_dict(self, state_dict): 538 | super(Adam16, self).load_state_dict(state_dict) 539 | for group in self.param_groups: 540 | for p in group['params']: 541 | self.state[p]['exp_avg'] = self.state[p]['exp_avg'].float() 542 | self.state[p]['exp_avg_sq'] = self.state[p][ 543 | 'exp_avg_sq'].float() 544 | self.state[p]['fp32_p'] = self.state[p]['fp32_p'].float() 545 | 546 | def step(self, closure=None): 547 | """Performs a single optimization step. 548 | Arguments: 549 | closure (callable, optional): A closure that reevaluates the model 550 | and returns the loss. 551 | """ 552 | loss = None 553 | if closure is not None: 554 | loss = closure() 555 | 556 | for group in self.param_groups: 557 | for p in group['params']: 558 | if p.grad is None: 559 | continue 560 | 561 | grad = p.grad.data.float() 562 | state = self.state[p] 563 | 564 | # State initialization 565 | if len(state) == 0: 566 | state['step'] = 0 567 | # Exponential moving average of gradient values 568 | state['exp_avg'] = grad.new().resize_as_(grad).zero_() 569 | # Exponential moving average of squared gradient values 570 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_() 571 | # Fp32 copy of the weights 572 | state['fp32_p'] = p.data.float() 573 | 574 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 575 | beta1, beta2 = group['betas'] 576 | 577 | state['step'] += 1 578 | 579 | if group['weight_decay'] != 0: 580 | grad = grad.add(group['weight_decay'], state['fp32_p']) 581 | 582 | # Decay the first and second moment running average coefficient 583 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 584 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 585 | 586 | denom = exp_avg_sq.sqrt().add_(group['eps']) 587 | 588 | bias_correction1 = 1 - beta1**state['step'] 589 | bias_correction2 = 1 - beta2**state['step'] 590 | step_size = group['lr'] * math.sqrt( 591 | bias_correction2) / bias_correction1 592 | 593 | state['fp32_p'].addcdiv_(-step_size, exp_avg, denom) 594 | p.data = state['fp32_p'].half() 595 | 596 | return loss 597 | -------------------------------------------------------------------------------- /utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as transforms 7 | from PIL import Image 8 | 9 | import dataset 10 | 11 | from .biggan_utils import CenterCropLongEdge 12 | 13 | 14 | # Arguments for DGP 15 | def add_dgp_parser(parser): 16 | parser.add_argument( 17 | '--dist', action='store_true', default=False, 18 | help='Train with distributed implementation (default: %(default)s)') 19 | parser.add_argument( 20 | '--port', type=str, default='12345', 21 | help='Port id for distributed training (default: %(default)s)') 22 | parser.add_argument( 23 | '--exp_path', type=str, default='', 24 | help='Experiment path (default: %(default)s)') 25 | parser.add_argument( 26 | '--root_dir', type=str, default='', 27 | help='Root path of dataset (default: %(default)s)') 28 | parser.add_argument( 29 | '--list_file', type=str, default='', 30 | help='List file of the dataset (default: %(default)s)') 31 | parser.add_argument( 32 | '--resolution', type=int, default=256, 33 | help='Resolution to resize the input image (default: %(default)s)') 34 | parser.add_argument( 35 | '--dgp_mode', type=str, default='reconstruct', 36 | help='DGP mode (default: %(default)s)') 37 | parser.add_argument( 38 | '--random_G', action='store_true', default=False, 39 | help='Use randomly initialized generator? (default: %(default)s)') 40 | parser.add_argument( 41 | '--update_G', action='store_true', default=False, 42 | help='Finetune Generator? (default: %(default)s)') 43 | parser.add_argument( 44 | '--update_embed', action='store_true', default=False, 45 | help='Finetune class embedding? (default: %(default)s)') 46 | parser.add_argument( 47 | '--save_G', action='store_true', default=False, 48 | help='Save fine-tuned generator and latent vector? (default: %(default)s)') 49 | parser.add_argument( 50 | '--ftr_type', type=str, default='Discriminator', 51 | choices=['Discriminator', 'VGG'], 52 | help='Feature loss type, choose from Discriminator and VGG (default: %(default)s)') 53 | parser.add_argument( 54 | '--ftr_num', type=int, default=[3], nargs='+', 55 | help='Number of features to computer feature loss (default: %(default)s)') 56 | parser.add_argument( 57 | '--ft_num', type=int, default=[2], nargs='+', 58 | help='Number of parameter groups to finetune (default: %(default)s)') 59 | parser.add_argument( 60 | '--print_interval', type=int, default=100, nargs='+', 61 | help='Number of iterations to print training loss (default: %(default)s)') 62 | parser.add_argument( 63 | '--save_interval', type=int, default=None, nargs='+', 64 | help='Number of iterations to save image') 65 | parser.add_argument( 66 | '--lr_ratio', type=float, default=[1.0, 1.0, 1.0, 1.0], nargs='+', 67 | help='Decreasing ratio for learning rate in blocks (default: %(default)s)') 68 | parser.add_argument( 69 | '--w_D_loss', type=float, default=[0.1], nargs='+', 70 | help='Discriminator feature loss weight (default: %(default)s)') 71 | parser.add_argument( 72 | '--w_nll', type=float, default=0.001, 73 | help='Weight for the negative log-likelihood loss (default: %(default)s)') 74 | parser.add_argument( 75 | '--w_mse', type=float, default=[0.1], nargs='+', 76 | help='MSE loss weight (default: %(default)s)') 77 | parser.add_argument( 78 | '--select_num', type=int, default=500, 79 | help='Number of image pool to select from (default: %(default)s)') 80 | parser.add_argument( 81 | '--sample_std', type=float, default=1.0, 82 | help='Std of the gaussian distribution used for sampling (default: %(default)s)') 83 | parser.add_argument( 84 | '--iterations', type=int, default=[200, 200, 200, 200], nargs='+', 85 | help='Training iterations for all stages') 86 | parser.add_argument( 87 | '--G_lrs', type=float, default=[1e-6, 2e-5, 1e-5, 1e-6], nargs='+', 88 | help='Learning rate steps of Generator') 89 | parser.add_argument( 90 | '--z_lrs', type=float, default=[1e-1, 1e-3, 1e-5, 1e-6], nargs='+', 91 | help='Learning rate steps of latent code z') 92 | parser.add_argument( 93 | '--warm_up', type=int, default=0, 94 | help='Number of warmup iterations (default: %(default)s)') 95 | parser.add_argument( 96 | '--use_in', type=str2bool, default=[False, False, False, False], nargs='+', 97 | help='Whether to use instance normalization in generator') 98 | parser.add_argument( 99 | '--stop_mse', type=float, default=0.0, 100 | help='MSE threshold for stopping training (default: %(default)s)') 101 | parser.add_argument( 102 | '--stop_ftr', type=float, default=0.0, 103 | help='Feature loss threshold for stopping training (default: %(default)s)') 104 | return parser 105 | 106 | 107 | def str2bool(v): 108 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 109 | return True 110 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 111 | return False 112 | else: 113 | raise argparse.ArgumentTypeError('Boolean value expected.') 114 | 115 | 116 | def get_img(img_path, resolution): 117 | img = dataset.default_loader(img_path) 118 | norm_mean = [0.5, 0.5, 0.5] 119 | norm_std = [0.5, 0.5, 0.5] 120 | transform = transforms.Compose([ 121 | CenterCropLongEdge(), 122 | transforms.Resize(resolution), 123 | transforms.ToTensor(), 124 | transforms.Normalize(norm_mean, norm_std) 125 | ]) 126 | img = transform(img) 127 | return img.unsqueeze(0) 128 | 129 | 130 | def save_img(image, path): 131 | image = np.uint8(255 * (image.cpu().detach().numpy() + 1) / 2.) 132 | image = np.transpose(image, (1, 2, 0)) 133 | image = Image.fromarray(image) 134 | image.save(path) 135 | 136 | 137 | class AverageMeter(object): 138 | """Computes and stores the average and current value""" 139 | 140 | def __init__(self): 141 | self.reset() 142 | 143 | def reset(self): 144 | self.val = 0 145 | self.avg = 0 146 | self.sum = 0 147 | self.count = 0 148 | 149 | def update(self, val, n=1): 150 | self.val = val 151 | self.sum += val * n 152 | self.count += n 153 | self.avg = self.sum / self.count 154 | 155 | 156 | def size_splits(tensor, split_sizes, dim=0): 157 | """Splits the tensor according to chunks of split_sizes. 158 | 159 | Arguments: 160 | tensor (Tensor): tensor to split. 161 | split_sizes (list(int)): sizes of chunks 162 | dim (int): dimension along which to split the tensor. 163 | """ 164 | if dim < 0: 165 | dim += tensor.dim() 166 | 167 | dim_size = tensor.size(dim) 168 | if dim_size != torch.sum(torch.Tensor(split_sizes)): 169 | raise KeyError("Sum of split sizes exceeds tensor dim") 170 | 171 | splits = torch.cumsum(torch.Tensor([0] + split_sizes), dim=0)[:-1] 172 | 173 | return tuple( 174 | tensor.narrow(int(dim), int(start), int(length)) 175 | for start, length in zip(splits, split_sizes)) 176 | 177 | 178 | class LRScheduler(object): 179 | 180 | def __init__(self, optimizer, warm_up): 181 | super(LRScheduler, self).__init__() 182 | self.optimizer = optimizer 183 | self.warm_up = warm_up 184 | 185 | def update(self, iteration, learning_rate, num_group=1000, ratio=1): 186 | if iteration < self.warm_up: 187 | learning_rate *= iteration / self.warm_up 188 | for i, param_group in enumerate(self.optimizer.param_groups): 189 | if i >= num_group: 190 | param_group['lr'] = 0 191 | else: 192 | param_group['lr'] = learning_rate * ratio**i 193 | 194 | 195 | def create_logger(name, log_file, level=logging.INFO): 196 | l = logging.getLogger(name) 197 | formatter = logging.Formatter('[%(asctime)s] %(message)s') 198 | fh = logging.FileHandler(log_file) 199 | fh.setFormatter(formatter) 200 | sh = logging.StreamHandler() 201 | sh.setFormatter(formatter) 202 | l.setLevel(level) 203 | l.addHandler(fh) 204 | l.addHandler(sh) 205 | return l 206 | 207 | 208 | def map_func(storage, location): 209 | return storage.cuda() 210 | 211 | 212 | def pil_to_np(img_PIL): 213 | '''Converts image in PIL format to np.array. 214 | 215 | From W x H x C [0...255] to C x W x H [0..1] 216 | ''' 217 | ar = np.array(img_PIL) 218 | 219 | if len(ar.shape) == 3: 220 | ar = ar.transpose(2, 0, 1) 221 | else: 222 | ar = ar[None, ...] 223 | 224 | return ar.astype(np.float32) / 255. 225 | 226 | 227 | def np_to_pil(img_np): 228 | '''Converts image in np.array format to PIL image. 229 | 230 | From C x W x H [0..1] to W x H x C [0...255] 231 | ''' 232 | ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) 233 | 234 | if img_np.shape[0] == 1: 235 | ar = ar[0] 236 | else: 237 | ar = ar.transpose(1, 2, 0) 238 | 239 | return Image.fromarray(ar) 240 | 241 | 242 | def np_to_torch(img_np): 243 | '''Converts image in numpy.array to torch.Tensor. 244 | 245 | From C x W x H [0..1] to C x W x H [0..1] 246 | ''' 247 | return torch.from_numpy(img_np)[None, :] 248 | 249 | 250 | def torch_to_np(img_var): 251 | '''Converts an image in torch.Tensor format to np.array. 252 | 253 | From 1 x C x W x H [0..1] to C x W x H [0..1] 254 | ''' 255 | return img_var.detach().cpu().numpy()[0] 256 | 257 | 258 | def bicubic_torch(img, size): 259 | img_pil = np_to_pil(torch_to_np(img)) 260 | img_bicubic_pil = img_pil.resize(size, Image.BICUBIC) 261 | img_bicubic_pth = np_to_torch(pil_to_np(img_bicubic_pil)) 262 | return img_bicubic_pth 263 | -------------------------------------------------------------------------------- /utils/distributed_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import multiprocessing as mp 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | from torch.nn import Module 8 | from torch.utils.data import Sampler 9 | 10 | 11 | class DistModule(Module): 12 | 13 | def __init__(self, module): 14 | super(DistModule, self).__init__() 15 | self.module = module 16 | broadcast_params(self.module) 17 | 18 | def forward(self, *inputs, **kwargs): 19 | return self.module(*inputs, **kwargs) 20 | 21 | def train(self, mode=True): 22 | super(DistModule, self).train(mode) 23 | self.module.train(mode) 24 | 25 | 26 | def average_gradients(model): 27 | """ average gradients """ 28 | for param in model.parameters(): 29 | if param.requires_grad and param.grad is not None: 30 | dist.all_reduce(param.grad.data) 31 | 32 | 33 | def broadcast_params(model): 34 | """ broadcast model parameters """ 35 | for p in model.state_dict().values(): 36 | dist.broadcast(p, 0) 37 | 38 | 39 | def average_params(model): 40 | """ broadcast model parameters """ 41 | worldsize = dist.get_world_size() 42 | for p in model.state_dict().values(): 43 | dist.all_reduce(p) 44 | p /= worldsize 45 | 46 | 47 | def dist_init(port): 48 | if mp.get_start_method(allow_none=True) != 'spawn': 49 | mp.set_start_method('spawn') 50 | proc_id = int(os.environ['SLURM_PROCID']) 51 | ntasks = int(os.environ['SLURM_NTASKS']) 52 | node_list = os.environ['SLURM_NODELIST'] 53 | num_gpus = torch.cuda.device_count() 54 | torch.cuda.set_device(proc_id % num_gpus) 55 | 56 | if '[' in node_list: 57 | beg = node_list.find('[') 58 | pos1 = node_list.find('-', beg) 59 | if pos1 < 0: 60 | pos1 = 1000 61 | pos2 = node_list.find(',', beg) 62 | if pos2 < 0: 63 | pos2 = 1000 64 | node_list = node_list[:min(pos1, pos2)].replace('[', '') 65 | addr = node_list[8:].replace('-', '.') 66 | print(addr) 67 | 68 | os.environ['MASTER_PORT'] = port 69 | os.environ['MASTER_ADDR'] = addr 70 | os.environ['WORLD_SIZE'] = str(ntasks) 71 | os.environ['RANK'] = str(proc_id) 72 | dist.init_process_group(backend='nccl') 73 | 74 | rank = dist.get_rank() 75 | world_size = dist.get_world_size() 76 | return rank, world_size 77 | 78 | 79 | class DistributedSampler(Sampler): 80 | """Sampler that restricts data loading to a subset of the dataset. 81 | 82 | It is especially useful in conjunction with 83 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 84 | process can pass a DistributedSampler instance as a DataLoader sampler, 85 | and load a subset of the original dataset that is exclusive to it. 86 | 87 | .. note:: 88 | Dataset is assumed to be of constant size. 89 | 90 | Arguments: 91 | dataset: Dataset used for sampling. 92 | num_replicas (optional): Number of processes participating in 93 | distributed training. 94 | rank (optional): Rank of the current process within num_replicas. 95 | """ 96 | 97 | def __init__(self, dataset, num_replicas=None, rank=None): 98 | if num_replicas is None: 99 | if not dist.is_available(): 100 | raise RuntimeError( 101 | "Requires distributed package to be available") 102 | num_replicas = dist.get_world_size() 103 | if rank is None: 104 | if not dist.is_available(): 105 | raise RuntimeError( 106 | "Requires distributed package to be available") 107 | rank = dist.get_rank() 108 | self.dataset = dataset 109 | self.num_replicas = num_replicas 110 | self.rank = rank 111 | self.epoch = 0 112 | self.num_samples = int( 113 | math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 114 | self.total_size = self.num_samples * self.num_replicas 115 | 116 | def __iter__(self): 117 | # deterministically shuffle based on epoch 118 | indices = [i for i in range(len(self.dataset))] 119 | 120 | # add extra samples to make it evenly divisible 121 | indices += indices[:(self.total_size - len(indices))] 122 | assert len(indices) == self.total_size 123 | 124 | # subsample 125 | indices = indices[self.rank * self.num_samples:(self.rank + 1) * 126 | self.num_samples] 127 | assert len(indices) == self.num_samples 128 | 129 | return iter(indices) 130 | 131 | def __len__(self): 132 | return self.num_samples 133 | 134 | def set_epoch(self, epoch): 135 | self.epoch = epoch 136 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class PerceptLoss(object): 7 | 8 | def __init__(self): 9 | pass 10 | 11 | def __call__(self, LossNet, fake_img, real_img): 12 | with torch.no_grad(): 13 | real_feature = LossNet(real_img.detach()) 14 | fake_feature = LossNet(fake_img) 15 | perceptual_penalty = F.mse_loss(fake_feature, real_feature) 16 | return perceptual_penalty 17 | 18 | def set_ftr_num(self, ftr_num): 19 | pass 20 | 21 | 22 | class DiscriminatorLoss(object): 23 | 24 | def __init__(self, ftr_num=4, data_parallel=False): 25 | self.data_parallel = data_parallel 26 | self.ftr_num = ftr_num 27 | 28 | def __call__(self, D, fake_img, real_img): 29 | if self.data_parallel: 30 | with torch.no_grad(): 31 | d, real_feature = nn.parallel.data_parallel( 32 | D, real_img.detach()) 33 | d, fake_feature = nn.parallel.data_parallel(D, fake_img) 34 | else: 35 | with torch.no_grad(): 36 | d, real_feature = D(real_img.detach()) 37 | d, fake_feature = D(fake_img) 38 | D_penalty = 0 39 | for i in range(self.ftr_num): 40 | f_id = -i - 1 41 | D_penalty = D_penalty + F.l1_loss(fake_feature[f_id], 42 | real_feature[f_id]) 43 | return D_penalty 44 | 45 | def set_ftr_num(self, ftr_num): 46 | self.ftr_num = ftr_num 47 | --------------------------------------------------------------------------------