├── .gitignore
├── .style.yapf
├── LICENSE
├── README.md
├── data
├── ILSVRC2012_val_00000525.JPEG
├── ILSVRC2012_val_00001172.JPEG
├── ILSVRC2012_val_00001970.JPEG
├── ILSVRC2012_val_00003004.JPEG
├── ILSVRC2012_val_00004291.JPEG
├── ILSVRC2012_val_00008229.JPEG
├── ILSVRC2012_val_00020814.JPEG
├── ILSVRC2012_val_00022130.JPEG
├── ILSVRC2012_val_00042095.JPEG
├── ILSVRC2012_val_00044065.JPEG
├── ILSVRC2012_val_00044640.JPEG
├── manipulation.gif
├── others
│ ├── library1.jpg
│ ├── library2.jpeg
│ ├── list.txt
│ ├── stones.jpg
│ ├── window1.jpg
│ ├── window2.jpg
│ └── window3.jpeg
├── restoration.gif
└── windows.png
├── dataset.py
├── example.py
├── experiments
├── examples
│ ├── run_SR.sh
│ ├── run_category_transfer.sh
│ ├── run_colorization.sh
│ ├── run_inpainting.sh
│ ├── run_inpainting_list.sh
│ ├── run_jitter.sh
│ └── run_morphing.sh
└── imagenet1k_128
│ ├── SR
│ ├── D_biased
│ │ ├── train.sh
│ │ └── train_slurm.sh
│ └── MSE_biased
│ │ ├── train.sh
│ │ └── train_slurm.sh
│ ├── colorization
│ ├── train.sh
│ └── train_slurm.sh
│ └── inpainting
│ ├── train.sh
│ └── train_slurm.sh
├── main.py
├── models
├── __init__.py
├── biggan.py
├── dgp.py
├── downsampler.py
├── layers.py
└── nethook.py
├── pretrained
└── .gitignore
├── requirements.txt
├── scripts
└── imagenet_val_1k.txt
├── trainer.py
└── utils
├── __init__.py
├── biggan_utils.py
├── common_utils.py
├── distributed_utils.py
└── losses.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | *.pyc
3 | __pycache__
4 | *.tar
5 | checkpoint*
6 | images*
7 | logs*
8 | event*
9 | *.npy
10 | *.pth
11 |
--------------------------------------------------------------------------------
/.style.yapf:
--------------------------------------------------------------------------------
1 | [style]
2 | BASED_ON_STYLE = pep8
3 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true
4 | SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true
5 | COLUMN_LIMIT = 100
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Xingang Pan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Deep Generative Prior (DGP)
2 |
3 | ### Paper
4 |
5 | Xingang Pan, Xiaohang Zhan, Bo Dai, Dahua Lin, Chen Change Loy, Ping Luo, "[Exploiting Deep Generative Prior for Versatile Image Restoration and Manipulation](https://arxiv.org/abs/2003.13659)", ECCV2020 (**Oral**)
6 |
7 | Video: https://youtu.be/p7ToqtwfVko
8 |
9 | ### Demos
10 |
11 | DGP exploits the image prior of an off-the-shelf GAN for various image restoration and manipulation.
12 |
13 | **Image restoration**:
14 |
15 |
16 |
17 |
18 |
19 | **Image manipulation**:
20 |
21 |
22 |
23 |
24 |
25 | A **learned prior** helps **internal learning**:
26 |
27 |
28 |
29 |
30 |
31 | ### Requirements
32 |
33 | * python>=3.6
34 | * pytorch>=1.0.1
35 | * others
36 |
37 | ```sh
38 | pip install -r requirements.txt
39 | ```
40 |
41 | ### Get Started
42 |
43 | Before start, please download the pretrained BigGAN at [Google drive](https://drive.google.com/drive/folders/1buQ2BtbnUhkh4PEPXOgdPuVo2iRK7gvI?usp=sharing) or [Baidu cloud](https://pan.baidu.com/s/10GKkWt7kSClvhnEGQU4ckA) (password: uqtw), and put them to `pretrained` folder.
44 |
45 | Example1: run image colorization example:
46 |
47 | sh experiments/examples/run_colorization.sh
48 |
49 | The results will be saved in `experiments/examples/images` and `experiments/examples/image_sheet`.
50 |
51 | Example2: process images with an image list:
52 |
53 | sh experiments/examples/run_inpainting_list.sh
54 |
55 | Example3: evaluate on 1k ImageNet validation images via distributed training based on [slurm](https://slurm.schedmd.com/):
56 |
57 | # need to specifiy the root path of imagenet validate set in --root_dir
58 | sh experiments/imagenet1k_128/colorization/train_slurm.sh
59 |
60 | Note:
61 | \- BigGAN needs a class condition as input. If no class condition is provided, it would be chosen from a set of random samples.
62 | \- The hyperparameters provided may not be optimal, feel free to tune them.
63 |
64 | ### Acknowledgement
65 |
66 | The code of BigGAN is borrowed from [https://github.com/ajbrock/BigGAN-PyTorch](https://github.com/ajbrock/BigGAN-PyTorch).
67 |
68 | ### Citation
69 |
70 | ```
71 | @inproceedings{pan2020dgp,
72 | author = {Pan, Xingang and Zhan, Xiaohang and Dai, Bo and Lin, Dahua and Loy, Chen Change and Luo, Ping},
73 | title = {Exploiting Deep Generative Prior for Versatile Image Restoration and Manipulation},
74 | booktitle = {European Conference on Computer Vision (ECCV)},
75 | year = {2020}
76 | }
77 |
78 | @ARTICLE{pan2020dgp_pami,
79 | author={Pan, Xingang and Zhan, Xiaohang and Dai, Bo and Lin, Dahua and Loy, Chen Change and Luo, Ping},
80 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
81 | title={Exploiting Deep Generative Prior for Versatile Image Restoration and Manipulation},
82 | year={2021},
83 | volume={},
84 | number={},
85 | pages={1-1},
86 | doi={10.1109/TPAMI.2021.3115428}
87 | }
88 | ```
89 |
--------------------------------------------------------------------------------
/data/ILSVRC2012_val_00000525.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00000525.JPEG
--------------------------------------------------------------------------------
/data/ILSVRC2012_val_00001172.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00001172.JPEG
--------------------------------------------------------------------------------
/data/ILSVRC2012_val_00001970.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00001970.JPEG
--------------------------------------------------------------------------------
/data/ILSVRC2012_val_00003004.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00003004.JPEG
--------------------------------------------------------------------------------
/data/ILSVRC2012_val_00004291.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00004291.JPEG
--------------------------------------------------------------------------------
/data/ILSVRC2012_val_00008229.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00008229.JPEG
--------------------------------------------------------------------------------
/data/ILSVRC2012_val_00020814.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00020814.JPEG
--------------------------------------------------------------------------------
/data/ILSVRC2012_val_00022130.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00022130.JPEG
--------------------------------------------------------------------------------
/data/ILSVRC2012_val_00042095.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00042095.JPEG
--------------------------------------------------------------------------------
/data/ILSVRC2012_val_00044065.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00044065.JPEG
--------------------------------------------------------------------------------
/data/ILSVRC2012_val_00044640.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/ILSVRC2012_val_00044640.JPEG
--------------------------------------------------------------------------------
/data/manipulation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/manipulation.gif
--------------------------------------------------------------------------------
/data/others/library1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/others/library1.jpg
--------------------------------------------------------------------------------
/data/others/library2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/others/library2.jpeg
--------------------------------------------------------------------------------
/data/others/list.txt:
--------------------------------------------------------------------------------
1 | window1.jpg
2 | window2.jpg
3 | window3.jpeg
4 | library1.jpg
5 | library2.jpeg
6 | stones.jpg
7 |
--------------------------------------------------------------------------------
/data/others/stones.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/others/stones.jpg
--------------------------------------------------------------------------------
/data/others/window1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/others/window1.jpg
--------------------------------------------------------------------------------
/data/others/window2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/others/window2.jpg
--------------------------------------------------------------------------------
/data/others/window3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/others/window3.jpeg
--------------------------------------------------------------------------------
/data/restoration.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/restoration.gif
--------------------------------------------------------------------------------
/data/windows.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/XingangPan/deep-generative-prior/c503a56ad0ed62b18a5fe05ee3d1ee27c0dc9d19/data/windows.png
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | import torchvision.transforms as transforms
3 | from PIL import Image
4 |
5 | import utils
6 |
7 |
8 | def pil_loader(path):
9 | # open path as file to avoid ResourceWarning
10 | # (https://github.com/python-pillow/Pillow/issues/835)
11 | with open(path, 'rb') as f:
12 | img = Image.open(f)
13 | return img.convert('RGB')
14 |
15 |
16 | def accimage_loader(path):
17 | import accimage
18 | try:
19 | return accimage.Image(path)
20 | except IOError:
21 | # Potentially a decoding problem, fall back to PIL.Image
22 | return pil_loader(path)
23 |
24 |
25 | def default_loader(path):
26 | from torchvision import get_image_backend
27 | if get_image_backend() == 'accimage':
28 | return accimage_loader(path)
29 | else:
30 | return pil_loader(path)
31 |
32 |
33 | class ImageDataset(data.Dataset):
34 |
35 | def __init__(self,
36 | root_dir,
37 | meta_file,
38 | transform=None,
39 | image_size=128,
40 | normalize=True):
41 | self.root_dir = root_dir
42 | if transform is not None:
43 | self.transform = transform
44 | else:
45 | norm_mean = [0.5, 0.5, 0.5]
46 | norm_std = [0.5, 0.5, 0.5]
47 | if normalize:
48 | self.transform = transforms.Compose([
49 | utils.CenterCropLongEdge(),
50 | transforms.Resize(image_size),
51 | transforms.ToTensor(),
52 | transforms.Normalize(norm_mean, norm_std)
53 | ])
54 | else:
55 | self.transform = transforms.Compose([
56 | utils.CenterCropLongEdge(),
57 | transforms.Resize(image_size),
58 | transforms.ToTensor()
59 | ])
60 | with open(meta_file) as f:
61 | lines = f.readlines()
62 | print("building dataset from %s" % meta_file)
63 | self.num = len(lines)
64 | self.metas = []
65 | self.classifier = None
66 | for line in lines:
67 | line_split = line.rstrip().split()
68 | if len(line_split) == 2:
69 | self.metas.append((line_split[0], int(line_split[1])))
70 | else:
71 | self.metas.append((line_split[0], -1))
72 | print("read meta done")
73 |
74 | def __len__(self):
75 | return self.num
76 |
77 | def __getitem__(self, idx):
78 | filename = self.root_dir + '/' + self.metas[idx][0]
79 | cls = self.metas[idx][1]
80 | img = default_loader(filename)
81 |
82 | # transform
83 | if self.transform is not None:
84 | img = self.transform(img)
85 |
86 | return img, cls, self.metas[idx][0]
87 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
4 | from collections import OrderedDict
5 |
6 | import torch
7 | import torchvision.utils as vutils
8 |
9 | import utils
10 | from models import DGP
11 |
12 | sys.path.append("./")
13 |
14 |
15 | # Arguments for demo
16 | def add_example_parser(parser):
17 | parser.add_argument(
18 | '--image_path', type=str, default='',
19 | help='Path of the image to be processed (default: %(default)s)')
20 | parser.add_argument(
21 | '--class', type=int, default=-1,
22 | help='class index of the image (default: %(default)s)')
23 | parser.add_argument(
24 | '--image_path2', type=str, default='',
25 | help='Path of the 2nd image to be processed, used in "morphing" mode (default: %(default)s)')
26 | parser.add_argument(
27 | '--class2', type=int, default=-1,
28 | help='class index of the 2nd image, used in "morphing" mode (default: %(default)s)')
29 | return parser
30 |
31 |
32 | # prepare arguments and save in config
33 | parser = utils.prepare_parser()
34 | parser = utils.add_dgp_parser(parser)
35 | parser = add_example_parser(parser)
36 | config = vars(parser.parse_args())
37 | utils.dgp_update_config(config)
38 |
39 | # set random seed
40 | utils.seed_rng(config['seed'])
41 |
42 | if not os.path.exists('{}/images'.format(config['exp_path'])):
43 | os.makedirs('{}/images'.format(config['exp_path']))
44 | if not os.path.exists('{}/images_sheet'.format(config['exp_path'])):
45 | os.makedirs('{}/images_sheet'.format(config['exp_path']))
46 |
47 | # initialize DGP model
48 | dgp = DGP(config)
49 |
50 | # prepare the target image
51 | img = utils.get_img(config['image_path'], config['resolution']).cuda()
52 | category = torch.Tensor([config['class']]).long().cuda()
53 | dgp.set_target(img, category, config['image_path'])
54 |
55 | # prepare initial latent vector
56 | dgp.select_z(select_y=True if config['class'] < 0 else False)
57 | # start reconstruction
58 | loss_dict = dgp.run()
59 |
60 | if config['dgp_mode'] == 'category_transfer':
61 | save_imgs = img.clone().cpu()
62 | for i in range(151, 294): # dog & cat
63 | # for i in range(7, 25): # bird
64 | with torch.no_grad():
65 | x = dgp.G(dgp.z, dgp.G.shared(dgp.y.fill_(i)))
66 | utils.save_img(
67 | x[0],
68 | '%s/images/%s_class%d.jpg' % (config['exp_path'], dgp.img_name, i))
69 | save_imgs = torch.cat((save_imgs, x.cpu()), dim=0)
70 | vutils.save_image(
71 | save_imgs,
72 | '%s/images_sheet/%s_categories.jpg' % (config['exp_path'], dgp.img_name),
73 | nrow=int(save_imgs.size(0)**0.5),
74 | normalize=True)
75 |
76 | elif config['dgp_mode'] == 'morphing':
77 | dgp2 = DGP(config)
78 | dgp_interp = DGP(config)
79 |
80 | img2 = utils.get_img(config['image_path2'], config['resolution']).cuda()
81 | category2 = torch.Tensor([config['class2']]).long().cuda()
82 |
83 | dgp2.set_target(img2, category2, config['image_path2'])
84 | dgp2.select_z(select_y=True if config['class2'] < 0 else False)
85 | loss_dict = dgp2.run()
86 |
87 | weight1 = dgp.G.state_dict()
88 | weight2 = dgp2.G.state_dict()
89 | weight_interp = OrderedDict()
90 | save_imgs = []
91 | with torch.no_grad():
92 | for i in range(11):
93 | alpha = i / 10
94 | # interpolate between both latent vector and generator weight
95 | z_interp = alpha * dgp.z + (1 - alpha) * dgp2.z
96 | y_interp = alpha * dgp.G.shared(dgp.y) + (1 - alpha) * dgp2.G.shared(dgp2.y)
97 | for k, w1 in weight1.items():
98 | w2 = weight2[k]
99 | weight_interp[k] = alpha * w1 + (1 - alpha) * w2
100 | dgp_interp.G.load_state_dict(weight_interp)
101 | x_interp = dgp_interp.G(z_interp, y_interp)
102 | save_imgs.append(x_interp.cpu())
103 | # save images
104 | save_path = '%s/images/%s_%s' % (config['exp_path'], dgp.img_name, dgp2.img_name)
105 | if not os.path.exists(save_path):
106 | os.makedirs(save_path)
107 | utils.save_img(x_interp[0], '%s/%03d.jpg' % (save_path, i + 1))
108 | save_imgs = torch.cat(save_imgs, 0)
109 | vutils.save_image(
110 | save_imgs,
111 | '%s/images_sheet/morphing_%s_%s.jpg' % (config['exp_path'], dgp.img_name, dgp2.img_name),
112 | nrow=int(save_imgs.size(0)**0.5),
113 | normalize=True)
114 |
--------------------------------------------------------------------------------
/experiments/examples/run_SR.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | WORK_PATH=$(dirname $0)
4 | IMAGE_PATH=${1:-data/ILSVRC2012_val_00042095.JPEG}
5 | CLASS=${2:-260}
6 | #IMAGE_PATH=data/ILSVRC2012_val_00000525.JPEG
7 | #CLASS=863
8 |
9 | python -u -W ignore example.py \
10 | --exp_path $WORK_PATH \
11 | --image_path $IMAGE_PATH \
12 | --class $CLASS \
13 | --seed 0 \
14 | --dgp_mode SR \
15 | --update_G \
16 | --ftr_num 8 8 8 8 8 \
17 | --ft_num 2 3 4 5 7 \
18 | --lr_ratio 1.0 1.0 1.0 1.0 1.0 \
19 | --w_D_loss 1 1 1 1 1 \
20 | --w_nll 0.02 \
21 | --w_mse 1 1 1 1 1 \
22 | --select_num 500 \
23 | --sample_std 0.3 \
24 | --iterations 200 200 200 200 200 \
25 | --G_lrs 5e-5 5e-5 2e-5 1e-5 1e-5 \
26 | --z_lrs 2e-3 1e-3 2e-5 1e-5 1e-5 \
27 | --use_in False False False False False \
28 | --resolution 256 \
29 | --weights_root pretrained \
30 | --load_weights 256 \
31 | --G_ch 96 --D_ch 96 \
32 | --G_shared \
33 | --hier --dim_z 120 --shared_dim 128 \
34 | --skip_init --use_ema
35 |
--------------------------------------------------------------------------------
/experiments/examples/run_category_transfer.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | WORK_PATH=$(dirname $0)
4 | IMAGE_PATH=${1:-data/ILSVRC2012_val_00008229.JPEG}
5 | CLASS=${2:-174}
6 |
7 | python -u -W ignore example.py \
8 | --exp_path $WORK_PATH \
9 | --image_path $IMAGE_PATH \
10 | --class $CLASS \
11 | --seed 4 \
12 | --dgp_mode category_transfer \
13 | --update_G \
14 | --ftr_num 8 8 8 \
15 | --ft_num 7 7 7 \
16 | --lr_ratio 1 1 1 \
17 | --w_D_loss 1 1 1 \
18 | --w_nll 0.2 \
19 | --w_mse 0 0 0 \
20 | --select_num 500 \
21 | --sample_std 0.5 \
22 | --iterations 125 125 100 \
23 | --G_lrs 2e-7 2e-5 2e-6 \
24 | --z_lrs 1e-1 1e-2 2e-4 \
25 | --use_in False False False \
26 | --resolution 256 \
27 | --weights_root pretrained \
28 | --load_weights 256 \
29 | --G_ch 96 --D_ch 96 \
30 | --G_shared \
31 | --hier --dim_z 120 --shared_dim 128 \
32 | --skip_init --use_ema
33 |
--------------------------------------------------------------------------------
/experiments/examples/run_colorization.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | WORK_PATH=$(dirname $0)
4 | IMAGE_PATH=${1:-data/ILSVRC2012_val_00003004.JPEG}
5 | CLASS=${2:-693} # set CLASS=-1 if you don't know the class
6 | #IMAGE_PATH=data/ILSVRC2012_val_00004291.JPEG
7 | #CLASS=442
8 |
9 | python -u -W ignore example.py \
10 | --exp_path $WORK_PATH \
11 | --image_path $IMAGE_PATH \
12 | --class $CLASS \
13 | --seed 0 \
14 | --dgp_mode colorization \
15 | --update_G \
16 | --ftr_num 7 7 7 7 7 \
17 | --ft_num 2 3 4 5 6 \
18 | --lr_ratio 0.7 0.7 0.8 0.9 1.0 \
19 | --w_D_loss 1 1 1 1 1 \
20 | --w_nll 0.02 \
21 | --w_mse 0 0 0 0 0 \
22 | --select_num 500 \
23 | --sample_std 0.5 \
24 | --iterations 200 200 300 400 300 \
25 | --G_lrs 5e-5 5e-5 5e-5 5e-5 2e-5 \
26 | --z_lrs 2e-3 1e-3 5e-4 5e-5 2e-5 \
27 | --use_in False False False False False \
28 | --resolution 256 \
29 | --weights_root pretrained \
30 | --load_weights 256 \
31 | --G_ch 96 --D_ch 96 \
32 | --G_shared \
33 | --hier --dim_z 120 --shared_dim 128 \
34 | --skip_init --use_ema
35 |
--------------------------------------------------------------------------------
/experiments/examples/run_inpainting.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | WORK_PATH=$(dirname $0)
4 | IMAGE_PATH=${1:-data/ILSVRC2012_val_00001970.JPEG}
5 | CLASS=${2:--1} # set CLASS=-1 if you don't know the class
6 | #IMAGE_PATH=data/ILSVRC2012_val_00022130.JPEG
7 | #CLASS=511
8 |
9 | python -u -W ignore example.py \
10 | --exp_path $WORK_PATH \
11 | --image_path $IMAGE_PATH \
12 | --class $CLASS \
13 | --seed 4 \
14 | --dgp_mode inpainting \
15 | --update_G \
16 | --update_embed \
17 | --ftr_num 8 8 8 8 8 \
18 | --ft_num 7 7 7 7 7 \
19 | --lr_ratio 1.0 1.0 1.0 1.0 1.0 \
20 | --w_D_loss 1 1 1 1 0.5 \
21 | --w_nll 0.02 \
22 | --w_mse 1 1 1 1 10 \
23 | --select_num 500 \
24 | --sample_std 0.3 \
25 | --iterations 200 200 200 200 200 \
26 | --G_lrs 5e-5 5e-5 2e-5 2e-5 1e-5 \
27 | --z_lrs 2e-3 1e-3 2e-5 2e-5 1e-5 \
28 | --use_in False False False False False \
29 | --resolution 256 \
30 | --weights_root pretrained \
31 | --load_weights 256 \
32 | --G_ch 96 --D_ch 96 \
33 | --G_shared \
34 | --hier --dim_z 120 --shared_dim 128 \
35 | --skip_init --use_ema
36 |
--------------------------------------------------------------------------------
/experiments/examples/run_inpainting_list.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | WORK_PATH=$(dirname $0)
4 |
5 | python -u -W ignore main.py \
6 | --exp_path $WORK_PATH \
7 | --root_dir data/others \
8 | --list_file data/others/list.txt \
9 | --seed 2 \
10 | --dgp_mode inpainting \
11 | --update_G \
12 | --update_embed \
13 | --ftr_num 8 8 8 8 8 \
14 | --ft_num 7 7 7 7 7 \
15 | --lr_ratio 1.0 1.0 1.0 1.0 1.0 \
16 | --w_D_loss 1 1 1 1 0.5 \
17 | --w_nll 0.02 \
18 | --w_mse 1 1 1 1 10 \
19 | --select_num 1000 \
20 | --sample_std 0.3 \
21 | --iterations 200 200 200 200 200 \
22 | --G_lrs 5e-5 5e-5 2e-5 2e-5 1e-5 \
23 | --z_lrs 2e-3 1e-3 2e-5 2e-5 1e-5 \
24 | --use_in False False False False False \
25 | --resolution 256 \
26 | --weights_root pretrained \
27 | --load_weights 256 \
28 | --G_ch 96 --D_ch 96 \
29 | --G_shared \
30 | --hier --dim_z 120 --shared_dim 128 \
31 | --skip_init --use_ema --no_tb
32 |
--------------------------------------------------------------------------------
/experiments/examples/run_jitter.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | WORK_PATH=$(dirname $0)
4 | IMAGE_PATH=${1:-data/ILSVRC2012_val_00001172.JPEG}
5 | CLASS=${2:-269}
6 |
7 | python -u -W ignore example.py \
8 | --exp_path $WORK_PATH \
9 | --image_path $IMAGE_PATH \
10 | --class $CLASS \
11 | --seed 0 \
12 | --dgp_mode jitter \
13 | --update_G \
14 | --ftr_num 8 8 8 \
15 | --ft_num 8 8 8 \
16 | --lr_ratio 1 1 1 \
17 | --w_D_loss 1 1 1 \
18 | --w_nll 0.2 \
19 | --w_mse 0 0 0 \
20 | --select_num 500 \
21 | --sample_std 0.5 \
22 | --iterations 125 125 100 \
23 | --G_lrs 2e-7 2e-5 2e-6 \
24 | --z_lrs 1e-1 1e-2 2e-4 \
25 | --use_in False False False \
26 | --resolution 256 \
27 | --weights_root pretrained \
28 | --load_weights 256 \
29 | --G_ch 96 --D_ch 96 \
30 | --G_shared \
31 | --hier --dim_z 120 --shared_dim 128 \
32 | --skip_init --use_ema
33 |
--------------------------------------------------------------------------------
/experiments/examples/run_morphing.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | WORK_PATH=$(dirname $0)
4 | IMAGE_PATH=${1:-data/ILSVRC2012_val_00001172.JPEG}
5 | CLASS=${2:-269}
6 | IMAGE_PATH2=${3:-data/ILSVRC2012_val_00020814.JPEG}
7 | CLASS2=${4:-185}
8 | #IMAGE_PATH=data/ILSVRC2012_val_00044065.JPEG
9 | #CLASS=425
10 | #IMAGE_PATH2=data/ILSVRC2012_val_00044640.JPEG
11 | #CLASS2=425
12 |
13 | python -u -W ignore example.py \
14 | --exp_path $WORK_PATH \
15 | --seed 0 \
16 | --image_path $IMAGE_PATH \
17 | --class $CLASS \
18 | --image_path2 $IMAGE_PATH2 \
19 | --class2 $CLASS2 \
20 | --dgp_mode morphing \
21 | --update_G \
22 | --update_embed \
23 | --ftr_num 8 8 8 \
24 | --ft_num 8 8 8 \
25 | --lr_ratio 1 1 1 \
26 | --w_D_loss 1 1 1 \
27 | --w_nll 0.2 \
28 | --w_mse 0 0 0 \
29 | --select_num 500 \
30 | --sample_std 0.5 \
31 | --iterations 125 125 100 \
32 | --G_lrs 2e-7 2e-5 2e-6 \
33 | --z_lrs 1e-1 1e-2 2e-4 \
34 | --use_in False False False \
35 | --resolution 256 \
36 | --weights_root pretrained \
37 | --load_weights ch64_256 \
38 | --G_ch 64 --D_ch 64 \
39 | --G_shared \
40 | --hier --dim_z 120 --shared_dim 128 \
41 | --skip_init --use_ema
42 |
--------------------------------------------------------------------------------
/experiments/imagenet1k_128/SR/D_biased/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | WORK_PATH=$(dirname $0)
4 |
5 | python -u -W ignore main.py \
6 | --seed 0 \
7 | --exp_path $WORK_PATH \
8 | --root_dir /path_to_your_ImageNet/val \
9 | --list_file scripts/imagenet_val_1k.txt \
10 | --dgp_mode SR \
11 | --update_G \
12 | --ftr_num 8 8 8 8 8 \
13 | --ft_num 2 3 4 5 6 \
14 | --lr_ratio 1 1 1 1 1 \
15 | --w_D_loss 1 1 1 1 1 \
16 | --w_nll 0.02 \
17 | --w_mse 1 1 1 1 1 \
18 | --select_num 500 \
19 | --sample_std 0.3 \
20 | --iterations 200 200 200 200 200 \
21 | --G_lrs 5e-5 5e-5 2e-5 1e-5 1e-5 \
22 | --z_lrs 2e-3 1e-3 2e-5 1e-5 1e-5 \
23 | --use_in False False False False False \
24 | --resolution 128 \
25 | --weights_root pretrained \
26 | --load_weights 128 \
27 | --G_ch 96 --D_ch 96 \
28 | --G_shared \
29 | --hier --dim_z 120 --shared_dim 128 \
30 | --skip_init --use_ema \
31 | 2>&1 | tee $WORK_PATH/log.txt
32 |
--------------------------------------------------------------------------------
/experiments/imagenet1k_128/SR/D_biased/train_slurm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | PARTITION=$1 # input your partition on lustre
4 | WORK_PATH=$(dirname $0)
5 |
6 | # make sure that the image list length could be divided by
7 | # the number of threads, e.g., 1000 % 4 = 0
8 | srun -p $PARTITION -n4 --gres=gpu:4 --ntasks-per-node=4 \
9 | python -u -W ignore main.py \
10 | --dist \
11 | --port 12348 \
12 | --seed 0 \
13 | --exp_path $WORK_PATH \
14 | --root_dir /path_to_your_ImageNet/val \
15 | --list_file scripts/imagenet_val_1k.txt \
16 | --dgp_mode SR \
17 | --update_G \
18 | --ftr_num 8 8 8 8 8 \
19 | --ft_num 2 3 4 5 6 \
20 | --lr_ratio 1 1 1 1 1 \
21 | --w_D_loss 1 1 1 1 1 \
22 | --w_nll 0.02 \
23 | --w_mse 1 1 1 1 1 \
24 | --select_num 500 \
25 | --sample_std 0.3 \
26 | --iterations 200 200 200 200 200 \
27 | --G_lrs 5e-5 5e-5 2e-5 1e-5 1e-5 \
28 | --z_lrs 2e-3 1e-3 2e-5 1e-5 1e-5 \
29 | --use_in False False False False False \
30 | --resolution 128 \
31 | --weights_root pretrained \
32 | --load_weights 128 \
33 | --G_ch 96 --D_ch 96 \
34 | --G_shared \
35 | --hier --dim_z 120 --shared_dim 128 \
36 | --skip_init --use_ema \
37 | 2>&1 | tee $WORK_PATH/log.txt
38 |
--------------------------------------------------------------------------------
/experiments/imagenet1k_128/SR/MSE_biased/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | WORK_PATH=$(dirname $0)
4 |
5 | python -u -W ignore main.py \
6 | --seed 0 \
7 | --exp_path $WORK_PATH \
8 | --root_dir /path_to_your_ImageNet/val \
9 | --list_file scripts/imagenet_val_1k.txt \
10 | --dgp_mode SR \
11 | --update_G \
12 | --ftr_num 8 8 8 8 8 \
13 | --ft_num 2 3 4 5 6 \
14 | --lr_ratio 1 1 1 1 1 \
15 | --w_D_loss 1 1 1 0.5 0.1 \
16 | --w_nll 0.02 \
17 | --w_mse 10 10 10 50 100 \
18 | --select_num 500 \
19 | --sample_std 0.3 \
20 | --iterations 200 200 200 200 200 \
21 | --G_lrs 2e-4 2e-4 1e-4 1e-4 1e-5 \
22 | --z_lrs 1e-3 1e-3 1e-4 1e-4 1e-5 \
23 | --use_in True True True True True \
24 | --resolution 128 \
25 | --weights_root pretrained \
26 | --load_weights 128 \
27 | --G_ch 96 --D_ch 96 \
28 | --G_shared \
29 | --hier --dim_z 120 --shared_dim 128 \
30 | --skip_init --use_ema \
31 | 2>&1 | tee $WORK_PATH/log.txt
32 |
--------------------------------------------------------------------------------
/experiments/imagenet1k_128/SR/MSE_biased/train_slurm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | PARTITION=$1 # input your partition on lustre
4 | WORK_PATH=$(dirname $0)
5 |
6 | # make sure that the image list length could be divided by
7 | # the number of threads, e.g., 1000 % 4 = 0
8 | srun -p $PARTITION -n4 --gres=gpu:4 --ntasks-per-node=4 \
9 | python -u -W ignore main.py \
10 | --dist \
11 | --port 12347 \
12 | --seed 0 \
13 | --exp_path $WORK_PATH \
14 | --root_dir /path_to_your_ImageNet/val \
15 | --list_file scripts/imagenet_val_1k.txt \
16 | --dgp_mode SR \
17 | --update_G \
18 | --ftr_num 8 8 8 8 8 \
19 | --ft_num 2 3 4 5 6 \
20 | --lr_ratio 1 1 1 1 1 \
21 | --w_D_loss 1 1 1 0.5 0.1 \
22 | --w_nll 0.02 \
23 | --w_mse 10 10 10 50 100 \
24 | --select_num 500 \
25 | --sample_std 0.3 \
26 | --iterations 200 200 200 200 200 \
27 | --G_lrs 2e-4 2e-4 1e-4 1e-4 1e-5 \
28 | --z_lrs 1e-3 1e-3 1e-4 1e-4 1e-5 \
29 | --use_in True True True True True \
30 | --resolution 128 \
31 | --weights_root pretrained \
32 | --load_weights 128 \
33 | --G_ch 96 --D_ch 96 \
34 | --G_shared \
35 | --hier --dim_z 120 --shared_dim 128 \
36 | --skip_init --use_ema \
37 | 2>&1 | tee $WORK_PATH/log.txt
38 |
--------------------------------------------------------------------------------
/experiments/imagenet1k_128/colorization/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | WORK_PATH=$(dirname $0)
4 |
5 | python -u -W ignore main.py \
6 | --seed 0 \
7 | --exp_path $WORK_PATH \
8 | --root_dir /path_to_your_ImageNet/val \
9 | --list_file scripts/imagenet_val_1k.txt \
10 | --dgp_mode colorization \
11 | --update_G \
12 | --ftr_num 7 7 7 7 7 \
13 | --ft_num 2 3 4 5 6 \
14 | --lr_ratio 0.7 0.7 0.8 0.9 1 \
15 | --w_D_loss 1 1 1 1 1 \
16 | --w_nll 0.02 \
17 | --w_mse 0 0 0 0 0 \
18 | --select_num 500 \
19 | --sample_std 0.5 \
20 | --iterations 200 200 300 400 300 \
21 | --G_lrs 5e-5 5e-5 5e-5 5e-5 2e-5 \
22 | --z_lrs 2e-3 1e-3 5e-4 5e-5 2e-5 \
23 | --use_in False False False False False \
24 | --resolution 128 \
25 | --weights_root pretrained \
26 | --load_weights 128 \
27 | --G_ch 96 --D_ch 96 \
28 | --G_shared \
29 | --hier --dim_z 120 --shared_dim 128 \
30 | --skip_init --use_ema \
31 | 2>&1 | tee $WORK_PATH/log.txt
32 |
--------------------------------------------------------------------------------
/experiments/imagenet1k_128/colorization/train_slurm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | PARTITION=$1 # input your partition on lustre
4 | WORK_PATH=$(dirname $0)
5 |
6 | # make sure that the image list length could be divided by
7 | # the number of threads, e.g., 1000 % 4 = 0
8 | srun -p $PARTITION -n4 --gres=gpu:4 --ntasks-per-node=4 \
9 | python -u -W ignore main.py \
10 | --dist \
11 | --port 12345 \
12 | --seed 0 \
13 | --exp_path $WORK_PATH \
14 | --root_dir /path_to_your_ImageNet/val \
15 | --list_file scripts/imagenet_val_1k.txt \
16 | --dgp_mode colorization \
17 | --update_G \
18 | --ftr_num 7 7 7 7 7 \
19 | --ft_num 2 3 4 5 6 \
20 | --lr_ratio 0.7 0.7 0.8 0.9 1 \
21 | --w_D_loss 1 1 1 1 1 \
22 | --w_nll 0.02 \
23 | --w_mse 0 0 0 0 0 \
24 | --select_num 500 \
25 | --sample_std 0.5 \
26 | --iterations 200 200 300 400 300 \
27 | --G_lrs 5e-5 5e-5 5e-5 5e-5 2e-5 \
28 | --z_lrs 2e-3 1e-3 5e-4 5e-5 2e-5 \
29 | --use_in False False False False False \
30 | --resolution 128 \
31 | --weights_root pretrained \
32 | --load_weights 128 \
33 | --G_ch 96 --D_ch 96 \
34 | --G_shared \
35 | --hier --dim_z 120 --shared_dim 128 \
36 | --skip_init --use_ema \
37 | 2>&1 | tee $WORK_PATH/log.txt
38 |
--------------------------------------------------------------------------------
/experiments/imagenet1k_128/inpainting/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | WORK_PATH=$(dirname $0)
4 |
5 | python -u -W ignore main.py \
6 | --seed 0 \
7 | --exp_path $WORK_PATH \
8 | --root_dir /path_to_your_ImageNet/val \
9 | --list_file scripts/imagenet_val_1k.txt \
10 | --dgp_mode inpainting \
11 | --update_G \
12 | --update_embed \
13 | --ftr_num 8 8 8 8 8 \
14 | --ft_num 6 6 6 6 6 \
15 | --lr_ratio 1 1 1 1 1 \
16 | --w_D_loss 1 1 1 0.1 0.1 \
17 | --w_nll 0.02 \
18 | --w_mse 10 10 10 100 100 \
19 | --select_num 500 \
20 | --sample_std 0.3 \
21 | --iterations 200 200 200 200 200 \
22 | --G_lrs 2e-4 2e-4 1e-4 1e-4 1e-5 \
23 | --z_lrs 1e-3 1e-3 1e-4 1e-4 1e-5 \
24 | --use_in True True True True True \
25 | --resolution 128 \
26 | --weights_root pretrained \
27 | --load_weights 128 \
28 | --G_ch 96 --D_ch 96 \
29 | --G_shared \
30 | --hier --dim_z 120 --shared_dim 128 \
31 | --skip_init --use_ema \
32 | 2>&1 | tee $WORK_PATH/log.txt
33 |
--------------------------------------------------------------------------------
/experiments/imagenet1k_128/inpainting/train_slurm.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env sh
2 |
3 | PARTITION=$1 # input your partition on lustre
4 | WORK_PATH=$(dirname $0)
5 |
6 | # make sure that the image list length could be divided by
7 | # the number of threads, e.g., 1000 % 4 = 0
8 | srun -p $PARTITION -n4 --gres=gpu:4 --ntasks-per-node=4 \
9 | python -u -W ignore main.py \
10 | --dist \
11 | --port 12346 \
12 | --seed 0 \
13 | --exp_path $WORK_PATH \
14 | --root_dir /path_to_your_ImageNet/val \
15 | --list_file scripts/imagenet_val_1k.txt \
16 | --dgp_mode inpainting \
17 | --update_G \
18 | --update_embed \
19 | --ftr_num 8 8 8 8 8 \
20 | --ft_num 6 6 6 6 6 \
21 | --lr_ratio 1 1 1 1 1 \
22 | --w_D_loss 1 1 1 0.1 0.1 \
23 | --w_nll 0.02 \
24 | --w_mse 10 10 10 100 100 \
25 | --select_num 500 \
26 | --sample_std 0.3 \
27 | --iterations 200 200 200 200 200 \
28 | --G_lrs 2e-4 2e-4 1e-4 1e-4 1e-5 \
29 | --z_lrs 1e-3 1e-3 1e-4 1e-4 1e-5 \
30 | --use_in True True True True True \
31 | --resolution 128 \
32 | --weights_root pretrained \
33 | --load_weights 128 \
34 | --G_ch 96 --D_ch 96 \
35 | --G_shared \
36 | --hier --dim_z 120 --shared_dim 128 \
37 | --skip_init --use_ema \
38 | 2>&1 | tee $WORK_PATH/log.txt
39 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.multiprocessing as mp
3 |
4 | import utils
5 | from trainer import Trainer
6 | from utils import dist_init
7 |
8 |
9 | def main():
10 | parser = utils.prepare_parser()
11 | parser = utils.add_dgp_parser(parser)
12 | config = vars(parser.parse_args())
13 | utils.dgp_update_config(config)
14 | print(config)
15 |
16 | rank = 0
17 | if mp.get_start_method(allow_none=True) != 'spawn':
18 | mp.set_start_method('spawn', force=True)
19 | if config['dist']:
20 | rank, world_size = dist_init(config['port'])
21 |
22 | # Seed RNG
23 | utils.seed_rng(rank + config['seed'])
24 |
25 | # Setup cudnn.benchmark for free speed
26 | torch.backends.cudnn.benchmark = True
27 |
28 | # train
29 | trainer = Trainer(config)
30 | trainer.run()
31 |
32 |
33 | if __name__ == '__main__':
34 | main()
35 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .biggan import *
2 | from .dgp import *
3 | from .nethook import *
4 |
--------------------------------------------------------------------------------
/models/biggan.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | from torch.nn import init
8 |
9 | from . import layers
10 |
11 |
12 | # Architectures for G
13 | # Attention is passed in in the format '32_64' to mean applying an attention
14 | # block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
15 | def G_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
16 | arch = {}
17 | arch[512] = {
18 | 'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
19 | 'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
20 | 'upsample': [True] * 7,
21 | 'resolution': [8, 16, 32, 64, 128, 256, 512],
22 | 'attention':
23 | {2**i: (2**i in [int(item) for item in attention.split('_')])
24 | for i in range(3, 10)}
25 | }
26 | arch[256] = {
27 | 'in_channels': [ch * item for item in [16, 16, 8, 8, 4, 2]],
28 | 'out_channels': [ch * item for item in [16, 8, 8, 4, 2, 1]],
29 | 'upsample': [True] * 6,
30 | 'resolution': [8, 16, 32, 64, 128, 256],
31 | 'attention':
32 | {2**i: (2**i in [int(item) for item in attention.split('_')])
33 | for i in range(3, 9)}
34 | }
35 | arch[128] = {
36 | 'in_channels': [ch * item for item in [16, 16, 8, 4, 2]],
37 | 'out_channels': [ch * item for item in [16, 8, 4, 2, 1]],
38 | 'upsample': [True] * 5,
39 | 'resolution': [8, 16, 32, 64, 128],
40 | 'attention':
41 | {2**i: (2**i in [int(item) for item in attention.split('_')])
42 | for i in range(3, 8)}
43 | }
44 | arch[64] = {
45 | 'in_channels': [ch * item for item in [16, 16, 8, 4]],
46 | 'out_channels': [ch * item for item in [16, 8, 4, 2]],
47 | 'upsample': [True] * 4,
48 | 'resolution': [8, 16, 32, 64],
49 | 'attention':
50 | {2**i: (2**i in [int(item) for item in attention.split('_')])
51 | for i in range(3, 7)}
52 | }
53 | arch[32] = {
54 | 'in_channels': [ch * item for item in [4, 4, 4]],
55 | 'out_channels': [ch * item for item in [4, 4, 4]],
56 | 'upsample': [True] * 3,
57 | 'resolution': [8, 16, 32],
58 | 'attention':
59 | {2**i: (2**i in [int(item) for item in attention.split('_')])
60 | for i in range(3, 6)}
61 | }
62 |
63 | return arch
64 |
65 |
66 | class Generator(nn.Module):
67 |
68 | def __init__(self,
69 | G_ch=64,
70 | dim_z=128,
71 | bottom_width=4,
72 | resolution=128,
73 | G_kernel_size=3,
74 | G_attn='64',
75 | n_classes=1000,
76 | num_G_SVs=1,
77 | num_G_SV_itrs=1,
78 | G_shared=True,
79 | shared_dim=0,
80 | hier=False,
81 | cross_replica=False,
82 | mybn=False,
83 | G_activation=nn.ReLU(inplace=True),
84 | optimizer='Adam',
85 | G_lr=5e-5,
86 | G_B1=0.0,
87 | G_B2=0.999,
88 | adam_eps=1e-8,
89 | BN_eps=1e-5,
90 | SN_eps=1e-12,
91 | G_mixed_precision=False,
92 | G_fp16=False,
93 | G_init='ortho',
94 | skip_init=False,
95 | no_optim=False,
96 | G_param='SN',
97 | norm_style='bn',
98 | **kwargs):
99 | super(Generator, self).__init__()
100 | # Channel width mulitplier
101 | self.ch = G_ch
102 | # Dimensionality of the latent space
103 | self.dim_z = dim_z
104 | # The initial spatial dimensions
105 | self.bottom_width = bottom_width
106 | # Resolution of the output
107 | self.resolution = resolution
108 | # Kernel size?
109 | self.kernel_size = G_kernel_size
110 | # Attention?
111 | self.attention = G_attn
112 | # number of classes, for use in categorical conditional generation
113 | self.n_classes = n_classes
114 | # Use shared embeddings?
115 | self.G_shared = G_shared
116 | # Dimensionality of the shared embedding? Unused if not using G_shared
117 | self.shared_dim = shared_dim if shared_dim > 0 else dim_z
118 | # Hierarchical latent space?
119 | self.hier = hier
120 | # Cross replica batchnorm?
121 | self.cross_replica = cross_replica
122 | # Use my batchnorm?
123 | self.mybn = mybn
124 | # nonlinearity for residual blocks
125 | self.activation = G_activation
126 | # Initialization style
127 | self.init = G_init
128 | # Parameterization style
129 | self.G_param = G_param
130 | # Normalization style
131 | self.norm_style = norm_style
132 | # Epsilon for BatchNorm?
133 | self.BN_eps = BN_eps
134 | # Epsilon for Spectral Norm?
135 | self.SN_eps = SN_eps
136 | # fp16?
137 | self.fp16 = G_fp16
138 | # Architecture dict
139 | self.arch = G_arch(self.ch, self.attention)[resolution]
140 |
141 | # If using hierarchical latents, adjust z
142 | if self.hier:
143 | # Number of places z slots into
144 | self.num_slots = len(self.arch['in_channels']) + 1
145 | self.z_chunk_size = (self.dim_z // self.num_slots)
146 | # Recalculate latent dimensionality for even splitting into chunks
147 | self.dim_z = self.z_chunk_size * self.num_slots
148 | else:
149 | self.num_slots = 1
150 | self.z_chunk_size = 0
151 |
152 | # Which convs, batchnorms, and linear layers to use
153 | if self.G_param == 'SN':
154 | self.which_conv = functools.partial(
155 | layers.SNConv2d,
156 | kernel_size=3,
157 | padding=1,
158 | num_svs=num_G_SVs,
159 | num_itrs=num_G_SV_itrs,
160 | eps=self.SN_eps)
161 | self.which_linear = functools.partial(
162 | layers.SNLinear, num_svs=num_G_SVs, num_itrs=num_G_SV_itrs, eps=self.SN_eps)
163 | else:
164 | self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
165 | self.which_linear = nn.Linear
166 |
167 | # We use a non-spectral-normed embedding here regardless;
168 | # For some reason applying SN to G's embedding seems to randomly cripple G
169 | self.which_embedding = nn.Embedding
170 | bn_linear = (
171 | functools.partial(self.which_linear, bias=False)
172 | if self.G_shared else self.which_embedding)
173 | self.which_bn = functools.partial(
174 | layers.ccbn,
175 | which_linear=bn_linear,
176 | cross_replica=self.cross_replica,
177 | mybn=self.mybn,
178 | input_size=(self.shared_dim + self.z_chunk_size if self.G_shared else self.n_classes),
179 | norm_style=self.norm_style,
180 | eps=self.BN_eps)
181 |
182 | # Prepare model
183 | # If not using shared embeddings, self.shared is just a passthrough
184 | self.shared = (
185 | self.which_embedding(n_classes, self.shared_dim) if G_shared else layers.identity())
186 |
187 | # First linear layer
188 | self.linear = self.which_linear(self.dim_z // self.num_slots,
189 | self.arch['in_channels'][0] * (self.bottom_width**2))
190 |
191 | # self.blocks is a doubly-nested list of modules, the outer loop intended
192 | # to be over blocks at a given resolution (resblocks and/or self-attention)
193 | # while the inner loop is over a given block
194 | self.blocks = []
195 | for index in range(len(self.arch['out_channels'])):
196 | self.blocks += [[
197 | layers.GBlock(
198 | in_channels=self.arch['in_channels'][index],
199 | out_channels=self.arch['out_channels'][index],
200 | which_conv=self.which_conv,
201 | which_bn=self.which_bn,
202 | activation=self.activation,
203 | upsample=(functools.partial(F.interpolate, scale_factor=2)
204 | if self.arch['upsample'][index] else None))
205 | ]]
206 |
207 | # If attention on this block, attach it to the end
208 | if self.arch['attention'][self.arch['resolution'][index]]:
209 | print('Adding attention layer in G at resolution %d' %
210 | self.arch['resolution'][index])
211 | self.blocks[-1] += [
212 | layers.Attention(self.arch['out_channels'][index], self.which_conv)
213 | ]
214 |
215 | # Turn self.blocks into a ModuleList so that it's all properly registered.
216 | self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
217 |
218 | # output layer: batchnorm-relu-conv.
219 | # Consider using a non-spectral conv here
220 | self.output_layer = nn.Sequential(
221 | layers.bn(
222 | self.arch['out_channels'][-1], cross_replica=self.cross_replica, mybn=self.mybn),
223 | self.activation, self.which_conv(self.arch['out_channels'][-1], 3))
224 |
225 | # Initialize weights. Optionally skip init for testing.
226 | if not skip_init:
227 | self.init_weights()
228 |
229 | # Set up optimizer
230 | # If this is an EMA copy, no need for an optim, so just return now
231 | if no_optim:
232 | return
233 | self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
234 | if G_mixed_precision:
235 | print('Using fp16 adam in G...')
236 | import utils
237 | self.optim = utils.Adam16(
238 | params=self.parameters(),
239 | lr=self.lr,
240 | betas=(self.B1, self.B2),
241 | weight_decay=0,
242 | eps=self.adam_eps)
243 | else:
244 | if optimizer == 'Adam':
245 | self.optim = optim.Adam(
246 | params=self.parameters(),
247 | lr=self.lr,
248 | betas=(self.B1, self.B2),
249 | weight_decay=0,
250 | eps=self.adam_eps)
251 | elif optimizer == 'SGD':
252 | self.optim = optim.SGD(
253 | params=self.parameters(), lr=self.lr, momentum=0.9, weight_decay=0)
254 | else:
255 | raise ValueError("optim has to be Adam or SGD, " "but got {}".format(optimizer))
256 |
257 | # LR scheduling, left here for forward compatibility
258 | # self.lr_sched = {'itr' : 0}# if self.progressive else {}
259 | # self.j = 0
260 |
261 | # Initialize
262 | def init_weights(self):
263 | self.param_count = 0
264 | for module in self.modules():
265 | if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear)
266 | or isinstance(module, nn.Embedding)):
267 | if self.init == 'ortho':
268 | init.orthogonal_(module.weight)
269 | elif self.init == 'N02':
270 | init.normal_(module.weight, 0, 0.02)
271 | elif self.init in ['glorot', 'xavier']:
272 | init.xavier_uniform_(module.weight)
273 | else:
274 | print('Init style not recognized...')
275 | self.param_count += sum([p.data.nelement() for p in module.parameters()])
276 | print('Param count for G' 's initialized parameters: %d' % self.param_count)
277 |
278 | def reset_in_init(self):
279 | for index, blocklist in enumerate(self.blocks):
280 | for block in blocklist:
281 | if isinstance(block, layers.GBlock):
282 | block.reset_in_init()
283 |
284 | def get_params(self, index=0, update_embed=False):
285 | if index == 0:
286 | for param in self.linear.parameters():
287 | yield param
288 | if update_embed:
289 | for param in self.shared.parameters():
290 | yield param
291 | elif index < len(self.blocks) + 1:
292 | for param in self.blocks[index - 1].parameters():
293 | yield param
294 | elif index == len(self.blocks) + 1:
295 | for param in self.output_layer.parameters():
296 | yield param
297 | else:
298 | raise ValueError('Index out of range')
299 |
300 | # Note on this forward function: we pass in a y vector which has
301 | # already been passed through G.shared to enable easy class-wise
302 | # interpolation later. If we passed in the one-hot and then ran it through
303 | # G.shared in this forward function, it would be harder to handle.
304 | def forward(self, z, y, use_in=False):
305 | # If hierarchical, concatenate zs and ys
306 | if self.hier:
307 | zs = torch.split(z, self.z_chunk_size, 1)
308 | z = zs[0]
309 | ys = [torch.cat([y, item], 1) for item in zs[1:]]
310 | else:
311 | ys = [y] * len(self.blocks)
312 |
313 | # First linear layer
314 | h = self.linear(z)
315 | # Reshape
316 | h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
317 |
318 | # Loop over blocks
319 | for index, blocklist in enumerate(self.blocks):
320 | # Second inner loop in case block has multiple layers
321 | for block in blocklist:
322 | h = block(h, ys[index], use_in)
323 |
324 | # Apply batchnorm-relu-conv-tanh at output
325 | return torch.tanh(self.output_layer(h))
326 |
327 |
328 | # Discriminator architecture, same paradigm as G's above
329 | def D_arch(ch=64, attention='64', ksize='333333', dilation='111111'):
330 | arch = {}
331 | arch[256] = {
332 | 'in_channels': [3] + [ch * item for item in [1, 2, 4, 8, 8, 16]],
333 | 'out_channels': [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
334 | 'downsample': [True] * 6 + [False],
335 | 'resolution': [128, 64, 32, 16, 8, 4, 4],
336 | 'attention':
337 | {2**i: 2**i in [int(item) for item in attention.split('_')]
338 | for i in range(2, 8)}
339 | }
340 | arch[128] = {
341 | 'in_channels': [3] + [ch * item for item in [1, 2, 4, 8, 16]],
342 | 'out_channels': [item * ch for item in [1, 2, 4, 8, 16, 16]],
343 | 'downsample': [True] * 5 + [False],
344 | 'resolution': [64, 32, 16, 8, 4, 4],
345 | 'attention':
346 | {2**i: 2**i in [int(item) for item in attention.split('_')]
347 | for i in range(2, 8)}
348 | }
349 | arch[64] = {
350 | 'in_channels': [3] + [ch * item for item in [1, 2, 4, 8]],
351 | 'out_channels': [item * ch for item in [1, 2, 4, 8, 16]],
352 | 'downsample': [True] * 4 + [False],
353 | 'resolution': [32, 16, 8, 4, 4],
354 | 'attention':
355 | {2**i: 2**i in [int(item) for item in attention.split('_')]
356 | for i in range(2, 7)}
357 | }
358 | arch[32] = {
359 | 'in_channels': [3] + [item * ch for item in [4, 4, 4]],
360 | 'out_channels': [item * ch for item in [4, 4, 4, 4]],
361 | 'downsample': [True, True, False, False],
362 | 'resolution': [16, 16, 16, 16],
363 | 'attention':
364 | {2**i: 2**i in [int(item) for item in attention.split('_')]
365 | for i in range(2, 6)}
366 | }
367 | return arch
368 |
369 |
370 | class Discriminator(nn.Module):
371 |
372 | def __init__(self,
373 | D_ch=64,
374 | D_wide=True,
375 | resolution=128,
376 | D_kernel_size=3,
377 | D_attn='64',
378 | n_classes=1000,
379 | num_D_SVs=1,
380 | num_D_SV_itrs=1,
381 | D_activation=nn.ReLU(inplace=False),
382 | D_lr=2e-4,
383 | D_B1=0.0,
384 | D_B2=0.999,
385 | adam_eps=1e-8,
386 | SN_eps=1e-12,
387 | output_dim=1,
388 | D_mixed_precision=False,
389 | D_fp16=False,
390 | D_init='ortho',
391 | skip_init=False,
392 | D_param='SN',
393 | **kwargs):
394 | super(Discriminator, self).__init__()
395 | # Width multiplier
396 | self.ch = D_ch
397 | # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
398 | self.D_wide = D_wide
399 | # Resolution
400 | self.resolution = resolution
401 | # Kernel size
402 | self.kernel_size = D_kernel_size
403 | # Attention?
404 | self.attention = D_attn
405 | # Number of classes
406 | self.n_classes = n_classes
407 | # Activation
408 | self.activation = D_activation
409 | # Initialization style
410 | self.init = D_init
411 | # Parameterization style
412 | self.D_param = D_param
413 | # Epsilon for Spectral Norm?
414 | self.SN_eps = SN_eps
415 | # Fp16?
416 | self.fp16 = D_fp16
417 | # Architecture
418 | self.arch = D_arch(self.ch, self.attention)[resolution]
419 |
420 | # Which convs, batchnorms, and linear layers to use
421 | # No option to turn off SN in D right now
422 | if self.D_param == 'SN':
423 | self.which_conv = functools.partial(
424 | layers.SNConv2d,
425 | kernel_size=3,
426 | padding=1,
427 | num_svs=num_D_SVs,
428 | num_itrs=num_D_SV_itrs,
429 | eps=self.SN_eps)
430 | self.which_linear = functools.partial(
431 | layers.SNLinear, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps)
432 | self.which_embedding = functools.partial(
433 | layers.SNEmbedding, num_svs=num_D_SVs, num_itrs=num_D_SV_itrs, eps=self.SN_eps)
434 | else:
435 | self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
436 | self.which_linear = nn.Linear
437 | self.which_embedding = nn.Embedding
438 |
439 | # Prepare model
440 | # self.blocks is a doubly-nested list of modules, the outer loop intended
441 | # to be over blocks at a given resolution (resblocks and/or self-attention)
442 | self.blocks = []
443 | for index in range(len(self.arch['out_channels'])):
444 | self.blocks += [[
445 | layers.DBlock(
446 | in_channels=self.arch['in_channels'][index],
447 | out_channels=self.arch['out_channels'][index],
448 | which_conv=self.which_conv,
449 | wide=self.D_wide,
450 | activation=self.activation,
451 | preactivation=(index > 0),
452 | downsample=(nn.AvgPool2d(2) if self.arch['downsample'][index] else None))
453 | ]]
454 | # If attention on this block, attach it to the end
455 | if self.arch['attention'][self.arch['resolution'][index]]:
456 | print('Adding attention layer in D at resolution %d' %
457 | self.arch['resolution'][index])
458 | self.blocks[-1] += [
459 | layers.Attention(self.arch['out_channels'][index], self.which_conv)
460 | ]
461 | # Turn self.blocks into a ModuleList so that it's all properly registered.
462 | self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
463 | # Linear output layer. The output dimension is typically 1, but may be
464 | # larger if we're e.g. turning this into a VAE with an inference output
465 | self.linear = self.which_linear(self.arch['out_channels'][-1], output_dim)
466 | # Embedding for projection discrimination
467 | self.embed = self.which_embedding(self.n_classes, self.arch['out_channels'][-1])
468 |
469 | # Initialize weights
470 | if not skip_init:
471 | self.init_weights()
472 |
473 | # Set up optimizer
474 | self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
475 | if D_mixed_precision:
476 | print('Using fp16 adam in D...')
477 | import utils
478 | self.optim = utils.Adam16(
479 | params=self.parameters(),
480 | lr=self.lr,
481 | betas=(self.B1, self.B2),
482 | weight_decay=0,
483 | eps=self.adam_eps)
484 | else:
485 | self.optim = optim.Adam(
486 | params=self.parameters(),
487 | lr=self.lr,
488 | betas=(self.B1, self.B2),
489 | weight_decay=0,
490 | eps=self.adam_eps)
491 | # LR scheduling, left here for forward compatibility
492 | # self.lr_sched = {'itr' : 0}# if self.progressive else {}
493 | # self.j = 0
494 |
495 | # Initialize
496 | def init_weights(self):
497 | self.param_count = 0
498 | for module in self.modules():
499 | if (isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear)
500 | or isinstance(module, nn.Embedding)):
501 | if self.init == 'ortho':
502 | init.orthogonal_(module.weight)
503 | elif self.init == 'N02':
504 | init.normal_(module.weight, 0, 0.02)
505 | elif self.init in ['glorot', 'xavier']:
506 | init.xavier_uniform_(module.weight)
507 | else:
508 | print('Init style not recognized...')
509 | self.param_count += sum([p.data.nelement() for p in module.parameters()])
510 | print('Param count for D' 's initialized parameters: %d' % self.param_count)
511 |
512 | def forward(self, x, y=None):
513 | # Stick x into h for cleaner for loops without flow control
514 | h = x
515 | h_list = []
516 | # Loop over blocks
517 | for index, blocklist in enumerate(self.blocks):
518 | for block in blocklist:
519 | h = block(h)
520 | h_list.append(h)
521 | # Apply global sum pooling as in SN-GAN
522 | h = torch.sum(self.activation(h), [2, 3])
523 | h_list.append(h)
524 | # Get initial class-unconditional output
525 | out = self.linear(h)
526 | if y is not None:
527 | # Get projection of final featureset onto class vectors and add to evidence
528 | out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
529 |
530 | return out, h_list
531 |
--------------------------------------------------------------------------------
/models/dgp.py:
--------------------------------------------------------------------------------
1 | import os
2 | from copy import deepcopy
3 |
4 | import numpy as np
5 | import torch
6 | import torch.distributed as dist
7 | import torch.nn.functional as F
8 | import torchvision
9 | from PIL import Image
10 | from skimage import color
11 | from skimage.measure import compare_psnr, compare_ssim
12 | from torch.autograd import Variable
13 |
14 | import models
15 | import utils
16 | from models.downsampler import Downsampler
17 |
18 |
19 | class DGP(object):
20 |
21 | def __init__(self, config):
22 | self.rank, self.world_size = 0, 1
23 | if config['dist']:
24 | self.rank = dist.get_rank()
25 | self.world_size = dist.get_world_size()
26 | self.config = config
27 | self.mode = config['dgp_mode']
28 | self.update_G = config['update_G']
29 | self.update_embed = config['update_embed']
30 | self.iterations = config['iterations']
31 | self.ftr_num = config['ftr_num']
32 | self.ft_num = config['ft_num']
33 | self.lr_ratio = config['lr_ratio']
34 | self.G_lrs = config['G_lrs']
35 | self.z_lrs = config['z_lrs']
36 | self.use_in = config['use_in']
37 | self.select_num = config['select_num']
38 | self.factor = 2 if self.mode == 'hybrid' else 4 # Downsample factor
39 |
40 | # create model
41 | self.G = models.Generator(**config).cuda()
42 | self.D = models.Discriminator(
43 | **config).cuda() if config['ftr_type'] == 'Discriminator' else None
44 | self.G.optim = torch.optim.Adam(
45 | [{'params': self.G.get_params(i, self.update_embed)}
46 | for i in range(len(self.G.blocks) + 1)],
47 | lr=config['G_lr'],
48 | betas=(config['G_B1'], config['G_B2']),
49 | weight_decay=0,
50 | eps=1e-8)
51 |
52 | # load weights
53 | if config['random_G']:
54 | self.random_G()
55 | else:
56 | utils.load_weights(
57 | self.G if not (config['use_ema']) else None,
58 | self.D,
59 | config['weights_root'],
60 | name_suffix=config['load_weights'],
61 | G_ema=self.G if config['use_ema'] else None,
62 | strict=False)
63 |
64 | self.G.eval()
65 | if self.D is not None:
66 | self.D.eval()
67 | self.G_weight = deepcopy(self.G.state_dict())
68 |
69 | # prepare latent variable and optimizer
70 | self._prepare_latent()
71 | # prepare learning rate scheduler
72 | self.G_scheduler = utils.LRScheduler(self.G.optim, config['warm_up'])
73 | self.z_scheduler = utils.LRScheduler(self.z_optim, config['warm_up'])
74 |
75 | # loss functions
76 | self.mse = torch.nn.MSELoss()
77 | if config['ftr_type'] == 'Discriminator':
78 | self.ftr_net = self.D
79 | self.criterion = utils.DiscriminatorLoss(ftr_num=config['ftr_num'][0])
80 | else:
81 | vgg = torchvision.models.vgg16(pretrained=True).cuda().eval()
82 | self.ftr_net = models.subsequence(vgg.features, last_layer='20')
83 | self.criterion = utils.PerceptLoss()
84 |
85 | # Downsampler for producing low-resolution image
86 | self.downsampler = Downsampler(
87 | n_planes=3,
88 | factor=self.factor,
89 | kernel_type='lanczos2',
90 | phase=0.5,
91 | preserve_size=True).type(torch.cuda.FloatTensor)
92 |
93 | def _prepare_latent(self):
94 | self.z = torch.zeros((1, self.G.dim_z)).normal_().cuda()
95 | self.z = Variable(self.z, requires_grad=True)
96 | self.z_optim = torch.optim.Adam(
97 | [{'params': self.z, 'lr': self.z_lrs[0]}],
98 | betas=(self.config['G_B1'], self.config['G_B2']),
99 | weight_decay=0,
100 | eps=1e-8
101 | )
102 | self.y = torch.zeros(1).long().cuda()
103 |
104 | def reset_G(self):
105 | self.G.load_state_dict(self.G_weight, strict=False)
106 | self.G.reset_in_init()
107 | if self.config['random_G']:
108 | self.G.train()
109 | else:
110 | self.G.eval()
111 |
112 | def random_G(self):
113 | self.G.init_weights()
114 |
115 | def set_target(self, target, category, img_path):
116 | self.target_origin = target
117 | # apply degradation transform to the original image
118 | self.target = self.pre_process(target, True)
119 | self.y.fill_(category.item())
120 | self.img_name = img_path[img_path.rfind('/') + 1:img_path.rfind('.')]
121 |
122 | def run(self, save_interval=None):
123 | save_imgs = self.target.clone()
124 | save_imgs2 = save_imgs.cpu().clone()
125 | loss_dict = {}
126 | curr_step = 0
127 | count = 0
128 | for stage, iteration in enumerate(self.iterations):
129 | # setup the number of features to use in discriminator
130 | self.criterion.set_ftr_num(self.ftr_num[stage])
131 |
132 | for i in range(iteration):
133 | curr_step += 1
134 | # setup learning rate
135 | self.G_scheduler.update(curr_step, self.G_lrs[stage],
136 | self.ft_num[stage], self.lr_ratio[stage])
137 | self.z_scheduler.update(curr_step, self.z_lrs[stage])
138 |
139 | self.z_optim.zero_grad()
140 | if self.update_G:
141 | self.G.optim.zero_grad()
142 | x = self.G(self.z, self.G.shared(self.y), use_in=self.use_in[stage])
143 | # apply degradation transform
144 | x_map = self.pre_process(x, False)
145 |
146 | # calculate losses in the degradation space
147 | ftr_loss = self.criterion(self.ftr_net, x_map, self.target)
148 | mse_loss = self.mse(x_map, self.target)
149 | # nll corresponds to a negative log-likelihood loss
150 | nll = self.z**2 / 2
151 | nll = nll.mean()
152 | l1_loss = F.l1_loss(x_map, self.target)
153 | loss = ftr_loss * self.config['w_D_loss'][stage] + \
154 | mse_loss * self.config['w_mse'][stage] + \
155 | nll * self.config['w_nll']
156 | loss.backward()
157 |
158 | self.z_optim.step()
159 | if self.update_G:
160 | self.G.optim.step()
161 |
162 | # These losses are calculated in the [-1,1] image scale
163 | # We record the rescaled MSE and L1 loss, corresponding to [0,1] image scale
164 | loss_dict = {
165 | 'ftr_loss': ftr_loss,
166 | 'nll': nll,
167 | 'mse_loss': mse_loss / 4,
168 | 'l1_loss': l1_loss / 2
169 | }
170 |
171 | # calculate losses in the non-degradation space
172 | if self.mode in ['reconstruct', 'colorization', 'SR', 'inpainting']:
173 | # x2 is to get the post-processed result in colorization
174 | metrics, x2 = self.get_metrics(x)
175 | loss_dict = {**loss_dict, **metrics}
176 |
177 | if i == 0 or (i + 1) % self.config['print_interval'] == 0:
178 | if self.rank == 0:
179 | print(', '.join(
180 | ['Stage: [{0}/{1}]'.format(stage + 1, len(self.iterations))] +
181 | ['Iter: [{0}/{1}]'.format(i + 1, iteration)] +
182 | ['%s : %+4.4f' % (key, loss_dict[key]) for key in loss_dict]
183 | ))
184 | # save image sheet of the reconstruction process
185 | save_imgs = torch.cat((save_imgs, x), dim=0)
186 | torchvision.utils.save_image(
187 | save_imgs.float(),
188 | '%s/images_sheet/%s_%s.jpg' %
189 | (self.config['exp_path'], self.img_name, self.mode),
190 | nrow=int(save_imgs.size(0)**0.5),
191 | normalize=True)
192 | if self.mode == 'colorization':
193 | save_imgs2 = torch.cat((save_imgs2, x2), dim=0)
194 | torchvision.utils.save_image(
195 | save_imgs2.float(),
196 | '%s/images_sheet/%s_%s_2.jpg' %
197 | (self.config['exp_path'], self.img_name, self.mode),
198 | nrow=int(save_imgs.size(0)**0.5),
199 | normalize=True)
200 |
201 | if save_interval is not None:
202 | if i == 0 or (i + 1) % save_interval[stage] == 0:
203 | count += 1
204 | save_path = '%s/images/%s' % (self.config['exp_path'],
205 | self.img_name)
206 | if not os.path.exists(save_path):
207 | os.makedirs(save_path)
208 | img_path = os.path.join(
209 | save_path, '%s_%03d.jpg' % (self.img_name, count))
210 | utils.save_img(x[0], img_path)
211 |
212 | # stop the reconstruction if the loss reaches a threshold
213 | if mse_loss.item() < self.config['stop_mse'] or ftr_loss.item(
214 | ) < self.config['stop_ftr']:
215 | break
216 |
217 | # save images
218 | utils.save_img(
219 | self.target[0], '%s/images/%s_%s_target.png' %
220 | (self.config['exp_path'], self.img_name, self.mode))
221 | utils.save_img(
222 | self.target_origin[0],
223 | '%s/images/%s_%s_target_origin.png' %
224 | (self.config['exp_path'], self.img_name, self.mode))
225 | utils.save_img(
226 | x[0], '%s/images/%s_%s.png' %
227 | (self.config['exp_path'], self.img_name, self.mode))
228 | if self.mode == 'colorization':
229 | utils.save_img(
230 | x2[0], '%s/images/%s_%s2.png' %
231 | (self.config['exp_path'], self.img_name, self.mode))
232 |
233 | if self.mode == 'jitter':
234 | # conduct random jittering
235 | self.jitter(x)
236 | if self.config['save_G']:
237 | torch.save(
238 | self.G.state_dict(), '%s/G_%s_%s.pth' %
239 | (self.config['exp_path'], self.img_name, self.mode))
240 | torch.save(
241 | self.z, '%s/z_%s_%s.pth' %
242 | (self.config['exp_path'], self.img_name, self.mode))
243 | return loss_dict
244 |
245 | def select_z(self, select_y=False):
246 | with torch.no_grad():
247 | if self.select_num == 0:
248 | self.z.zero_()
249 | return
250 | elif self.select_num == 1:
251 | self.z.normal_()
252 | return
253 | z_all, y_all, loss_all = [], [], []
254 | if self.rank == 0:
255 | print('Selecting z from {} samples'.format(self.select_num))
256 | # only use last 3 discriminator features to compare
257 | self.criterion.set_ftr_num(3)
258 | for i in range(self.select_num):
259 | self.z.normal_(mean=0, std=self.config['sample_std'])
260 | z_all.append(self.z.cpu())
261 | if select_y:
262 | self.y.random_(0, self.config['n_classes'])
263 | y_all.append(self.y.cpu())
264 | x = self.G(self.z, self.G.shared(self.y))
265 | x = self.pre_process(x)
266 | ftr_loss = self.criterion(self.ftr_net, x, self.target)
267 | loss_all.append(ftr_loss.view(1).cpu())
268 | if self.rank == 0 and (i + 1) % 100 == 0:
269 | print('Generating {}th sample'.format(i + 1))
270 | loss_all = torch.cat(loss_all)
271 | idx = torch.argmin(loss_all)
272 | self.z.copy_(z_all[idx])
273 | if select_y:
274 | self.y.copy_(y_all[idx])
275 | self.criterion.set_ftr_num(self.ftr_num[0])
276 |
277 | def pre_process(self, image, target=True):
278 | if self.mode in ['SR', 'hybrid']:
279 | # apply downsampling, this part is the same as deep image prior
280 | if target:
281 | image_pil = utils.np_to_pil(
282 | utils.torch_to_np((image.cpu() + 1) / 2))
283 | LR_size = [
284 | image_pil.size[0] // self.factor,
285 | image_pil.size[1] // self.factor
286 | ]
287 | img_LR_pil = image_pil.resize(LR_size, Image.ANTIALIAS)
288 | image = utils.np_to_torch(utils.pil_to_np(img_LR_pil)).cuda()
289 | image = image * 2 - 1
290 | else:
291 | image = self.downsampler((image + 1) / 2)
292 | image = image * 2 - 1
293 | # interpolate to the orginal resolution via bilinear interpolation
294 | image = F.interpolate(
295 | image, scale_factor=self.factor, mode='bilinear')
296 | n, _, h, w = image.size()
297 | if self.mode in ['colorization', 'hybrid']:
298 | # transform the image to gray-scale
299 | r = image[:, 0, :, :]
300 | g = image[:, 1, :, :]
301 | b = image[:, 2, :, :]
302 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
303 | image = gray.view(n, 1, h, w).expand(n, 3, h, w)
304 | if self.mode in ['inpainting', 'hybrid']:
305 | # remove the center part of the image
306 | hole = min(h, w) // 3
307 | begin = (h - hole) // 2
308 | end = h - begin
309 | self.begin, self.end = begin, end
310 | mask = torch.ones(1, 1, h, w).cuda()
311 | mask[0, 0, begin:end, begin:end].zero_()
312 | image = image * mask
313 | return image
314 |
315 | def get_metrics(self, x):
316 | with torch.no_grad():
317 | l1_loss_origin = F.l1_loss(x, self.target_origin) / 2
318 | mse_loss_origin = self.mse(x, self.target_origin) / 4
319 | metrics = {
320 | 'l1_loss_origin': l1_loss_origin,
321 | 'mse_loss_origin': mse_loss_origin
322 | }
323 | # transfer to numpy array and scale to [0, 1]
324 | target_np = (self.target_origin.detach().cpu().numpy()[0] + 1) / 2
325 | x_np = (x.detach().cpu().numpy()[0] + 1) / 2
326 | target_np = np.transpose(target_np, (1, 2, 0))
327 | x_np = np.transpose(x_np, (1, 2, 0))
328 | if self.mode == 'colorization':
329 | # combine the 'ab' dim of x with the 'L' dim of target image
330 | x_lab = color.rgb2lab(x_np)
331 | target_lab = color.rgb2lab(target_np)
332 | x_lab[:, :, 0] = target_lab[:, :, 0]
333 | x_np = color.lab2rgb(x_lab)
334 | x = torch.Tensor(np.transpose(x_np, (2, 0, 1))) * 2 - 1
335 | x = x.unsqueeze(0)
336 | elif self.mode == 'inpainting':
337 | # only use the inpainted area to calculate ssim and psnr
338 | x_np = x_np[self.begin:self.end, self.begin:self.end, :]
339 | target_np = target_np[self.begin:self.end,
340 | self.begin:self.end, :]
341 | ssim = compare_ssim(target_np, x_np, multichannel=True)
342 | psnr = compare_psnr(target_np, x_np)
343 | metrics['psnr'] = torch.Tensor([psnr]).cuda()
344 | metrics['ssim'] = torch.Tensor([ssim]).cuda()
345 | return metrics, x
346 |
347 | def jitter(self, x):
348 | save_imgs = x.clone().cpu()
349 | z_rand = self.z.clone()
350 | stds = [0.3, 0.5, 0.7]
351 | save_path = '%s/images/%s_jitter' % (self.config['exp_path'],
352 | self.img_name)
353 | if not os.path.exists(save_path):
354 | os.makedirs(save_path)
355 | with torch.no_grad():
356 | for std in stds:
357 | for i in range(30):
358 | # add random noise to the latent vector
359 | z_rand.normal_()
360 | z = self.z + std * z_rand
361 | x_jitter = self.G(z, self.G.shared(self.y))
362 | utils.save_img(
363 | x_jitter[0], '%s/std%.1f_%d.jpg' % (save_path, std, i))
364 | save_imgs = torch.cat((save_imgs, x_jitter.cpu()), dim=0)
365 |
366 | torchvision.utils.save_image(
367 | save_imgs.float(),
368 | '%s/images_sheet/%s_jitters.jpg' %
369 | (self.config['exp_path'], self.img_name),
370 | nrow=int(save_imgs.size(0)**0.5),
371 | normalize=True)
372 |
--------------------------------------------------------------------------------
/models/downsampler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 |
5 | class Downsampler(nn.Module):
6 | '''
7 | http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf
8 | '''
9 | def __init__(self, n_planes, factor, kernel_type, phase=0, kernel_width=None, support=None, sigma=None, preserve_size=False):
10 | super(Downsampler, self).__init__()
11 |
12 | assert phase in [0, 0.5], 'phase should be 0 or 0.5'
13 |
14 | if kernel_type == 'lanczos2':
15 | support = 2
16 | kernel_width = 4 * factor + 1
17 | kernel_type_ = 'lanczos'
18 |
19 | elif kernel_type == 'lanczos3':
20 | support = 3
21 | kernel_width = 6 * factor + 1
22 | kernel_type_ = 'lanczos'
23 |
24 | elif kernel_type == 'gauss12':
25 | kernel_width = 7
26 | sigma = 1/2
27 | kernel_type_ = 'gauss'
28 |
29 | elif kernel_type == 'gauss1sq2':
30 | kernel_width = 9
31 | sigma = 1./np.sqrt(2)
32 | kernel_type_ = 'gauss'
33 |
34 | elif kernel_type in ['lanczos', 'gauss', 'box']:
35 | kernel_type_ = kernel_type
36 |
37 | else:
38 | assert False, 'wrong name kernel'
39 |
40 |
41 | # note that `kernel width` will be different to actual size for phase = 1/2
42 | self.kernel = get_kernel(factor, kernel_type_, phase, kernel_width, support=support, sigma=sigma)
43 |
44 | downsampler = nn.Conv2d(n_planes, n_planes, kernel_size=self.kernel.shape, stride=factor, padding=0)
45 | downsampler.weight.data[:] = 0
46 | downsampler.bias.data[:] = 0
47 |
48 | kernel_torch = torch.from_numpy(self.kernel)
49 | for i in range(n_planes):
50 | downsampler.weight.data[i, i] = kernel_torch
51 |
52 | self.downsampler_ = downsampler
53 |
54 | if preserve_size:
55 |
56 | if self.kernel.shape[0] % 2 == 1:
57 | pad = int((self.kernel.shape[0] - 1) / 2.)
58 | else:
59 | pad = int((self.kernel.shape[0] - factor) / 2.)
60 |
61 | self.padding = nn.ReplicationPad2d(pad)
62 |
63 | self.preserve_size = preserve_size
64 |
65 | def forward(self, input):
66 | if self.preserve_size:
67 | x = self.padding(input)
68 | else:
69 | x= input
70 | self.x = x
71 | return self.downsampler_(x)
72 |
73 | def get_kernel(factor, kernel_type, phase, kernel_width, support=None, sigma=None):
74 | assert kernel_type in ['lanczos', 'gauss', 'box']
75 |
76 | # factor = float(factor)
77 | if phase == 0.5 and kernel_type != 'box':
78 | kernel = np.zeros([kernel_width - 1, kernel_width - 1])
79 | else:
80 | kernel = np.zeros([kernel_width, kernel_width])
81 |
82 |
83 | if kernel_type == 'box':
84 | assert phase == 0.5, 'Box filter is always half-phased'
85 | kernel[:] = 1./(kernel_width * kernel_width)
86 |
87 | elif kernel_type == 'gauss':
88 | assert sigma, 'sigma is not specified'
89 | assert phase != 0.5, 'phase 1/2 for gauss not implemented'
90 |
91 | center = (kernel_width + 1.)/2.
92 | print(center, kernel_width)
93 | sigma_sq = sigma * sigma
94 |
95 | for i in range(1, kernel.shape[0] + 1):
96 | for j in range(1, kernel.shape[1] + 1):
97 | di = (i - center)/2.
98 | dj = (j - center)/2.
99 | kernel[i - 1][j - 1] = np.exp(-(di * di + dj * dj)/(2 * sigma_sq))
100 | kernel[i - 1][j - 1] = kernel[i - 1][j - 1]/(2. * np.pi * sigma_sq)
101 | elif kernel_type == 'lanczos':
102 | assert support, 'support is not specified'
103 | center = (kernel_width + 1) / 2.
104 |
105 | for i in range(1, kernel.shape[0] + 1):
106 | for j in range(1, kernel.shape[1] + 1):
107 |
108 | if phase == 0.5:
109 | di = abs(i + 0.5 - center) / factor
110 | dj = abs(j + 0.5 - center) / factor
111 | else:
112 | di = abs(i - center) / factor
113 | dj = abs(j - center) / factor
114 |
115 |
116 | pi_sq = np.pi * np.pi
117 |
118 | val = 1
119 | if di != 0:
120 | val = val * support * np.sin(np.pi * di) * np.sin(np.pi * di / support)
121 | val = val / (np.pi * np.pi * di * di)
122 |
123 | if dj != 0:
124 | val = val * support * np.sin(np.pi * dj) * np.sin(np.pi * dj / support)
125 | val = val / (np.pi * np.pi * dj * dj)
126 |
127 | kernel[i - 1][j - 1] = val
128 |
129 |
130 | else:
131 | assert False, 'wrong method name'
132 |
133 | kernel /= kernel.sum()
134 |
135 | return kernel
136 |
137 | #a = Downsampler(n_planes=3, factor=2, kernel_type='lanczos2', phase='1', preserve_size=True)
138 |
139 |
140 |
141 |
142 |
143 |
144 | #################
145 | # Learnable downsampler
146 |
147 | # KS = 32
148 | # dow = nn.Sequential(nn.ReplicationPad2d(int((KS - factor) / 2.)), nn.Conv2d(1,1,KS,factor))
149 |
150 | # class Apply(nn.Module):
151 | # def __init__(self, what, dim, *args):
152 | # super(Apply, self).__init__()
153 | # self.dim = dim
154 |
155 | # self.what = what
156 |
157 | # def forward(self, input):
158 | # inputs = []
159 | # for i in range(input.size(self.dim)):
160 | # inputs.append(self.what(input.narrow(self.dim, i, 1)))
161 |
162 | # return torch.cat(inputs, dim=self.dim)
163 |
164 | # def __len__(self):
165 | # return len(self._modules)
166 |
167 | # downs = Apply(dow, 1)
168 | # downs.type(dtype)(net_input.type(dtype)).size()
169 |
--------------------------------------------------------------------------------
/models/layers.py:
--------------------------------------------------------------------------------
1 | ''' Layers
2 | This file contains various layers for the BigGAN models.
3 | '''
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.nn import Parameter as P
8 |
9 |
10 | # Projection of x onto y
11 | def proj(x, y):
12 | return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
13 |
14 |
15 | # Orthogonalize x wrt list of vectors ys
16 | def gram_schmidt(x, ys):
17 | for y in ys:
18 | x = x - proj(x, y)
19 | return x
20 |
21 |
22 | # Apply num_itrs steps of the power method to estimate top N singular values.
23 | def power_iteration(W, u_, update=True, eps=1e-12):
24 | # Lists holding singular vectors and values
25 | us, vs, svs = [], [], []
26 | for i, u in enumerate(u_):
27 | # Run one step of the power iteration
28 | with torch.no_grad():
29 | v = torch.matmul(u, W)
30 | # Run Gram-Schmidt to subtract components of all other singular vectors
31 | v = F.normalize(gram_schmidt(v, vs), eps=eps)
32 | # Add to the list
33 | vs += [v]
34 | # Update the other singular vector
35 | u = torch.matmul(v, W.t())
36 | # Run Gram-Schmidt to subtract components of all other singular vectors
37 | u = F.normalize(gram_schmidt(u, us), eps=eps)
38 | # Add to the list
39 | us += [u]
40 | if update:
41 | u_[i][:] = u
42 | # Compute this singular value and add it to the list
43 | svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
44 | # svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
45 | return svs, us, vs
46 |
47 |
48 | # Convenience passthrough function
49 | class identity(nn.Module):
50 |
51 | def forward(self, input):
52 | return input
53 |
54 |
55 | # Spectral normalization base class
56 | class SN(object):
57 |
58 | def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
59 | # Number of power iterations per step
60 | self.num_itrs = num_itrs
61 | # Number of singular values
62 | self.num_svs = num_svs
63 | # Transposed?
64 | self.transpose = transpose
65 | # Epsilon value for avoiding divide-by-0
66 | self.eps = eps
67 | # Register a singular vector for each sv
68 | for i in range(self.num_svs):
69 | self.register_buffer('u%d' % i, torch.randn(1, num_outputs))
70 | self.register_buffer('sv%d' % i, torch.ones(1))
71 |
72 | # Singular vectors (u side)
73 | @property
74 | def u(self):
75 | return [getattr(self, 'u%d' % i) for i in range(self.num_svs)]
76 |
77 | # Singular values;
78 | # note that these buffers are just for logging and are not used in training.
79 | @property
80 | def sv(self):
81 | return [getattr(self, 'sv%d' % i) for i in range(self.num_svs)]
82 |
83 | # Compute the spectrally-normalized weight
84 | def W_(self):
85 | W_mat = self.weight.view(self.weight.size(0), -1)
86 | if self.transpose:
87 | W_mat = W_mat.t()
88 | # Apply num_itrs power iterations
89 | for _ in range(self.num_itrs):
90 | svs, us, vs = power_iteration(W_mat, self.u, update=self.training, eps=self.eps)
91 | # Update the svs
92 | if self.training:
93 | with torch.no_grad(
94 | ): # Make sure to do this in a no_grad() context or you'll get memory leaks!
95 | for i, sv in enumerate(svs):
96 | self.sv[i][:] = sv
97 | return self.weight / svs[0]
98 |
99 |
100 | # 2D Conv layer with spectral norm
101 | class SNConv2d(nn.Conv2d, SN):
102 |
103 | def __init__(self,
104 | in_channels,
105 | out_channels,
106 | kernel_size,
107 | stride=1,
108 | padding=0,
109 | dilation=1,
110 | groups=1,
111 | bias=True,
112 | num_svs=1,
113 | num_itrs=1,
114 | eps=1e-12):
115 | nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation,
116 | groups, bias)
117 | SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
118 |
119 | def forward(self, x):
120 | return F.conv2d(x, self.W_(), self.bias, self.stride, self.padding, self.dilation,
121 | self.groups)
122 |
123 |
124 | # Linear layer with spectral norm
125 | class SNLinear(nn.Linear, SN):
126 |
127 | def __init__(self, in_features, out_features, bias=True, num_svs=1, num_itrs=1, eps=1e-12):
128 | nn.Linear.__init__(self, in_features, out_features, bias)
129 | SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
130 |
131 | def forward(self, x):
132 | return F.linear(x, self.W_(), self.bias)
133 |
134 |
135 | # Embedding layer with spectral norm
136 | # We use num_embeddings as the dim instead of embedding_dim here
137 | # for convenience sake
138 | class SNEmbedding(nn.Embedding, SN):
139 |
140 | def __init__(self,
141 | num_embeddings,
142 | embedding_dim,
143 | padding_idx=None,
144 | max_norm=None,
145 | norm_type=2,
146 | scale_grad_by_freq=False,
147 | sparse=False,
148 | _weight=None,
149 | num_svs=1,
150 | num_itrs=1,
151 | eps=1e-12):
152 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, padding_idx, max_norm, norm_type,
153 | scale_grad_by_freq, sparse, _weight)
154 | SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
155 |
156 | def forward(self, x):
157 | return F.embedding(x, self.W_())
158 |
159 |
160 | # A non-local block as used in SA-GAN
161 | # Note that the implementation as described in the paper is largely incorrect;
162 | # refer to the released code for the actual implementation.
163 | class Attention(nn.Module):
164 |
165 | def __init__(self, ch, which_conv=SNConv2d, name='attention'):
166 | super(Attention, self).__init__()
167 | # Channel multiplier
168 | self.ch = ch
169 | self.which_conv = which_conv
170 | self.theta = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
171 | self.phi = self.which_conv(self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False)
172 | self.g = self.which_conv(self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False)
173 | self.o = self.which_conv(self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False)
174 | # Learnable gain parameter
175 | self.gamma = P(torch.tensor(0.), requires_grad=True)
176 |
177 | def forward(self, x, y=None, use_in=False):
178 | # Apply convs
179 | theta = self.theta(x)
180 | phi = F.max_pool2d(self.phi(x), [2, 2])
181 | g = F.max_pool2d(self.g(x), [2, 2])
182 | # Perform reshapes
183 | theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
184 | phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
185 | g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
186 | # Matmul and softmax to get attention maps
187 | beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
188 | # Attention map times g path
189 | o = self.o(
190 | torch.bmm(g, beta.transpose(1, 2)).view(-1, self.ch // 2, x.shape[2], x.shape[3]))
191 | output = self.gamma * o + x
192 | return output
193 |
194 |
195 | # Fused batchnorm op
196 | def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
197 | # Apply scale and shift--if gain and bias are provided, fuse them here
198 | # Prepare scale
199 | scale = torch.rsqrt(var + eps)
200 | # If a gain is provided, use it
201 | if gain is not None:
202 | scale = scale * gain
203 | # Prepare shift
204 | shift = mean * scale
205 | # If bias is provided, use it
206 | if bias is not None:
207 | shift = shift - bias
208 | return x * scale - shift
209 | # return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
210 |
211 |
212 | # Manual BN
213 | # Calculate means and variances using mean-of-squares minus mean-squared
214 | def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
215 | # Cast x to float32 if necessary
216 | float_x = x.float()
217 | # Calculate expected value of x (m) and expected value of x**2 (m2)
218 | # Mean of x
219 | m = torch.mean(float_x, [0, 2, 3], keepdim=True)
220 | # Mean of x squared
221 | m2 = torch.mean(float_x**2, [0, 2, 3], keepdim=True)
222 | # Calculate variance as mean of squared minus mean squared.
223 | var = (m2 - m**2)
224 | # Cast back to float 16 if necessary
225 | var = var.type(x.type())
226 | m = m.type(x.type())
227 | # Return mean and variance for updating stored mean/var if requested
228 | if return_mean_var:
229 | return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
230 | else:
231 | return fused_bn(x, m, var, gain, bias, eps)
232 |
233 |
234 | # My batchnorm, supports standing stats
235 | class myBN(nn.Module):
236 |
237 | def __init__(self, num_channels, eps=1e-5, momentum=0.1):
238 | super(myBN, self).__init__()
239 | # momentum for updating running stats
240 | self.momentum = momentum
241 | # epsilon to avoid dividing by 0
242 | self.eps = eps
243 | # Momentum
244 | self.momentum = momentum
245 | # Register buffers
246 | self.register_buffer('stored_mean', torch.zeros(num_channels))
247 | self.register_buffer('stored_var', torch.ones(num_channels))
248 | self.register_buffer('accumulation_counter', torch.zeros(1))
249 | # Accumulate running means and vars
250 | self.accumulate_standing = False
251 |
252 | # reset standing stats
253 | def reset_stats(self):
254 | self.stored_mean[:] = 0
255 | self.stored_var[:] = 0
256 | self.accumulation_counter[:] = 0
257 |
258 | def forward(self, x, gain, bias):
259 | if self.training:
260 | out, mean, var = manual_bn(x, gain, bias, return_mean_var=True, eps=self.eps)
261 | # If accumulating standing stats, increment them
262 | if self.accumulate_standing:
263 | self.stored_mean[:] = self.stored_mean + mean.data
264 | self.stored_var[:] = self.stored_var + var.data
265 | self.accumulation_counter += 1.0
266 | # If not accumulating standing stats, take running averages
267 | else:
268 | self.stored_mean[:] = self.stored_mean * (1 - self.momentum) + mean * self.momentum
269 | self.stored_var[:] = self.stored_var * (1 - self.momentum) + var * self.momentum
270 | return out
271 | # If not in training mode, use the stored statistics
272 | else:
273 | mean = self.stored_mean.view(1, -1, 1, 1)
274 | var = self.stored_var.view(1, -1, 1, 1)
275 | # If using standing stats, divide them by the accumulation counter
276 | if self.accumulate_standing:
277 | mean = mean / self.accumulation_counter
278 | var = var / self.accumulation_counter
279 | return fused_bn(x, mean, var, gain, bias, self.eps)
280 |
281 |
282 | # Class-conditional bn
283 | # output size is the number of channels, input size is for the linear layers
284 | # Andy's Note: this class feels messy but I'm not really sure how to clean it up
285 | # Suggestions welcome! (By which I mean, refactor this and make a pull request
286 | # if you want to make this more readable/usable).
287 | class ccbn(nn.Module):
288 |
289 | def __init__(
290 | self,
291 | output_size,
292 | input_size,
293 | which_linear,
294 | eps=1e-5,
295 | momentum=0.1,
296 | cross_replica=False,
297 | mybn=False,
298 | norm_style='bn',
299 | ):
300 | super(ccbn, self).__init__()
301 | self.output_size, self.input_size = output_size, input_size
302 | # Prepare gain and bias layers
303 | self.gain = which_linear(input_size, output_size)
304 | self.bias = which_linear(input_size, output_size)
305 | # epsilon to avoid dividing by 0
306 | self.eps = eps
307 | # Momentum
308 | self.momentum = momentum
309 | # Use cross-replica batchnorm?
310 | self.cross_replica = cross_replica
311 | # Use my batchnorm?
312 | self.mybn = mybn
313 | # Norm style?
314 | self.norm_style = norm_style
315 |
316 | if self.cross_replica:
317 | raise NotImplementedError
318 | elif self.mybn:
319 | self.bn = myBN(output_size, self.eps, self.momentum)
320 | elif self.norm_style in ['bn', 'in']:
321 | self.register_buffer('stored_mean', torch.zeros(output_size))
322 | self.register_buffer('stored_var', torch.ones(output_size))
323 |
324 | def forward(self, x, y):
325 | # Calculate class-conditional gains and biases
326 | if y is not None:
327 | gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
328 | bias = self.bias(y).view(y.size(0), -1, 1, 1)
329 | # If using my batchnorm
330 | if self.mybn:
331 | return self.bn(x, gain=gain, bias=bias)
332 | elif self.cross_replica:
333 | return self.bn(x) * gain + bias
334 | # else:
335 | else:
336 | if self.norm_style == 'bn':
337 | out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None, self.training,
338 | 0.1, self.eps)
339 | elif self.norm_style == 'in':
340 | out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
341 | self.training, 0.1, self.eps)
342 | elif self.norm_style == 'nonorm':
343 | out = x
344 | if y is not None:
345 | out = out * gain + bias
346 | return out
347 |
348 | def extra_repr(self):
349 | s = 'out: {output_size}, in: {input_size},'
350 | s += ' cross_replica={cross_replica}'
351 | return s.format(**self.__dict__)
352 |
353 |
354 | # Normal, non-class-conditional BN
355 | class bn(nn.Module):
356 |
357 | def __init__(self, output_size, eps=1e-5, momentum=0.1, cross_replica=False, mybn=False):
358 | super(bn, self).__init__()
359 | self.output_size = output_size
360 | # Prepare gain and bias layers
361 | self.gain = P(torch.ones(output_size), requires_grad=True)
362 | self.bias = P(torch.zeros(output_size), requires_grad=True)
363 | # epsilon to avoid dividing by 0
364 | self.eps = eps
365 | # Momentum
366 | self.momentum = momentum
367 | # Use cross-replica batchnorm?
368 | self.cross_replica = cross_replica
369 | # Use my batchnorm?
370 | self.mybn = mybn
371 |
372 | if self.cross_replica:
373 | raise NotImplementedError
374 | elif mybn:
375 | self.bn = myBN(output_size, self.eps, self.momentum)
376 | # Register buffers if neither of the above
377 | else:
378 | self.register_buffer('stored_mean', torch.zeros(output_size))
379 | self.register_buffer('stored_var', torch.ones(output_size))
380 |
381 | def forward(self, x, y=None):
382 | if self.cross_replica or self.mybn:
383 | gain = self.gain.view(1, -1, 1, 1)
384 | bias = self.bias.view(1, -1, 1, 1)
385 | if self.mybn:
386 | return self.bn(x, gain=gain, bias=bias)
387 | elif self.cross_replica:
388 | return self.bn(x) * gain + bias
389 | else:
390 | return F.batch_norm(x, self.stored_mean, self.stored_var, self.gain, self.bias,
391 | self.training, self.momentum, self.eps)
392 |
393 |
394 | # Generator blocks
395 | # Note that this class assumes the kernel size and padding (and any other
396 | # settings) have been selected in the main generator module and passed in
397 | # through the which_conv arg. Similar rules apply with which_bn (the input
398 | # size [which is actually the number of channels of the conditional info] must
399 | # be preselected)
400 | class GBlock(nn.Module):
401 |
402 | def __init__(self,
403 | in_channels,
404 | out_channels,
405 | which_conv=nn.Conv2d,
406 | which_bn=bn,
407 | activation=None,
408 | upsample=None):
409 | super(GBlock, self).__init__()
410 |
411 | self.in_channels, self.out_channels = in_channels, out_channels
412 | self.which_conv, self.which_bn = which_conv, which_bn
413 | self.activation = activation
414 | self.upsample = upsample
415 | # Conv layers
416 | self.conv1 = self.which_conv(self.in_channels, self.out_channels)
417 | self.conv2 = self.which_conv(self.out_channels, self.out_channels)
418 | self.learnable_sc = in_channels != out_channels or upsample
419 | if self.learnable_sc:
420 | self.conv_sc = self.which_conv(in_channels, out_channels, kernel_size=1, padding=0)
421 | # Batchnorm layers
422 | self.bn1 = self.which_bn(in_channels)
423 | self.bn2 = self.which_bn(out_channels)
424 | # upsample layers
425 | self.upsample = upsample
426 |
427 | # instance norm layers
428 | self.in_initialized = False
429 | self.in1 = nn.InstanceNorm2d(in_channels, affine=True)
430 | self.in2 = nn.InstanceNorm2d(out_channels, affine=True)
431 | self.in1.weight.requires_grad = False
432 | self.in1.bias.requires_grad = False
433 | self.in2.weight.requires_grad = False
434 | self.in2.bias.requires_grad = False
435 |
436 | def reset_in_init(self):
437 | self.in_initialized = False
438 |
439 | def init_in(self, which_bn, which_in, x, y):
440 | # carefully initialize IN's weights such that the output does not change
441 | with torch.no_grad():
442 | h = which_bn(x, y)
443 | mean = torch.mean(h, (2, 3)).squeeze(0)
444 | std = torch.std(h.view(h.size(0), h.size(1), -1), 2).squeeze(0)
445 | which_in.weight.copy_(std)
446 | which_in.bias.copy_(mean)
447 |
448 | def forward(self, x, y, use_in):
449 | if use_in:
450 | if not self.in_initialized:
451 | self.init_in(self.bn1, self.in1, x, y)
452 | h = self.in1(x)
453 | else:
454 | h = self.bn1(x, y)
455 | self.in_initialized = False
456 |
457 | h = self.activation(h)
458 | if self.upsample:
459 | h = self.upsample(h)
460 | x = self.upsample(x)
461 | h = self.conv1(h)
462 |
463 | if use_in:
464 | if not self.in_initialized:
465 | self.init_in(self.bn2, self.in2, h, y)
466 | self.in_initialized = True
467 | h = self.in2(h)
468 | else:
469 | h = self.bn2(h, y)
470 |
471 | h = self.activation(h)
472 | h = self.conv2(h)
473 | if self.learnable_sc:
474 | x = self.conv_sc(x)
475 | return h + x
476 |
477 |
478 | # Residual block for the discriminator
479 | class DBlock(nn.Module):
480 |
481 | def __init__(
482 | self,
483 | in_channels,
484 | out_channels,
485 | which_conv=SNConv2d,
486 | wide=True,
487 | preactivation=False,
488 | activation=None,
489 | downsample=None,
490 | ):
491 | super(DBlock, self).__init__()
492 | self.in_channels, self.out_channels = in_channels, out_channels
493 | # If using wide D (as in SA-GAN and BigGAN), change the channel pattern
494 | self.hidden_channels = self.out_channels if wide else self.in_channels
495 | self.which_conv = which_conv
496 | self.preactivation = preactivation
497 | self.activation = activation
498 | self.downsample = downsample
499 |
500 | # Conv layers
501 | self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
502 | self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
503 | self.learnable_sc = True if (in_channels != out_channels) or downsample else False
504 | if self.learnable_sc:
505 | self.conv_sc = self.which_conv(in_channels, out_channels, kernel_size=1, padding=0)
506 |
507 | def shortcut(self, x):
508 | if self.preactivation:
509 | if self.learnable_sc:
510 | x = self.conv_sc(x)
511 | if self.downsample:
512 | x = self.downsample(x)
513 | else:
514 | if self.downsample:
515 | x = self.downsample(x)
516 | if self.learnable_sc:
517 | x = self.conv_sc(x)
518 | return x
519 |
520 | def forward(self, x):
521 | if self.preactivation:
522 | # h = self.activation(x) # NOT TODAY SATAN
523 | # Andy's note: This line *must* be an out-of-place ReLU or it
524 | # will negatively affect the shortcut connection.
525 | h = F.relu(x)
526 | else:
527 | h = x
528 | h = self.conv1(h)
529 | h = self.conv2(self.activation(h))
530 | if self.downsample:
531 | h = self.downsample(h)
532 |
533 | return h + self.shortcut(x)
534 |
--------------------------------------------------------------------------------
/models/nethook.py:
--------------------------------------------------------------------------------
1 | '''
2 | Utilities for instrumenting a torch model.
3 |
4 | InstrumentedModel will wrap a pytorch model and allow hooking
5 | arbitrary layers to monitor or modify their output directly.
6 | '''
7 |
8 | import torch, numpy, types, copy
9 | from collections import OrderedDict, defaultdict
10 |
11 |
12 | class InstrumentedModel(torch.nn.Module):
13 | '''
14 | A wrapper for hooking, probing and intervening in pytorch Modules.
15 | Example usage:
16 |
17 | ```
18 | model = load_my_model()
19 | with inst as InstrumentedModel(model):
20 | inst.retain_layer(layername)
21 | inst.edit_layer(layername, ablation=0.5, replacement=target_features)
22 | inst(inputs)
23 | original_features = inst.retained_layer(layername)
24 | ```
25 | '''
26 |
27 | def __init__(self, model):
28 | super().__init__()
29 | self.model = model
30 | self._retained = OrderedDict()
31 | self._detach_retained = {}
32 | self._editargs = defaultdict(dict)
33 | self._editrule = {}
34 | self._hooked_layer = {}
35 | self._old_forward = {}
36 | if isinstance(model, torch.nn.Sequential):
37 | self._hook_sequential()
38 |
39 | def __enter__(self):
40 | return self
41 |
42 | def __exit__(self, type, value, traceback):
43 | self.close()
44 |
45 | def forward(self, *inputs, **kwargs):
46 | return self.model(*inputs, **kwargs)
47 |
48 | def retain_layer(self, layername, detach=True):
49 | '''
50 | Pass a fully-qualified layer name (E.g., module.submodule.conv3)
51 | to hook that layer and retain its output each time the model is run.
52 | A pair (layername, aka) can be provided, and the aka will be used
53 | as the key for the retained value instead of the layername.
54 | '''
55 | self.retain_layers([layername], detach=detach)
56 |
57 | def retain_layers(self, layernames, detach=True):
58 | '''
59 | Retains a list of a layers at once.
60 | '''
61 | self.add_hooks(layernames)
62 | for layername in layernames:
63 | aka = layername
64 | if not isinstance(aka, str):
65 | layername, aka = layername
66 | if aka not in self._retained:
67 | self._retained[aka] = None
68 | self._detach_retained[aka] = detach
69 |
70 | def stop_retaining_layers(self, layernames):
71 | '''
72 | Removes a list of layers from the set retained.
73 | '''
74 | self.add_hooks(layernames)
75 | for layername in layernames:
76 | aka = layername
77 | if not isinstance(aka, str):
78 | layername, aka = layername
79 | if aka in self._retained:
80 | del self._retained[aka]
81 | del self._detach_retained[aka]
82 |
83 | def retained_features(self, clear=False):
84 | '''
85 | Returns a dict of all currently retained features.
86 | '''
87 | result = OrderedDict(self._retained)
88 | if clear:
89 | for k in result:
90 | self._retained[k] = None
91 | return result
92 |
93 | def retained_layer(self, aka=None, clear=False):
94 | '''
95 | Retrieve retained data that was previously hooked by retain_layer.
96 | Call this after the model is run. If clear is set, then the
97 | retained value will return and also cleared.
98 | '''
99 | if aka is None:
100 | # Default to the first retained layer.
101 | aka = next(self._retained.keys().__iter__())
102 | result = self._retained[aka]
103 | if clear:
104 | self._retained[aka] = None
105 | return result
106 |
107 | def edit_layer(self, layername, rule=None, **kwargs):
108 | '''
109 | Pass a fully-qualified layer name (E.g., module.submodule.conv3)
110 | to hook that layer and modify its output each time the model is run.
111 | The output of the layer will be modified to be a convex combination
112 | of the replacement and x interpolated according to the ablation, i.e.:
113 | `output = x * (1 - a) + (r * a)`.
114 | '''
115 | if not isinstance(layername, str):
116 | layername, aka = layername
117 | else:
118 | aka = layername
119 |
120 | # The default editing rule is apply_ablation_replacement
121 | if rule is None:
122 | rule = apply_ablation_replacement
123 |
124 | self.add_hooks([(layername, aka)])
125 | self._editargs[aka].update(kwargs)
126 | self._editrule[aka] = rule
127 |
128 | def remove_edits(self, layername=None):
129 | '''
130 | Removes edits at the specified layer, or removes edits at all layers
131 | if no layer name is specified.
132 | '''
133 | if layername is None:
134 | self._editargs.clear()
135 | self._editrule.clear()
136 | return
137 |
138 | if not isinstance(layername, str):
139 | layername, aka = layername
140 | else:
141 | aka = layername
142 | if aka in self._editargs:
143 | del self._editargs[aka]
144 | if aka in self._editrule:
145 | del self._editrule[aka]
146 |
147 | def add_hooks(self, layernames):
148 | '''
149 | Sets up a set of layers to be hooked.
150 |
151 | Usually not called directly: use edit_layer or retain_layer instead.
152 | '''
153 | needed = set()
154 | aka_map = {}
155 | for name in layernames:
156 | aka = name
157 | if not isinstance(aka, str):
158 | name, aka = name
159 | if self._hooked_layer.get(aka, None) != name:
160 | aka_map[name] = aka
161 | needed.add(name)
162 | if not needed:
163 | return
164 | for name, layer in self.model.named_modules():
165 | if name in aka_map:
166 | needed.remove(name)
167 | aka = aka_map[name]
168 | self._hook_layer(layer, name, aka)
169 | for name in needed:
170 | raise ValueError('Layer %s not found in model' % name)
171 |
172 | def _hook_layer(self, layer, layername, aka):
173 | '''
174 | Internal method to replace a forward method with a closure that
175 | intercepts the call, and tracks the hook so that it can be reverted.
176 | '''
177 | if aka in self._hooked_layer:
178 | raise ValueError('Layer %s already hooked' % aka)
179 | if layername in self._old_forward:
180 | raise ValueError('Layer %s already hooked' % layername)
181 | self._hooked_layer[aka] = layername
182 | self._old_forward[layername] = (layer, aka, layer.__dict__.get('forward', None))
183 | editor = self
184 | original_forward = layer.forward
185 |
186 | def new_forward(self, *inputs, **kwargs):
187 | original_x = original_forward(*inputs, **kwargs)
188 | x = editor._postprocess_forward(original_x, aka)
189 | return x
190 |
191 | layer.forward = types.MethodType(new_forward, layer)
192 |
193 | def _unhook_layer(self, aka):
194 | '''
195 | Internal method to remove a hook, restoring the original forward method.
196 | '''
197 | if aka not in self._hooked_layer:
198 | return
199 | layername = self._hooked_layer[aka]
200 | # Remove any retained data and any edit rules
201 | if aka in self._retained:
202 | del self._retained[aka]
203 | del self._detach_retained[aka]
204 | self.remove_edits(aka)
205 | # Restore the unhooked method for the layer
206 | layer, check, old_forward = self._old_forward[layername]
207 | assert check == aka
208 | if old_forward is None:
209 | if 'forward' in layer.__dict__:
210 | del layer.__dict__['forward']
211 | else:
212 | layer.forward = old_forward
213 | del self._old_forward[layername]
214 | del self._hooked_layer[aka]
215 |
216 | def _postprocess_forward(self, x, aka):
217 | '''
218 | The internal method called by the hooked layers after they are run.
219 | '''
220 | # Retain output before edits, if desired.
221 | if aka in self._retained:
222 | if self._detach_retained[aka]:
223 | self._retained[aka] = x.detach()
224 | else:
225 | self._retained[aka] = x
226 | # Apply any edits requested.
227 | rule = self._editrule.get(aka, None)
228 | if rule is not None:
229 | x = rule(x, self, **(self._editargs[aka]))
230 | return x
231 |
232 | def _hook_sequential(self):
233 | '''
234 | Replaces 'forward' of sequential with a version that takes
235 | additional keyword arguments: layer allows a single layer to be run;
236 | first_layer and last_layer allow a subsequence of layers to be run.
237 | '''
238 | model = self.model
239 | self._hooked_layer['.'] = '.'
240 | self._old_forward['.'] = (model, '.', model.__dict__.get('forward', None))
241 |
242 | def new_forward(this, x, layer=None, first_layer=None, last_layer=None):
243 | assert layer is None or (first_layer is None and last_layer is None)
244 | first_layer, last_layer = [
245 | str(layer) if layer is not None else str(d) if d is not None else None
246 | for d in [first_layer, last_layer]
247 | ]
248 | including_children = (first_layer is None)
249 | for name, layer in this._modules.items():
250 | if name == first_layer:
251 | first_layer = None
252 | including_children = True
253 | if including_children:
254 | x = layer(x)
255 | if name == last_layer:
256 | last_layer = None
257 | including_children = False
258 | assert first_layer is None, '%s not found' % first_layer
259 | assert last_layer is None, '%s not found' % last_layer
260 | return x
261 |
262 | model.forward = types.MethodType(new_forward, model)
263 |
264 | def close(self):
265 | '''
266 | Unhooks all hooked layers in the model.
267 | '''
268 | for aka in list(self._old_forward.keys()):
269 | self._unhook_layer(aka)
270 | assert len(self._old_forward) == 0
271 |
272 |
273 | def apply_ablation_replacement(x, imodel, **buffers):
274 | if buffers is not None:
275 | # Apply any edits requested.
276 | a = make_matching_tensor(buffers, 'ablation', x)
277 | if a is not None:
278 | x = x * (1 - a)
279 | v = make_matching_tensor(buffers, 'replacement', x)
280 | if v is not None:
281 | x += (v * a)
282 | return x
283 |
284 |
285 | def make_matching_tensor(valuedict, name, data):
286 | '''
287 | Converts `valuedict[name]` to be a tensor with the same dtype, device,
288 | and dimension count as `data`, and caches the converted tensor.
289 | '''
290 | v = valuedict.get(name, None)
291 | if v is None:
292 | return None
293 | if not isinstance(v, torch.Tensor):
294 | # Accept non-torch data.
295 | v = torch.from_numpy(numpy.array(v))
296 | valuedict[name] = v
297 | if not v.device == data.device or not v.dtype == data.dtype:
298 | # Ensure device and type matches.
299 | assert not v.requires_grad, '%s wrong device or type' % (name)
300 | v = v.to(device=data.device, dtype=data.dtype)
301 | valuedict[name] = v
302 | if len(v.shape) < len(data.shape):
303 | # Ensure dimensions are unsqueezed as needed.
304 | assert not v.requires_grad, '%s wrong dimensions' % (name)
305 | v = v.view((1, ) + tuple(v.shape) + (1, ) * (len(data.shape) - len(v.shape) - 1))
306 | valuedict[name] = v
307 | return v
308 |
309 |
310 | def subsequence(sequential,
311 | first_layer=None,
312 | last_layer=None,
313 | upto_layer=None,
314 | single_layer=None,
315 | share_weights=False):
316 | '''
317 | Creates a subsequence of a pytorch Sequential model, copying over
318 | modules together with parameters for the subsequence. Only
319 | modules from first_layer to last_layer (inclusive) are included.
320 |
321 | If share_weights is True, then references the original modules
322 | and their parameters without copying them. Otherwise, by default,
323 | makes a separate brand-new copy.
324 | '''
325 | assert ((single_layer is None) or (first_layer is last_layer is upto_layer is None))
326 | assert (last_layer is None) or (upto_layer is None)
327 | if single_layer is not None:
328 | first_layer = single_layer
329 | last_layer = single_layer
330 | included_children = OrderedDict()
331 | including_children = (first_layer is None)
332 | for name, layer in sequential._modules.items():
333 | if name == first_layer:
334 | first_layer = None
335 | including_children = True
336 | if name == upto_layer:
337 | upto_layer = None
338 | including_children = False
339 | if including_children:
340 | included_children[name] = layer if share_weights else (copy.deepcopy(layer))
341 | if name == last_layer:
342 | last_layer = None
343 | including_children = False
344 | if first_layer is not None:
345 | raise ValueError('Layer %s not found' % first_layer)
346 | if last_layer is not None:
347 | raise ValueError('Layer %s not found' % last_layer)
348 | if upto_layer is not None:
349 | raise ValueError('Layer %s not found' % upto_layer)
350 | # if not len(included_children):
351 | # raise ValueError('Empty subsequence')
352 | return torch.nn.Sequential(OrderedDict(included_children))
353 |
354 |
355 | def set_requires_grad(requires_grad, *models):
356 | for model in models:
357 | if isinstance(model, torch.nn.Module):
358 | for param in model.parameters():
359 | param.requires_grad = requires_grad
360 | elif isinstance(model, (torch.nn.Parameter, torch.Tensor)):
361 | model.requires_grad = requires_grad
362 | else:
363 | assert False, 'unknown type %r' % type(model)
364 |
--------------------------------------------------------------------------------
/pretrained/.gitignore:
--------------------------------------------------------------------------------
1 | # Ignore everything in this directory
2 | *
3 | # Except this file
4 | !.gitignore
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | scikit-image
3 | scipy
4 | pillow<=6.2.2
5 | torchvision
6 | tensorboardX
7 |
--------------------------------------------------------------------------------
/scripts/imagenet_val_1k.txt:
--------------------------------------------------------------------------------
1 | ILSVRC2012_val_00000293.JPEG 0
2 | ILSVRC2012_val_00000236.JPEG 1
3 | ILSVRC2012_val_00002338.JPEG 2
4 | ILSVRC2012_val_00002922.JPEG 3
5 | ILSVRC2012_val_00001676.JPEG 4
6 | ILSVRC2012_val_00000921.JPEG 5
7 | ILSVRC2012_val_00001935.JPEG 6
8 | ILSVRC2012_val_00000329.JPEG 7
9 | ILSVRC2012_val_00001114.JPEG 8
10 | ILSVRC2012_val_00001031.JPEG 9
11 | ILSVRC2012_val_00000651.JPEG 10
12 | ILSVRC2012_val_00000570.JPEG 11
13 | ILSVRC2012_val_00000873.JPEG 12
14 | ILSVRC2012_val_00000247.JPEG 13
15 | ILSVRC2012_val_00000414.JPEG 14
16 | ILSVRC2012_val_00001598.JPEG 15
17 | ILSVRC2012_val_00000198.JPEG 16
18 | ILSVRC2012_val_00000880.JPEG 17
19 | ILSVRC2012_val_00000476.JPEG 18
20 | ILSVRC2012_val_00000837.JPEG 19
21 | ILSVRC2012_val_00000962.JPEG 20
22 | ILSVRC2012_val_00000073.JPEG 21
23 | ILSVRC2012_val_00000128.JPEG 22
24 | ILSVRC2012_val_00000018.JPEG 23
25 | ILSVRC2012_val_00000318.JPEG 24
26 | ILSVRC2012_val_00000225.JPEG 25
27 | ILSVRC2012_val_00000498.JPEG 26
28 | ILSVRC2012_val_00000258.JPEG 27
29 | ILSVRC2012_val_00000088.JPEG 28
30 | ILSVRC2012_val_00000052.JPEG 29
31 | ILSVRC2012_val_00001652.JPEG 30
32 | ILSVRC2012_val_00000944.JPEG 31
33 | ILSVRC2012_val_00000037.JPEG 32
34 | ILSVRC2012_val_00000404.JPEG 33
35 | ILSVRC2012_val_00000791.JPEG 34
36 | ILSVRC2012_val_00000229.JPEG 35
37 | ILSVRC2012_val_00002832.JPEG 36
38 | ILSVRC2012_val_00000362.JPEG 37
39 | ILSVRC2012_val_00003904.JPEG 38
40 | ILSVRC2012_val_00000153.JPEG 39
41 | ILSVRC2012_val_00001931.JPEG 40
42 | ILSVRC2012_val_00000908.JPEG 41
43 | ILSVRC2012_val_00003821.JPEG 42
44 | ILSVRC2012_val_00000297.JPEG 43
45 | ILSVRC2012_val_00001022.JPEG 44
46 | ILSVRC2012_val_00001475.JPEG 45
47 | ILSVRC2012_val_00000064.JPEG 46
48 | ILSVRC2012_val_00001375.JPEG 47
49 | ILSVRC2012_val_00000541.JPEG 48
50 | ILSVRC2012_val_00001573.JPEG 49
51 | ILSVRC2012_val_00000090.JPEG 50
52 | ILSVRC2012_val_00002051.JPEG 51
53 | ILSVRC2012_val_00003569.JPEG 52
54 | ILSVRC2012_val_00001042.JPEG 53
55 | ILSVRC2012_val_00002568.JPEG 54
56 | ILSVRC2012_val_00000029.JPEG 55
57 | ILSVRC2012_val_00000512.JPEG 56
58 | ILSVRC2012_val_00000006.JPEG 57
59 | ILSVRC2012_val_00000084.JPEG 58
60 | ILSVRC2012_val_00001108.JPEG 59
61 | ILSVRC2012_val_00000337.JPEG 60
62 | ILSVRC2012_val_00000786.JPEG 61
63 | ILSVRC2012_val_00000688.JPEG 62
64 | ILSVRC2012_val_00003143.JPEG 63
65 | ILSVRC2012_val_00000298.JPEG 64
66 | ILSVRC2012_val_00000001.JPEG 65
67 | ILSVRC2012_val_00001582.JPEG 66
68 | ILSVRC2012_val_00000749.JPEG 67
69 | ILSVRC2012_val_00001706.JPEG 68
70 | ILSVRC2012_val_00000299.JPEG 69
71 | ILSVRC2012_val_00000107.JPEG 70
72 | ILSVRC2012_val_00001007.JPEG 71
73 | ILSVRC2012_val_00001037.JPEG 72
74 | ILSVRC2012_val_00002688.JPEG 73
75 | ILSVRC2012_val_00000040.JPEG 74
76 | ILSVRC2012_val_00001150.JPEG 75
77 | ILSVRC2012_val_00000396.JPEG 76
78 | ILSVRC2012_val_00000218.JPEG 77
79 | ILSVRC2012_val_00000679.JPEG 78
80 | ILSVRC2012_val_00003981.JPEG 79
81 | ILSVRC2012_val_00000075.JPEG 80
82 | ILSVRC2012_val_00000860.JPEG 81
83 | ILSVRC2012_val_00002539.JPEG 82
84 | ILSVRC2012_val_00001461.JPEG 83
85 | ILSVRC2012_val_00001146.JPEG 84
86 | ILSVRC2012_val_00003015.JPEG 85
87 | ILSVRC2012_val_00000576.JPEG 86
88 | ILSVRC2012_val_00001629.JPEG 87
89 | ILSVRC2012_val_00001163.JPEG 88
90 | ILSVRC2012_val_00001288.JPEG 89
91 | ILSVRC2012_val_00001749.JPEG 90
92 | ILSVRC2012_val_00000370.JPEG 91
93 | ILSVRC2012_val_00000051.JPEG 92
94 | ILSVRC2012_val_00002100.JPEG 93
95 | ILSVRC2012_val_00002974.JPEG 94
96 | ILSVRC2012_val_00002326.JPEG 95
97 | ILSVRC2012_val_00000952.JPEG 96
98 | ILSVRC2012_val_00000415.JPEG 97
99 | ILSVRC2012_val_00002309.JPEG 98
100 | ILSVRC2012_val_00000547.JPEG 99
101 | ILSVRC2012_val_00000466.JPEG 100
102 | ILSVRC2012_val_00000067.JPEG 101
103 | ILSVRC2012_val_00000319.JPEG 102
104 | ILSVRC2012_val_00000189.JPEG 103
105 | ILSVRC2012_val_00000655.JPEG 104
106 | ILSVRC2012_val_00000970.JPEG 105
107 | ILSVRC2012_val_00001933.JPEG 106
108 | ILSVRC2012_val_00000063.JPEG 107
109 | ILSVRC2012_val_00000017.JPEG 108
110 | ILSVRC2012_val_00000011.JPEG 109
111 | ILSVRC2012_val_00003054.JPEG 110
112 | ILSVRC2012_val_00000836.JPEG 111
113 | ILSVRC2012_val_00000781.JPEG 112
114 | ILSVRC2012_val_00001268.JPEG 113
115 | ILSVRC2012_val_00000546.JPEG 114
116 | ILSVRC2012_val_00000154.JPEG 115
117 | ILSVRC2012_val_00000689.JPEG 116
118 | ILSVRC2012_val_00001356.JPEG 117
119 | ILSVRC2012_val_00000195.JPEG 118
120 | ILSVRC2012_val_00002684.JPEG 119
121 | ILSVRC2012_val_00000818.JPEG 120
122 | ILSVRC2012_val_00000238.JPEG 121
123 | ILSVRC2012_val_00000111.JPEG 122
124 | ILSVRC2012_val_00002152.JPEG 123
125 | ILSVRC2012_val_00001381.JPEG 124
126 | ILSVRC2012_val_00000231.JPEG 125
127 | ILSVRC2012_val_00000669.JPEG 126
128 | ILSVRC2012_val_00000690.JPEG 127
129 | ILSVRC2012_val_00001068.JPEG 128
130 | ILSVRC2012_val_00000043.JPEG 129
131 | ILSVRC2012_val_00000356.JPEG 130
132 | ILSVRC2012_val_00000794.JPEG 131
133 | ILSVRC2012_val_00001067.JPEG 132
134 | ILSVRC2012_val_00000363.JPEG 133
135 | ILSVRC2012_val_00000389.JPEG 134
136 | ILSVRC2012_val_00002126.JPEG 135
137 | ILSVRC2012_val_00003399.JPEG 136
138 | ILSVRC2012_val_00000825.JPEG 137
139 | ILSVRC2012_val_00001073.JPEG 138
140 | ILSVRC2012_val_00001800.JPEG 139
141 | ILSVRC2012_val_00000177.JPEG 140
142 | ILSVRC2012_val_00000470.JPEG 141
143 | ILSVRC2012_val_00000098.JPEG 142
144 | ILSVRC2012_val_00000548.JPEG 143
145 | ILSVRC2012_val_00003147.JPEG 144
146 | ILSVRC2012_val_00001056.JPEG 145
147 | ILSVRC2012_val_00000207.JPEG 146
148 | ILSVRC2012_val_00000016.JPEG 147
149 | ILSVRC2012_val_00000545.JPEG 148
150 | ILSVRC2012_val_00000072.JPEG 149
151 | ILSVRC2012_val_00000033.JPEG 150
152 | ILSVRC2012_val_00001049.JPEG 151
153 | ILSVRC2012_val_00000462.JPEG 152
154 | ILSVRC2012_val_00000457.JPEG 153
155 | ILSVRC2012_val_00001069.JPEG 154
156 | ILSVRC2012_val_00000907.JPEG 155
157 | ILSVRC2012_val_00000613.JPEG 156
158 | ILSVRC2012_val_00001389.JPEG 157
159 | ILSVRC2012_val_00000170.JPEG 158
160 | ILSVRC2012_val_00000077.JPEG 159
161 | ILSVRC2012_val_00000082.JPEG 160
162 | ILSVRC2012_val_00000838.JPEG 161
163 | ILSVRC2012_val_00000599.JPEG 162
164 | ILSVRC2012_val_00000141.JPEG 163
165 | ILSVRC2012_val_00001097.JPEG 164
166 | ILSVRC2012_val_00000327.JPEG 165
167 | ILSVRC2012_val_00000511.JPEG 166
168 | ILSVRC2012_val_00000028.JPEG 167
169 | ILSVRC2012_val_00001267.JPEG 168
170 | ILSVRC2012_val_00001626.JPEG 169
171 | ILSVRC2012_val_00001765.JPEG 170
172 | ILSVRC2012_val_00001825.JPEG 171
173 | ILSVRC2012_val_00001846.JPEG 172
174 | ILSVRC2012_val_00000022.JPEG 173
175 | ILSVRC2012_val_00002328.JPEG 174
176 | ILSVRC2012_val_00000079.JPEG 175
177 | ILSVRC2012_val_00000263.JPEG 176
178 | ILSVRC2012_val_00000405.JPEG 177
179 | ILSVRC2012_val_00001385.JPEG 178
180 | ILSVRC2012_val_00000248.JPEG 179
181 | ILSVRC2012_val_00000191.JPEG 180
182 | ILSVRC2012_val_00000883.JPEG 181
183 | ILSVRC2012_val_00000832.JPEG 182
184 | ILSVRC2012_val_00000062.JPEG 183
185 | ILSVRC2012_val_00000588.JPEG 184
186 | ILSVRC2012_val_00000391.JPEG 185
187 | ILSVRC2012_val_00000612.JPEG 186
188 | ILSVRC2012_val_00000696.JPEG 187
189 | ILSVRC2012_val_00001001.JPEG 188
190 | ILSVRC2012_val_00002194.JPEG 189
191 | ILSVRC2012_val_00000474.JPEG 190
192 | ILSVRC2012_val_00003002.JPEG 191
193 | ILSVRC2012_val_00000379.JPEG 192
194 | ILSVRC2012_val_00001564.JPEG 193
195 | ILSVRC2012_val_00000279.JPEG 194
196 | ILSVRC2012_val_00000350.JPEG 195
197 | ILSVRC2012_val_00000176.JPEG 196
198 | ILSVRC2012_val_00002573.JPEG 197
199 | ILSVRC2012_val_00000044.JPEG 198
200 | ILSVRC2012_val_00000411.JPEG 199
201 | ILSVRC2012_val_00000802.JPEG 200
202 | ILSVRC2012_val_00000105.JPEG 201
203 | ILSVRC2012_val_00000748.JPEG 202
204 | ILSVRC2012_val_00002784.JPEG 203
205 | ILSVRC2012_val_00001481.JPEG 204
206 | ILSVRC2012_val_00000824.JPEG 205
207 | ILSVRC2012_val_00001434.JPEG 206
208 | ILSVRC2012_val_00001112.JPEG 207
209 | ILSVRC2012_val_00000375.JPEG 208
210 | ILSVRC2012_val_00000668.JPEG 209
211 | ILSVRC2012_val_00001070.JPEG 210
212 | ILSVRC2012_val_00002323.JPEG 211
213 | ILSVRC2012_val_00000172.JPEG 212
214 | ILSVRC2012_val_00000161.JPEG 213
215 | ILSVRC2012_val_00000708.JPEG 214
216 | ILSVRC2012_val_00000227.JPEG 215
217 | ILSVRC2012_val_00000643.JPEG 216
218 | ILSVRC2012_val_00000665.JPEG 217
219 | ILSVRC2012_val_00002298.JPEG 218
220 | ILSVRC2012_val_00000524.JPEG 219
221 | ILSVRC2012_val_00000539.JPEG 220
222 | ILSVRC2012_val_00001693.JPEG 221
223 | ILSVRC2012_val_00000507.JPEG 222
224 | ILSVRC2012_val_00000533.JPEG 223
225 | ILSVRC2012_val_00001109.JPEG 224
226 | ILSVRC2012_val_00003176.JPEG 225
227 | ILSVRC2012_val_00004938.JPEG 226
228 | ILSVRC2012_val_00000113.JPEG 227
229 | ILSVRC2012_val_00000383.JPEG 228
230 | ILSVRC2012_val_00002247.JPEG 229
231 | ILSVRC2012_val_00000003.JPEG 230
232 | ILSVRC2012_val_00003480.JPEG 231
233 | ILSVRC2012_val_00000672.JPEG 232
234 | ILSVRC2012_val_00000203.JPEG 233
235 | ILSVRC2012_val_00003623.JPEG 234
236 | ILSVRC2012_val_00000946.JPEG 235
237 | ILSVRC2012_val_00000134.JPEG 236
238 | ILSVRC2012_val_00004386.JPEG 237
239 | ILSVRC2012_val_00000531.JPEG 238
240 | ILSVRC2012_val_00002708.JPEG 239
241 | ILSVRC2012_val_00000594.JPEG 240
242 | ILSVRC2012_val_00001005.JPEG 241
243 | ILSVRC2012_val_00000635.JPEG 242
244 | ILSVRC2012_val_00000975.JPEG 243
245 | ILSVRC2012_val_00004506.JPEG 244
246 | ILSVRC2012_val_00001313.JPEG 245
247 | ILSVRC2012_val_00001408.JPEG 246
248 | ILSVRC2012_val_00001937.JPEG 247
249 | ILSVRC2012_val_00000178.JPEG 248
250 | ILSVRC2012_val_00000729.JPEG 249
251 | ILSVRC2012_val_00000269.JPEG 250
252 | ILSVRC2012_val_00001702.JPEG 251
253 | ILSVRC2012_val_00002267.JPEG 252
254 | ILSVRC2012_val_00001344.JPEG 253
255 | ILSVRC2012_val_00000220.JPEG 254
256 | ILSVRC2012_val_00001203.JPEG 255
257 | ILSVRC2012_val_00000045.JPEG 256
258 | ILSVRC2012_val_00000174.JPEG 257
259 | ILSVRC2012_val_00000590.JPEG 258
260 | ILSVRC2012_val_00000057.JPEG 259
261 | ILSVRC2012_val_00000355.JPEG 260
262 | ILSVRC2012_val_00000658.JPEG 261
263 | ILSVRC2012_val_00001568.JPEG 262
264 | ILSVRC2012_val_00001096.JPEG 263
265 | ILSVRC2012_val_00000997.JPEG 264
266 | ILSVRC2012_val_00000147.JPEG 265
267 | ILSVRC2012_val_00000867.JPEG 266
268 | ILSVRC2012_val_00000399.JPEG 267
269 | ILSVRC2012_val_00000719.JPEG 268
270 | ILSVRC2012_val_00000530.JPEG 269
271 | ILSVRC2012_val_00000027.JPEG 270
272 | ILSVRC2012_val_00003172.JPEG 271
273 | ILSVRC2012_val_00000156.JPEG 272
274 | ILSVRC2012_val_00000604.JPEG 273
275 | ILSVRC2012_val_00001148.JPEG 274
276 | ILSVRC2012_val_00000078.JPEG 275
277 | ILSVRC2012_val_00000133.JPEG 276
278 | ILSVRC2012_val_00000157.JPEG 277
279 | ILSVRC2012_val_00000756.JPEG 278
280 | ILSVRC2012_val_00001335.JPEG 279
281 | ILSVRC2012_val_00000323.JPEG 280
282 | ILSVRC2012_val_00001341.JPEG 281
283 | ILSVRC2012_val_00000796.JPEG 282
284 | ILSVRC2012_val_00000130.JPEG 283
285 | ILSVRC2012_val_00000211.JPEG 284
286 | ILSVRC2012_val_00000709.JPEG 285
287 | ILSVRC2012_val_00000012.JPEG 286
288 | ILSVRC2012_val_00000898.JPEG 287
289 | ILSVRC2012_val_00000600.JPEG 288
290 | ILSVRC2012_val_00000186.JPEG 289
291 | ILSVRC2012_val_00000610.JPEG 290
292 | ILSVRC2012_val_00001359.JPEG 291
293 | ILSVRC2012_val_00000303.JPEG 292
294 | ILSVRC2012_val_00000348.JPEG 293
295 | ILSVRC2012_val_00000871.JPEG 294
296 | ILSVRC2012_val_00000592.JPEG 295
297 | ILSVRC2012_val_00001498.JPEG 296
298 | ILSVRC2012_val_00001745.JPEG 297
299 | ILSVRC2012_val_00000518.JPEG 298
300 | ILSVRC2012_val_00000779.JPEG 299
301 | ILSVRC2012_val_00000361.JPEG 300
302 | ILSVRC2012_val_00000625.JPEG 301
303 | ILSVRC2012_val_00000956.JPEG 302
304 | ILSVRC2012_val_00000926.JPEG 303
305 | ILSVRC2012_val_00000510.JPEG 304
306 | ILSVRC2012_val_00000342.JPEG 305
307 | ILSVRC2012_val_00000692.JPEG 306
308 | ILSVRC2012_val_00001570.JPEG 307
309 | ILSVRC2012_val_00001763.JPEG 308
310 | ILSVRC2012_val_00001876.JPEG 309
311 | ILSVRC2012_val_00000407.JPEG 310
312 | ILSVRC2012_val_00000125.JPEG 311
313 | ILSVRC2012_val_00001873.JPEG 312
314 | ILSVRC2012_val_00001585.JPEG 313
315 | ILSVRC2012_val_00001082.JPEG 314
316 | ILSVRC2012_val_00000772.JPEG 315
317 | ILSVRC2012_val_00002649.JPEG 316
318 | ILSVRC2012_val_00003032.JPEG 317
319 | ILSVRC2012_val_00001466.JPEG 318
320 | ILSVRC2012_val_00000418.JPEG 319
321 | ILSVRC2012_val_00000639.JPEG 320
322 | ILSVRC2012_val_00001187.JPEG 321
323 | ILSVRC2012_val_00000810.JPEG 322
324 | ILSVRC2012_val_00001290.JPEG 323
325 | ILSVRC2012_val_00000031.JPEG 324
326 | ILSVRC2012_val_00000581.JPEG 325
327 | ILSVRC2012_val_00001536.JPEG 326
328 | ILSVRC2012_val_00001161.JPEG 327
329 | ILSVRC2012_val_00000138.JPEG 328
330 | ILSVRC2012_val_00000301.JPEG 329
331 | ILSVRC2012_val_00000097.JPEG 330
332 | ILSVRC2012_val_00000644.JPEG 331
333 | ILSVRC2012_val_00000010.JPEG 332
334 | ILSVRC2012_val_00000617.JPEG 333
335 | ILSVRC2012_val_00000007.JPEG 334
336 | ILSVRC2012_val_00000502.JPEG 335
337 | ILSVRC2012_val_00000675.JPEG 336
338 | ILSVRC2012_val_00003202.JPEG 337
339 | ILSVRC2012_val_00000565.JPEG 338
340 | ILSVRC2012_val_00000179.JPEG 339
341 | ILSVRC2012_val_00002070.JPEG 340
342 | ILSVRC2012_val_00000478.JPEG 341
343 | ILSVRC2012_val_00002006.JPEG 342
344 | ILSVRC2012_val_00000136.JPEG 343
345 | ILSVRC2012_val_00001066.JPEG 344
346 | ILSVRC2012_val_00001280.JPEG 345
347 | ILSVRC2012_val_00000306.JPEG 346
348 | ILSVRC2012_val_00002433.JPEG 347
349 | ILSVRC2012_val_00001027.JPEG 348
350 | ILSVRC2012_val_00000100.JPEG 349
351 | ILSVRC2012_val_00000444.JPEG 350
352 | ILSVRC2012_val_00001232.JPEG 351
353 | ILSVRC2012_val_00000127.JPEG 352
354 | ILSVRC2012_val_00003125.JPEG 353
355 | ILSVRC2012_val_00002271.JPEG 354
356 | ILSVRC2012_val_00000343.JPEG 355
357 | ILSVRC2012_val_00000723.JPEG 356
358 | ILSVRC2012_val_00002028.JPEG 357
359 | ILSVRC2012_val_00000055.JPEG 358
360 | ILSVRC2012_val_00000326.JPEG 359
361 | ILSVRC2012_val_00001407.JPEG 360
362 | ILSVRC2012_val_00000181.JPEG 361
363 | ILSVRC2012_val_00000726.JPEG 362
364 | ILSVRC2012_val_00002448.JPEG 363
365 | ILSVRC2012_val_00001442.JPEG 364
366 | ILSVRC2012_val_00001535.JPEG 365
367 | ILSVRC2012_val_00000093.JPEG 366
368 | ILSVRC2012_val_00000553.JPEG 367
369 | ILSVRC2012_val_00003388.JPEG 368
370 | ILSVRC2012_val_00000087.JPEG 369
371 | ILSVRC2012_val_00000013.JPEG 370
372 | ILSVRC2012_val_00000964.JPEG 371
373 | ILSVRC2012_val_00000486.JPEG 372
374 | ILSVRC2012_val_00000095.JPEG 373
375 | ILSVRC2012_val_00001254.JPEG 374
376 | ILSVRC2012_val_00001095.JPEG 375
377 | ILSVRC2012_val_00000206.JPEG 376
378 | ILSVRC2012_val_00000520.JPEG 377
379 | ILSVRC2012_val_00002450.JPEG 378
380 | ILSVRC2012_val_00003118.JPEG 379
381 | ILSVRC2012_val_00000340.JPEG 380
382 | ILSVRC2012_val_00003812.JPEG 381
383 | ILSVRC2012_val_00002011.JPEG 382
384 | ILSVRC2012_val_00000092.JPEG 383
385 | ILSVRC2012_val_00001025.JPEG 384
386 | ILSVRC2012_val_00000597.JPEG 385
387 | ILSVRC2012_val_00001177.JPEG 386
388 | ILSVRC2012_val_00002863.JPEG 387
389 | ILSVRC2012_val_00000481.JPEG 388
390 | ILSVRC2012_val_00000312.JPEG 389
391 | ILSVRC2012_val_00000066.JPEG 390
392 | ILSVRC2012_val_00002364.JPEG 391
393 | ILSVRC2012_val_00000124.JPEG 392
394 | ILSVRC2012_val_00001176.JPEG 393
395 | ILSVRC2012_val_00000050.JPEG 394
396 | ILSVRC2012_val_00001975.JPEG 395
397 | ILSVRC2012_val_00000196.JPEG 396
398 | ILSVRC2012_val_00000699.JPEG 397
399 | ILSVRC2012_val_00000038.JPEG 398
400 | ILSVRC2012_val_00001285.JPEG 399
401 | ILSVRC2012_val_00000851.JPEG 400
402 | ILSVRC2012_val_00000204.JPEG 401
403 | ILSVRC2012_val_00001416.JPEG 402
404 | ILSVRC2012_val_00000289.JPEG 403
405 | ILSVRC2012_val_00000631.JPEG 404
406 | ILSVRC2012_val_00001602.JPEG 405
407 | ILSVRC2012_val_00000932.JPEG 406
408 | ILSVRC2012_val_00001910.JPEG 407
409 | ILSVRC2012_val_00000998.JPEG 408
410 | ILSVRC2012_val_00001011.JPEG 409
411 | ILSVRC2012_val_00002391.JPEG 410
412 | ILSVRC2012_val_00000771.JPEG 411
413 | ILSVRC2012_val_00000422.JPEG 412
414 | ILSVRC2012_val_00000268.JPEG 413
415 | ILSVRC2012_val_00000459.JPEG 414
416 | ILSVRC2012_val_00000008.JPEG 415
417 | ILSVRC2012_val_00000494.JPEG 416
418 | ILSVRC2012_val_00001731.JPEG 417
419 | ILSVRC2012_val_00000721.JPEG 418
420 | ILSVRC2012_val_00000400.JPEG 419
421 | ILSVRC2012_val_00000315.JPEG 420
422 | ILSVRC2012_val_00000272.JPEG 421
423 | ILSVRC2012_val_00004406.JPEG 422
424 | ILSVRC2012_val_00000759.JPEG 423
425 | ILSVRC2012_val_00000076.JPEG 424
426 | ILSVRC2012_val_00000234.JPEG 425
427 | ILSVRC2012_val_00000287.JPEG 426
428 | ILSVRC2012_val_00000687.JPEG 427
429 | ILSVRC2012_val_00002509.JPEG 428
430 | ILSVRC2012_val_00001468.JPEG 429
431 | ILSVRC2012_val_00001517.JPEG 430
432 | ILSVRC2012_val_00000354.JPEG 431
433 | ILSVRC2012_val_00001090.JPEG 432
434 | ILSVRC2012_val_00001013.JPEG 433
435 | ILSVRC2012_val_00001384.JPEG 434
436 | ILSVRC2012_val_00002563.JPEG 435
437 | ILSVRC2012_val_00000175.JPEG 436
438 | ILSVRC2012_val_00000483.JPEG 437
439 | ILSVRC2012_val_00000253.JPEG 438
440 | ILSVRC2012_val_00001048.JPEG 439
441 | ILSVRC2012_val_00000654.JPEG 440
442 | ILSVRC2012_val_00003782.JPEG 441
443 | ILSVRC2012_val_00001103.JPEG 442
444 | ILSVRC2012_val_00001252.JPEG 443
445 | ILSVRC2012_val_00000953.JPEG 444
446 | ILSVRC2012_val_00000239.JPEG 445
447 | ILSVRC2012_val_00003446.JPEG 446
448 | ILSVRC2012_val_00000917.JPEG 447
449 | ILSVRC2012_val_00000514.JPEG 448
450 | ILSVRC2012_val_00000193.JPEG 449
451 | ILSVRC2012_val_00000601.JPEG 450
452 | ILSVRC2012_val_00000763.JPEG 451
453 | ILSVRC2012_val_00000906.JPEG 452
454 | ILSVRC2012_val_00003274.JPEG 453
455 | ILSVRC2012_val_00000879.JPEG 454
456 | ILSVRC2012_val_00000634.JPEG 455
457 | ILSVRC2012_val_00001354.JPEG 456
458 | ILSVRC2012_val_00000126.JPEG 457
459 | ILSVRC2012_val_00001145.JPEG 458
460 | ILSVRC2012_val_00000808.JPEG 459
461 | ILSVRC2012_val_00000358.JPEG 460
462 | ILSVRC2012_val_00000080.JPEG 461
463 | ILSVRC2012_val_00002937.JPEG 462
464 | ILSVRC2012_val_00000226.JPEG 463
465 | ILSVRC2012_val_00000558.JPEG 464
466 | ILSVRC2012_val_00000366.JPEG 465
467 | ILSVRC2012_val_00000562.JPEG 466
468 | ILSVRC2012_val_00000071.JPEG 467
469 | ILSVRC2012_val_00000056.JPEG 468
470 | ILSVRC2012_val_00000302.JPEG 469
471 | ILSVRC2012_val_00002596.JPEG 470
472 | ILSVRC2012_val_00000392.JPEG 471
473 | ILSVRC2012_val_00000347.JPEG 472
474 | ILSVRC2012_val_00000101.JPEG 473
475 | ILSVRC2012_val_00000061.JPEG 474
476 | ILSVRC2012_val_00000213.JPEG 475
477 | ILSVRC2012_val_00000074.JPEG 476
478 | ILSVRC2012_val_00001041.JPEG 477
479 | ILSVRC2012_val_00000019.JPEG 478
480 | ILSVRC2012_val_00000085.JPEG 479
481 | ILSVRC2012_val_00000580.JPEG 480
482 | ILSVRC2012_val_00000353.JPEG 481
483 | ILSVRC2012_val_00000557.JPEG 482
484 | ILSVRC2012_val_00000122.JPEG 483
485 | ILSVRC2012_val_00000402.JPEG 484
486 | ILSVRC2012_val_00003491.JPEG 485
487 | ILSVRC2012_val_00000108.JPEG 486
488 | ILSVRC2012_val_00000089.JPEG 487
489 | ILSVRC2012_val_00000376.JPEG 488
490 | ILSVRC2012_val_00001094.JPEG 489
491 | ILSVRC2012_val_00002280.JPEG 490
492 | ILSVRC2012_val_00000537.JPEG 491
493 | ILSVRC2012_val_00001168.JPEG 492
494 | ILSVRC2012_val_00000162.JPEG 493
495 | ILSVRC2012_val_00001601.JPEG 494
496 | ILSVRC2012_val_00000346.JPEG 495
497 | ILSVRC2012_val_00001908.JPEG 496
498 | ILSVRC2012_val_00003351.JPEG 497
499 | ILSVRC2012_val_00000086.JPEG 498
500 | ILSVRC2012_val_00000564.JPEG 499
501 | ILSVRC2012_val_00001426.JPEG 500
502 | ILSVRC2012_val_00000750.JPEG 501
503 | ILSVRC2012_val_00001500.JPEG 502
504 | ILSVRC2012_val_00000344.JPEG 503
505 | ILSVRC2012_val_00000164.JPEG 504
506 | ILSVRC2012_val_00001940.JPEG 505
507 | ILSVRC2012_val_00000915.JPEG 506
508 | ILSVRC2012_val_00004971.JPEG 507
509 | ILSVRC2012_val_00000321.JPEG 508
510 | ILSVRC2012_val_00001190.JPEG 509
511 | ILSVRC2012_val_00000403.JPEG 510
512 | ILSVRC2012_val_00001201.JPEG 511
513 | ILSVRC2012_val_00000678.JPEG 512
514 | ILSVRC2012_val_00002171.JPEG 513
515 | ILSVRC2012_val_00000766.JPEG 514
516 | ILSVRC2012_val_00001875.JPEG 515
517 | ILSVRC2012_val_00000005.JPEG 516
518 | ILSVRC2012_val_00000020.JPEG 517
519 | ILSVRC2012_val_00001689.JPEG 518
520 | ILSVRC2012_val_00001411.JPEG 519
521 | ILSVRC2012_val_00003036.JPEG 520
522 | ILSVRC2012_val_00000534.JPEG 521
523 | ILSVRC2012_val_00000249.JPEG 522
524 | ILSVRC2012_val_00001337.JPEG 523
525 | ILSVRC2012_val_00000730.JPEG 524
526 | ILSVRC2012_val_00000330.JPEG 525
527 | ILSVRC2012_val_00000777.JPEG 526
528 | ILSVRC2012_val_00001246.JPEG 527
529 | ILSVRC2012_val_00000137.JPEG 528
530 | ILSVRC2012_val_00000493.JPEG 529
531 | ILSVRC2012_val_00003643.JPEG 530
532 | ILSVRC2012_val_00000339.JPEG 531
533 | ILSVRC2012_val_00000241.JPEG 532
534 | ILSVRC2012_val_00000256.JPEG 533
535 | ILSVRC2012_val_00000698.JPEG 534
536 | ILSVRC2012_val_00002489.JPEG 535
537 | ILSVRC2012_val_00001729.JPEG 536
538 | ILSVRC2012_val_00002266.JPEG 537
539 | ILSVRC2012_val_00001868.JPEG 538
540 | ILSVRC2012_val_00001893.JPEG 539
541 | ILSVRC2012_val_00000504.JPEG 540
542 | ILSVRC2012_val_00001609.JPEG 541
543 | ILSVRC2012_val_00000542.JPEG 542
544 | ILSVRC2012_val_00001045.JPEG 543
545 | ILSVRC2012_val_00000182.JPEG 544
546 | ILSVRC2012_val_00000942.JPEG 545
547 | ILSVRC2012_val_00000374.JPEG 546
548 | ILSVRC2012_val_00001699.JPEG 547
549 | ILSVRC2012_val_00000199.JPEG 548
550 | ILSVRC2012_val_00000419.JPEG 549
551 | ILSVRC2012_val_00001545.JPEG 550
552 | ILSVRC2012_val_00000624.JPEG 551
553 | ILSVRC2012_val_00001256.JPEG 552
554 | ILSVRC2012_val_00000132.JPEG 553
555 | ILSVRC2012_val_00000288.JPEG 554
556 | ILSVRC2012_val_00000764.JPEG 555
557 | ILSVRC2012_val_00000743.JPEG 556
558 | ILSVRC2012_val_00000445.JPEG 557
559 | ILSVRC2012_val_00000140.JPEG 558
560 | ILSVRC2012_val_00001616.JPEG 559
561 | ILSVRC2012_val_00002562.JPEG 560
562 | ILSVRC2012_val_00001579.JPEG 561
563 | ILSVRC2012_val_00001018.JPEG 562
564 | ILSVRC2012_val_00002799.JPEG 563
565 | ILSVRC2012_val_00002945.JPEG 564
566 | ILSVRC2012_val_00000047.JPEG 565
567 | ILSVRC2012_val_00000957.JPEG 566
568 | ILSVRC2012_val_00000516.JPEG 567
569 | ILSVRC2012_val_00000641.JPEG 568
570 | ILSVRC2012_val_00001144.JPEG 569
571 | ILSVRC2012_val_00002290.JPEG 570
572 | ILSVRC2012_val_00000732.JPEG 571
573 | ILSVRC2012_val_00000311.JPEG 572
574 | ILSVRC2012_val_00000032.JPEG 573
575 | ILSVRC2012_val_00002314.JPEG 574
576 | ILSVRC2012_val_00001116.JPEG 575
577 | ILSVRC2012_val_00000913.JPEG 576
578 | ILSVRC2012_val_00001052.JPEG 577
579 | ILSVRC2012_val_00000780.JPEG 578
580 | ILSVRC2012_val_00000662.JPEG 579
581 | ILSVRC2012_val_00000295.JPEG 580
582 | ILSVRC2012_val_00000852.JPEG 581
583 | ILSVRC2012_val_00000372.JPEG 582
584 | ILSVRC2012_val_00000169.JPEG 583
585 | ILSVRC2012_val_00000190.JPEG 584
586 | ILSVRC2012_val_00000856.JPEG 585
587 | ILSVRC2012_val_00000035.JPEG 586
588 | ILSVRC2012_val_00000887.JPEG 587
589 | ILSVRC2012_val_00000060.JPEG 588
590 | ILSVRC2012_val_00000876.JPEG 589
591 | ILSVRC2012_val_00000149.JPEG 590
592 | ILSVRC2012_val_00000054.JPEG 591
593 | ILSVRC2012_val_00000367.JPEG 592
594 | ILSVRC2012_val_00001030.JPEG 593
595 | ILSVRC2012_val_00000448.JPEG 594
596 | ILSVRC2012_val_00000015.JPEG 595
597 | ILSVRC2012_val_00002436.JPEG 596
598 | ILSVRC2012_val_00002822.JPEG 597
599 | ILSVRC2012_val_00000387.JPEG 598
600 | ILSVRC2012_val_00001784.JPEG 599
601 | ILSVRC2012_val_00000575.JPEG 600
602 | ILSVRC2012_val_00000776.JPEG 601
603 | ILSVRC2012_val_00001827.JPEG 602
604 | ILSVRC2012_val_00000752.JPEG 603
605 | ILSVRC2012_val_00001943.JPEG 604
606 | ILSVRC2012_val_00000647.JPEG 605
607 | ILSVRC2012_val_00000208.JPEG 606
608 | ILSVRC2012_val_00000308.JPEG 607
609 | ILSVRC2012_val_00000110.JPEG 608
610 | ILSVRC2012_val_00000659.JPEG 609
611 | ILSVRC2012_val_00000219.JPEG 610
612 | ILSVRC2012_val_00000607.JPEG 611
613 | ILSVRC2012_val_00000281.JPEG 612
614 | ILSVRC2012_val_00000969.JPEG 613
615 | ILSVRC2012_val_00000900.JPEG 614
616 | ILSVRC2012_val_00000686.JPEG 615
617 | ILSVRC2012_val_00002377.JPEG 616
618 | ILSVRC2012_val_00002627.JPEG 617
619 | ILSVRC2012_val_00000684.JPEG 618
620 | ILSVRC2012_val_00000427.JPEG 619
621 | ILSVRC2012_val_00000151.JPEG 620
622 | ILSVRC2012_val_00000384.JPEG 621
623 | ILSVRC2012_val_00000359.JPEG 622
624 | ILSVRC2012_val_00001977.JPEG 623
625 | ILSVRC2012_val_00000715.JPEG 624
626 | ILSVRC2012_val_00000556.JPEG 625
627 | ILSVRC2012_val_00003417.JPEG 626
628 | ILSVRC2012_val_00000185.JPEG 627
629 | ILSVRC2012_val_00000664.JPEG 628
630 | ILSVRC2012_val_00000984.JPEG 629
631 | ILSVRC2012_val_00000385.JPEG 630
632 | ILSVRC2012_val_00001814.JPEG 631
633 | ILSVRC2012_val_00000109.JPEG 632
634 | ILSVRC2012_val_00000999.JPEG 633
635 | ILSVRC2012_val_00000760.JPEG 634
636 | ILSVRC2012_val_00001229.JPEG 635
637 | ILSVRC2012_val_00000221.JPEG 636
638 | ILSVRC2012_val_00000152.JPEG 637
639 | ILSVRC2012_val_00000117.JPEG 638
640 | ILSVRC2012_val_00001420.JPEG 639
641 | ILSVRC2012_val_00002223.JPEG 640
642 | ILSVRC2012_val_00000710.JPEG 641
643 | ILSVRC2012_val_00001796.JPEG 642
644 | ILSVRC2012_val_00000608.JPEG 643
645 | ILSVRC2012_val_00000259.JPEG 644
646 | ILSVRC2012_val_00000120.JPEG 645
647 | ILSVRC2012_val_00000118.JPEG 646
648 | ILSVRC2012_val_00000163.JPEG 647
649 | ILSVRC2012_val_00001318.JPEG 648
650 | ILSVRC2012_val_00000382.JPEG 649
651 | ILSVRC2012_val_00001982.JPEG 650
652 | ILSVRC2012_val_00000519.JPEG 651
653 | ILSVRC2012_val_00003199.JPEG 652
654 | ILSVRC2012_val_00001170.JPEG 653
655 | ILSVRC2012_val_00000314.JPEG 654
656 | ILSVRC2012_val_00001192.JPEG 655
657 | ILSVRC2012_val_00001722.JPEG 656
658 | ILSVRC2012_val_00000700.JPEG 657
659 | ILSVRC2012_val_00001574.JPEG 658
660 | ILSVRC2012_val_00000173.JPEG 659
661 | ILSVRC2012_val_00000254.JPEG 660
662 | ILSVRC2012_val_00002010.JPEG 661
663 | ILSVRC2012_val_00000222.JPEG 662
664 | ILSVRC2012_val_00000670.JPEG 663
665 | ILSVRC2012_val_00000119.JPEG 664
666 | ILSVRC2012_val_00000717.JPEG 665
667 | ILSVRC2012_val_00000168.JPEG 666
668 | ILSVRC2012_val_00000923.JPEG 667
669 | ILSVRC2012_val_00000680.JPEG 668
670 | ILSVRC2012_val_00000859.JPEG 669
671 | ILSVRC2012_val_00001892.JPEG 670
672 | ILSVRC2012_val_00001659.JPEG 671
673 | ILSVRC2012_val_00002082.JPEG 672
674 | ILSVRC2012_val_00000803.JPEG 673
675 | ILSVRC2012_val_00000009.JPEG 674
676 | ILSVRC2012_val_00000365.JPEG 675
677 | ILSVRC2012_val_00000277.JPEG 676
678 | ILSVRC2012_val_00000649.JPEG 677
679 | ILSVRC2012_val_00000243.JPEG 678
680 | ILSVRC2012_val_00000850.JPEG 679
681 | ILSVRC2012_val_00000540.JPEG 680
682 | ILSVRC2012_val_00000677.JPEG 681
683 | ILSVRC2012_val_00002278.JPEG 682
684 | ILSVRC2012_val_00001464.JPEG 683
685 | ILSVRC2012_val_00000521.JPEG 684
686 | ILSVRC2012_val_00001380.JPEG 685
687 | ILSVRC2012_val_00000114.JPEG 686
688 | ILSVRC2012_val_00000166.JPEG 687
689 | ILSVRC2012_val_00001016.JPEG 688
690 | ILSVRC2012_val_00001961.JPEG 689
691 | ILSVRC2012_val_00001581.JPEG 690
692 | ILSVRC2012_val_00000549.JPEG 691
693 | ILSVRC2012_val_00002069.JPEG 692
694 | ILSVRC2012_val_00003004.JPEG 693
695 | ILSVRC2012_val_00000264.JPEG 694
696 | ILSVRC2012_val_00000265.JPEG 695
697 | ILSVRC2012_val_00000360.JPEG 696
698 | ILSVRC2012_val_00001424.JPEG 697
699 | ILSVRC2012_val_00003804.JPEG 698
700 | ILSVRC2012_val_00000681.JPEG 699
701 | ILSVRC2012_val_00001833.JPEG 700
702 | ILSVRC2012_val_00001440.JPEG 701
703 | ILSVRC2012_val_00000240.JPEG 702
704 | ILSVRC2012_val_00000192.JPEG 703
705 | ILSVRC2012_val_00000201.JPEG 704
706 | ILSVRC2012_val_00000096.JPEG 705
707 | ILSVRC2012_val_00001665.JPEG 706
708 | ILSVRC2012_val_00000335.JPEG 707
709 | ILSVRC2012_val_00002193.JPEG 708
710 | ILSVRC2012_val_00001265.JPEG 709
711 | ILSVRC2012_val_00004441.JPEG 710
712 | ILSVRC2012_val_00001754.JPEG 711
713 | ILSVRC2012_val_00000465.JPEG 712
714 | ILSVRC2012_val_00000489.JPEG 713
715 | ILSVRC2012_val_00001365.JPEG 714
716 | ILSVRC2012_val_00000834.JPEG 715
717 | ILSVRC2012_val_00000626.JPEG 716
718 | ILSVRC2012_val_00000049.JPEG 717
719 | ILSVRC2012_val_00000121.JPEG 718
720 | ILSVRC2012_val_00001023.JPEG 719
721 | ILSVRC2012_val_00000112.JPEG 720
722 | ILSVRC2012_val_00000438.JPEG 721
723 | ILSVRC2012_val_00000428.JPEG 722
724 | ILSVRC2012_val_00000991.JPEG 723
725 | ILSVRC2012_val_00000966.JPEG 724
726 | ILSVRC2012_val_00000046.JPEG 725
727 | ILSVRC2012_val_00000144.JPEG 726
728 | ILSVRC2012_val_00000024.JPEG 727
729 | ILSVRC2012_val_00000722.JPEG 728
730 | ILSVRC2012_val_00001853.JPEG 729
731 | ILSVRC2012_val_00000292.JPEG 730
732 | ILSVRC2012_val_00002729.JPEG 731
733 | ILSVRC2012_val_00000765.JPEG 732
734 | ILSVRC2012_val_00000334.JPEG 733
735 | ILSVRC2012_val_00000452.JPEG 734
736 | ILSVRC2012_val_00004159.JPEG 735
737 | ILSVRC2012_val_00002385.JPEG 736
738 | ILSVRC2012_val_00001970.JPEG 737
739 | ILSVRC2012_val_00001919.JPEG 738
740 | ILSVRC2012_val_00000458.JPEG 739
741 | ILSVRC2012_val_00000589.JPEG 740
742 | ILSVRC2012_val_00000230.JPEG 741
743 | ILSVRC2012_val_00002985.JPEG 742
744 | ILSVRC2012_val_00002554.JPEG 743
745 | ILSVRC2012_val_00000468.JPEG 744
746 | ILSVRC2012_val_00001253.JPEG 745
747 | ILSVRC2012_val_00000454.JPEG 746
748 | ILSVRC2012_val_00002782.JPEG 747
749 | ILSVRC2012_val_00001343.JPEG 748
750 | ILSVRC2012_val_00001345.JPEG 749
751 | ILSVRC2012_val_00000260.JPEG 750
752 | ILSVRC2012_val_00000135.JPEG 751
753 | ILSVRC2012_val_00001786.JPEG 752
754 | ILSVRC2012_val_00000770.JPEG 753
755 | ILSVRC2012_val_00001346.JPEG 754
756 | ILSVRC2012_val_00002046.JPEG 755
757 | ILSVRC2012_val_00000042.JPEG 756
758 | ILSVRC2012_val_00000014.JPEG 757
759 | ILSVRC2012_val_00000473.JPEG 758
760 | ILSVRC2012_val_00001085.JPEG 759
761 | ILSVRC2012_val_00000331.JPEG 760
762 | ILSVRC2012_val_00000275.JPEG 761
763 | ILSVRC2012_val_00000591.JPEG 762
764 | ILSVRC2012_val_00000158.JPEG 763
765 | ILSVRC2012_val_00000244.JPEG 764
766 | ILSVRC2012_val_00000421.JPEG 765
767 | ILSVRC2012_val_00001625.JPEG 766
768 | ILSVRC2012_val_00000453.JPEG 767
769 | ILSVRC2012_val_00001497.JPEG 768
770 | ILSVRC2012_val_00001210.JPEG 769
771 | ILSVRC2012_val_00001465.JPEG 770
772 | ILSVRC2012_val_00000143.JPEG 771
773 | ILSVRC2012_val_00002068.JPEG 772
774 | ILSVRC2012_val_00000792.JPEG 773
775 | ILSVRC2012_val_00002426.JPEG 774
776 | ILSVRC2012_val_00000442.JPEG 775
777 | ILSVRC2012_val_00000614.JPEG 776
778 | ILSVRC2012_val_00000039.JPEG 777
779 | ILSVRC2012_val_00001091.JPEG 778
780 | ILSVRC2012_val_00001736.JPEG 779
781 | ILSVRC2012_val_00000094.JPEG 780
782 | ILSVRC2012_val_00000167.JPEG 781
783 | ILSVRC2012_val_00001269.JPEG 782
784 | ILSVRC2012_val_00001639.JPEG 783
785 | ILSVRC2012_val_00000310.JPEG 784
786 | ILSVRC2012_val_00001792.JPEG 785
787 | ILSVRC2012_val_00000377.JPEG 786
788 | ILSVRC2012_val_00000232.JPEG 787
789 | ILSVRC2012_val_00000083.JPEG 788
790 | ILSVRC2012_val_00000159.JPEG 789
791 | ILSVRC2012_val_00000629.JPEG 790
792 | ILSVRC2012_val_00001316.JPEG 791
793 | ILSVRC2012_val_00002065.JPEG 792
794 | ILSVRC2012_val_00002394.JPEG 793
795 | ILSVRC2012_val_00000527.JPEG 794
796 | ILSVRC2012_val_00000568.JPEG 795
797 | ILSVRC2012_val_00000550.JPEG 796
798 | ILSVRC2012_val_00001395.JPEG 797
799 | ILSVRC2012_val_00000266.JPEG 798
800 | ILSVRC2012_val_00002286.JPEG 799
801 | ILSVRC2012_val_00000857.JPEG 800
802 | ILSVRC2012_val_00000508.JPEG 801
803 | ILSVRC2012_val_00000131.JPEG 802
804 | ILSVRC2012_val_00000767.JPEG 803
805 | ILSVRC2012_val_00003585.JPEG 804
806 | ILSVRC2012_val_00002349.JPEG 805
807 | ILSVRC2012_val_00000936.JPEG 806
808 | ILSVRC2012_val_00000897.JPEG 807
809 | ILSVRC2012_val_00002335.JPEG 808
810 | ILSVRC2012_val_00000004.JPEG 809
811 | ILSVRC2012_val_00001672.JPEG 810
812 | ILSVRC2012_val_00004382.JPEG 811
813 | ILSVRC2012_val_00000559.JPEG 812
814 | ILSVRC2012_val_00001587.JPEG 813
815 | ILSVRC2012_val_00000901.JPEG 814
816 | ILSVRC2012_val_00002128.JPEG 815
817 | ILSVRC2012_val_00000349.JPEG 816
818 | ILSVRC2012_val_00001247.JPEG 817
819 | ILSVRC2012_val_00000345.JPEG 818
820 | ILSVRC2012_val_00001388.JPEG 819
821 | ILSVRC2012_val_00000544.JPEG 820
822 | ILSVRC2012_val_00000271.JPEG 821
823 | ILSVRC2012_val_00000842.JPEG 822
824 | ILSVRC2012_val_00001183.JPEG 823
825 | ILSVRC2012_val_00001059.JPEG 824
826 | ILSVRC2012_val_00000171.JPEG 825
827 | ILSVRC2012_val_00000290.JPEG 826
828 | ILSVRC2012_val_00000205.JPEG 827
829 | ILSVRC2012_val_00001470.JPEG 828
830 | ILSVRC2012_val_00000714.JPEG 829
831 | ILSVRC2012_val_00000535.JPEG 830
832 | ILSVRC2012_val_00001811.JPEG 831
833 | ILSVRC2012_val_00000522.JPEG 832
834 | ILSVRC2012_val_00000595.JPEG 833
835 | ILSVRC2012_val_00000380.JPEG 834
836 | ILSVRC2012_val_00002015.JPEG 835
837 | ILSVRC2012_val_00000413.JPEG 836
838 | ILSVRC2012_val_00000815.JPEG 837
839 | ILSVRC2012_val_00001102.JPEG 838
840 | ILSVRC2012_val_00000778.JPEG 839
841 | ILSVRC2012_val_00000701.JPEG 840
842 | ILSVRC2012_val_00000070.JPEG 841
843 | ILSVRC2012_val_00000065.JPEG 842
844 | ILSVRC2012_val_00001086.JPEG 843
845 | ILSVRC2012_val_00000053.JPEG 844
846 | ILSVRC2012_val_00002160.JPEG 845
847 | ILSVRC2012_val_00000026.JPEG 846
848 | ILSVRC2012_val_00002120.JPEG 847
849 | ILSVRC2012_val_00000627.JPEG 848
850 | ILSVRC2012_val_00000464.JPEG 849
851 | ILSVRC2012_val_00000602.JPEG 850
852 | ILSVRC2012_val_00001532.JPEG 851
853 | ILSVRC2012_val_00000123.JPEG 852
854 | ILSVRC2012_val_00003312.JPEG 853
855 | ILSVRC2012_val_00002206.JPEG 854
856 | ILSVRC2012_val_00001263.JPEG 855
857 | ILSVRC2012_val_00000416.JPEG 856
858 | ILSVRC2012_val_00001160.JPEG 857
859 | ILSVRC2012_val_00000030.JPEG 858
860 | ILSVRC2012_val_00000737.JPEG 859
861 | ILSVRC2012_val_00000561.JPEG 860
862 | ILSVRC2012_val_00000449.JPEG 861
863 | ILSVRC2012_val_00001008.JPEG 862
864 | ILSVRC2012_val_00000525.JPEG 863
865 | ILSVRC2012_val_00000671.JPEG 864
866 | ILSVRC2012_val_00000261.JPEG 865
867 | ILSVRC2012_val_00000294.JPEG 866
868 | ILSVRC2012_val_00000187.JPEG 867
869 | ILSVRC2012_val_00000513.JPEG 868
870 | ILSVRC2012_val_00000291.JPEG 869
871 | ILSVRC2012_val_00000069.JPEG 870
872 | ILSVRC2012_val_00001519.JPEG 871
873 | ILSVRC2012_val_00000059.JPEG 872
874 | ILSVRC2012_val_00001325.JPEG 873
875 | ILSVRC2012_val_00000573.JPEG 874
876 | ILSVRC2012_val_00000146.JPEG 875
877 | ILSVRC2012_val_00000874.JPEG 876
878 | ILSVRC2012_val_00001750.JPEG 877
879 | ILSVRC2012_val_00000104.JPEG 878
880 | ILSVRC2012_val_00000381.JPEG 879
881 | ILSVRC2012_val_00002599.JPEG 880
882 | ILSVRC2012_val_00000284.JPEG 881
883 | ILSVRC2012_val_00000712.JPEG 882
884 | ILSVRC2012_val_00000895.JPEG 883
885 | ILSVRC2012_val_00001002.JPEG 884
886 | ILSVRC2012_val_00000433.JPEG 885
887 | ILSVRC2012_val_00001881.JPEG 886
888 | ILSVRC2012_val_00000036.JPEG 887
889 | ILSVRC2012_val_00000296.JPEG 888
890 | ILSVRC2012_val_00000212.JPEG 889
891 | ILSVRC2012_val_00001680.JPEG 890
892 | ILSVRC2012_val_00000797.JPEG 891
893 | ILSVRC2012_val_00004307.JPEG 892
894 | ILSVRC2012_val_00000866.JPEG 893
895 | ILSVRC2012_val_00000482.JPEG 894
896 | ILSVRC2012_val_00000369.JPEG 895
897 | ILSVRC2012_val_00000538.JPEG 896
898 | ILSVRC2012_val_00000280.JPEG 897
899 | ILSVRC2012_val_00000745.JPEG 898
900 | ILSVRC2012_val_00000286.JPEG 899
901 | ILSVRC2012_val_00000586.JPEG 900
902 | ILSVRC2012_val_00002867.JPEG 901
903 | ILSVRC2012_val_00000497.JPEG 902
904 | ILSVRC2012_val_00000854.JPEG 903
905 | ILSVRC2012_val_00000351.JPEG 904
906 | ILSVRC2012_val_00000587.JPEG 905
907 | ILSVRC2012_val_00000106.JPEG 906
908 | ILSVRC2012_val_00000805.JPEG 907
909 | ILSVRC2012_val_00001901.JPEG 908
910 | ILSVRC2012_val_00001328.JPEG 909
911 | ILSVRC2012_val_00002685.JPEG 910
912 | ILSVRC2012_val_00001208.JPEG 911
913 | ILSVRC2012_val_00000437.JPEG 912
914 | ILSVRC2012_val_00001034.JPEG 913
915 | ILSVRC2012_val_00003145.JPEG 914
916 | ILSVRC2012_val_00000460.JPEG 915
917 | ILSVRC2012_val_00000830.JPEG 916
918 | ILSVRC2012_val_00000985.JPEG 917
919 | ILSVRC2012_val_00002362.JPEG 918
920 | ILSVRC2012_val_00000903.JPEG 919
921 | ILSVRC2012_val_00000274.JPEG 920
922 | ILSVRC2012_val_00000980.JPEG 921
923 | ILSVRC2012_val_00000209.JPEG 922
924 | ILSVRC2012_val_00001611.JPEG 923
925 | ILSVRC2012_val_00000728.JPEG 924
926 | ILSVRC2012_val_00000267.JPEG 925
927 | ILSVRC2012_val_00001512.JPEG 926
928 | ILSVRC2012_val_00000620.JPEG 927
929 | ILSVRC2012_val_00000574.JPEG 928
930 | ILSVRC2012_val_00002619.JPEG 929
931 | ILSVRC2012_val_00000820.JPEG 930
932 | ILSVRC2012_val_00001405.JPEG 931
933 | ILSVRC2012_val_00000891.JPEG 932
934 | ILSVRC2012_val_00002226.JPEG 933
935 | ILSVRC2012_val_00000129.JPEG 934
936 | ILSVRC2012_val_00000183.JPEG 935
937 | ILSVRC2012_val_00001826.JPEG 936
938 | ILSVRC2012_val_00000155.JPEG 937
939 | ILSVRC2012_val_00000332.JPEG 938
940 | ILSVRC2012_val_00000554.JPEG 939
941 | ILSVRC2012_val_00000461.JPEG 940
942 | ILSVRC2012_val_00001932.JPEG 941
943 | ILSVRC2012_val_00000674.JPEG 942
944 | ILSVRC2012_val_00000463.JPEG 943
945 | ILSVRC2012_val_00000443.JPEG 944
946 | ILSVRC2012_val_00000423.JPEG 945
947 | ILSVRC2012_val_00000881.JPEG 946
948 | ILSVRC2012_val_00002490.JPEG 947
949 | ILSVRC2012_val_00000023.JPEG 948
950 | ILSVRC2012_val_00000099.JPEG 949
951 | ILSVRC2012_val_00001406.JPEG 950
952 | ILSVRC2012_val_00000711.JPEG 951
953 | ILSVRC2012_val_00002124.JPEG 952
954 | ILSVRC2012_val_00001479.JPEG 953
955 | ILSVRC2012_val_00000605.JPEG 954
956 | ILSVRC2012_val_00000691.JPEG 955
957 | ILSVRC2012_val_00000798.JPEG 956
958 | ILSVRC2012_val_00001514.JPEG 957
959 | ILSVRC2012_val_00000816.JPEG 958
960 | ILSVRC2012_val_00000116.JPEG 959
961 | ILSVRC2012_val_00000782.JPEG 960
962 | ILSVRC2012_val_00000543.JPEG 961
963 | ILSVRC2012_val_00001115.JPEG 962
964 | ILSVRC2012_val_00000773.JPEG 963
965 | ILSVRC2012_val_00000491.JPEG 964
966 | ILSVRC2012_val_00000572.JPEG 965
967 | ILSVRC2012_val_00001064.JPEG 966
968 | ILSVRC2012_val_00002934.JPEG 967
969 | ILSVRC2012_val_00001726.JPEG 968
970 | ILSVRC2012_val_00000139.JPEG 969
971 | ILSVRC2012_val_00000002.JPEG 970
972 | ILSVRC2012_val_00000697.JPEG 971
973 | ILSVRC2012_val_00001444.JPEG 972
974 | ILSVRC2012_val_00000235.JPEG 973
975 | ILSVRC2012_val_00000394.JPEG 974
976 | ILSVRC2012_val_00000150.JPEG 975
977 | ILSVRC2012_val_00002476.JPEG 976
978 | ILSVRC2012_val_00000145.JPEG 977
979 | ILSVRC2012_val_00000214.JPEG 978
980 | ILSVRC2012_val_00000313.JPEG 979
981 | ILSVRC2012_val_00002103.JPEG 980
982 | ILSVRC2012_val_00000034.JPEG 981
983 | ILSVRC2012_val_00000341.JPEG 982
984 | ILSVRC2012_val_00000255.JPEG 983
985 | ILSVRC2012_val_00000216.JPEG 984
986 | ILSVRC2012_val_00004514.JPEG 985
987 | ILSVRC2012_val_00000551.JPEG 986
988 | ILSVRC2012_val_00000484.JPEG 987
989 | ILSVRC2012_val_00001058.JPEG 988
990 | ILSVRC2012_val_00000251.JPEG 989
991 | ILSVRC2012_val_00001120.JPEG 990
992 | ILSVRC2012_val_00000304.JPEG 991
993 | ILSVRC2012_val_00000731.JPEG 992
994 | ILSVRC2012_val_00000200.JPEG 993
995 | ILSVRC2012_val_00000058.JPEG 994
996 | ILSVRC2012_val_00000755.JPEG 995
997 | ILSVRC2012_val_00002167.JPEG 996
998 | ILSVRC2012_val_00001362.JPEG 997
999 | ILSVRC2012_val_00000408.JPEG 998
1000 | ILSVRC2012_val_00001079.JPEG 999
1001 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | from collections import OrderedDict
4 |
5 | import torch
6 | import torch.distributed as dist
7 | import torch.optim
8 | import torchvision.utils as vutils
9 | from torch.utils.data import DataLoader
10 |
11 | import models
12 | import utils
13 | from dataset import ImageDataset
14 |
15 |
16 | class Trainer(object):
17 |
18 | def __init__(self, config):
19 | self.rank, self.world_size = 0, 1
20 | if config['dist']:
21 | self.rank = dist.get_rank()
22 | self.world_size = dist.get_world_size()
23 |
24 | self.mode = config['dgp_mode']
25 | assert self.mode in [
26 | 'reconstruct', 'colorization', 'SR', 'hybrid', 'inpainting',
27 | 'morphing', 'defence', 'jitter'
28 | ]
29 |
30 | if self.rank == 0:
31 | # mkdir path
32 | if not os.path.exists('{}/images'.format(config['exp_path'])):
33 | os.makedirs('{}/images'.format(config['exp_path']))
34 | if not os.path.exists('{}/images_sheet'.format(
35 | config['exp_path'])):
36 | os.makedirs('{}/images_sheet'.format(config['exp_path']))
37 | if not os.path.exists('{}/logs'.format(config['exp_path'])):
38 | os.makedirs('{}/logs'.format(config['exp_path']))
39 |
40 | # prepare logger
41 | if not config['no_tb']:
42 | try:
43 | from tensorboardX import SummaryWriter
44 | except ImportError:
45 | raise Exception("Please switch off \"tensorboard\" "
46 | "in your config file if you do not "
47 | "want to use it, otherwise install it.")
48 | self.tb_logger = SummaryWriter('{}'.format(config['exp_path']))
49 | else:
50 | self.tb_logger = None
51 |
52 | self.logger = utils.create_logger(
53 | 'global_logger',
54 | '{}/logs/log_train.txt'.format(config['exp_path']))
55 |
56 | self.model = models.DGP(config)
57 | if self.mode == 'morphing':
58 | self.model2 = models.DGP(config)
59 | self.model_interp = models.DGP(config)
60 |
61 | # Data loader
62 | train_dataset = ImageDataset(
63 | config['root_dir'],
64 | config['list_file'],
65 | image_size=config['resolution'],
66 | normalize=True)
67 | sampler = utils.DistributedSampler(
68 | train_dataset) if config['dist'] else None
69 | self.train_loader = DataLoader(
70 | train_dataset,
71 | batch_size=1,
72 | shuffle=False,
73 | sampler=sampler,
74 | num_workers=1,
75 | pin_memory=False)
76 | self.config = config
77 |
78 | def run(self):
79 | # train
80 | if self.mode == 'morphing':
81 | self.train_morphing()
82 | else:
83 | self.train()
84 |
85 | def train(self):
86 | btime_rec = utils.AverageMeter()
87 | dtime_rec = utils.AverageMeter()
88 | recorder = {}
89 | end = time.time()
90 | for i, (image, category, img_path) in enumerate(self.train_loader):
91 | # measure data loading time
92 | dtime_rec.update(time.time() - end)
93 |
94 | torch.cuda.empty_cache()
95 |
96 | image = image.cuda()
97 | category = category.cuda()
98 | img_path = img_path[0]
99 |
100 | self.model.reset_G()
101 | self.model.set_target(image, category, img_path)
102 | # when category is unkonwn (category=-1), it would be selected from samples
103 | self.model.select_z(select_y=True if category.item() < 0 else False)
104 | loss_dict = self.model.run(save_interval=self.config['save_interval'])
105 |
106 | # average loss if distributed
107 | if self.config['dist']:
108 | for k, v in loss_dict.items():
109 | reduced = v.data.clone() / self.world_size
110 | dist.all_reduce_multigpu([reduced])
111 | loss_dict[k] = reduced
112 |
113 | if len(recorder) == 0:
114 | for k in loss_dict.keys():
115 | recorder[k] = utils.AverageMeter()
116 | for k in loss_dict.keys():
117 | recorder[k].update(loss_dict[k].item())
118 |
119 | btime_rec.update(time.time() - end)
120 | end = time.time()
121 |
122 | # logging
123 | loss_str = ""
124 | if self.rank == 0:
125 | for k in recorder.keys():
126 | if self.tb_logger is not None:
127 | self.tb_logger.add_scalar('train_{}'.format(k),
128 | recorder[k].avg, i + 1)
129 | loss_str += '{}: {loss.val:.4g} ({loss.avg:.4g}) '.format(
130 | k, loss=recorder[k])
131 |
132 | self.logger.info(
133 | 'Iter: [{0}/{1}] '.format(i + 1, len(self.train_loader)) +
134 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '.
135 | format(batch_time=btime_rec) +
136 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) '.format(
137 | data_time=dtime_rec) + 'Image {} '.format(img_path) +
138 | loss_str)
139 |
140 | def train_morphing(self):
141 | btime_rec = utils.AverageMeter()
142 | dtime_rec = utils.AverageMeter()
143 | recorder = {}
144 | last_category = -1
145 | end = time.time()
146 | for i, (image, category, img_path) in enumerate(self.train_loader):
147 | # measure data loading time
148 | dtime_rec.update(time.time() - end)
149 |
150 | assert image.shape[0] > 0
151 | image = image.cuda()
152 | category = category.cuda()
153 | img_path = img_path[0]
154 |
155 | self.model.reset_G()
156 | self.model.set_target(image, category, img_path)
157 | self.model.select_z()
158 | loss_dict = self.model.run(
159 | save_interval=self.config['save_interval'])
160 |
161 | # apply image morphing within the same category
162 | if category == last_category:
163 | self.morphing()
164 | torch.cuda.empty_cache()
165 |
166 | with torch.no_grad():
167 | self.model2.G.load_state_dict(self.model.G.state_dict())
168 | self.model2.z.copy_(self.model.z)
169 | self.model2.img_name = self.model.img_name
170 | self.model2.target = self.model.target
171 | self.model2.category = self.model.category
172 |
173 | if category == last_category:
174 | # average loss if distributed
175 | if self.config['dist']:
176 | for k, v in loss_dict.items():
177 | reduced = v.data.clone() / self.world_size
178 | dist.all_reduce_multigpu([reduced])
179 | loss_dict[k] = reduced
180 |
181 | if len(recorder) < len(loss_dict):
182 | for k in loss_dict.keys():
183 | recorder[k] = utils.AverageMeter()
184 | for k in loss_dict.keys():
185 | recorder[k].update(loss_dict[k].item())
186 |
187 | btime_rec.update(time.time() - end)
188 | end = time.time()
189 |
190 | # logging
191 | loss_str = ""
192 | if self.rank == 0:
193 | for k in recorder.keys():
194 | if self.tb_logger is not None:
195 | self.tb_logger.add_scalar('train_{}'.format(k),
196 | recorder[k].avg, i + 1)
197 | loss_str += '{}: {loss.val:.4g} ({loss.avg:.4g}) '.format(
198 | k, loss=recorder[k])
199 |
200 | self.logger.info(
201 | 'Iter: [{0}/{1}] '.format(i, len(self.train_loader)) +
202 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '.
203 | format(batch_time=btime_rec) +
204 | 'Data {data_time.val:.3f} ({data_time.avg:.3f}) '.
205 | format(data_time=dtime_rec) +
206 | 'Image {} '.format(img_path) + loss_str)
207 |
208 | last_category = category
209 |
210 | def morphing(self):
211 | weight1 = self.model.G.state_dict()
212 | weight2 = self.model2.G.state_dict()
213 | weight_interp = OrderedDict()
214 | imgs = []
215 | with torch.no_grad():
216 | for i in range(11):
217 | alpha = i / 10
218 | # interpolate between both latent vector and generator weight
219 | z_interp = alpha * self.model.z + (1 - alpha) * self.model2.z
220 | for k, w1 in weight1.items():
221 | w2 = weight2[k]
222 | weight_interp[k] = alpha * w1 + (1 - alpha) * w2
223 | self.model_interp.G.load_state_dict(weight_interp)
224 | x_interp = self.model_interp.G(
225 | z_interp, self.model_interp.G.shared(self.model.y))
226 | imgs.append(x_interp.cpu())
227 | # save image
228 | save_path = '%s/images/%s_%s' % (self.config['exp_path'],
229 | self.model.img_name,
230 | self.model2.img_name)
231 | if not os.path.exists(save_path):
232 | os.makedirs(save_path)
233 | utils.save_img(x_interp[0], '%s/%03d.jpg' % (save_path, i + 1))
234 | imgs = torch.cat(imgs, 0)
235 | vutils.save_image(
236 | imgs,
237 | '%s/images_sheet/morphing_class%d_%s_%s.jpg' %
238 | (self.config['exp_path'], self.model.category, self.model.img_name,
239 | self.model2.img_name),
240 | nrow=int(imgs.size(0)**0.5),
241 | normalize=True)
242 | del weight_interp
243 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .biggan_utils import *
2 | from .common_utils import *
3 | from .distributed_utils import *
4 | from .losses import *
5 |
--------------------------------------------------------------------------------
/utils/biggan_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | ''' Utilities file
4 | This file contains utility functions for bookkeeping, logging, and data loading.
5 | Methods which directly affect training should either go in layers, the model,
6 | or train_fns.py.
7 | '''
8 |
9 | from __future__ import print_function
10 |
11 | import math
12 | import os
13 | from argparse import ArgumentParser
14 |
15 | import numpy as np
16 | import torch
17 | import torch.nn as nn
18 | import torchvision.transforms as transforms
19 | from torch.optim.optimizer import Optimizer
20 |
21 |
22 | def prepare_parser():
23 | usage = 'Parser for all scripts.'
24 | parser = ArgumentParser(description=usage)
25 |
26 | ### Pipeline stuff ###
27 | parser.add_argument(
28 | '--eval_mode', action='store_true', default=False,
29 | help='Evaluation mode? (do not save logs) '
30 | ' (default: %(default)s)')
31 |
32 | ### Model stuff ###
33 | parser.add_argument(
34 | '--model', type=str, default='BigGAN',
35 | help='Name of the model module (default: %(default)s)')
36 | parser.add_argument(
37 | '--G_param', type=str, default='SN',
38 | help='Parameterization style to use for G, spectral norm (SN) or SVD (SVD)'
39 | ' or None (default: %(default)s)')
40 | parser.add_argument(
41 | '--D_param', type=str, default='SN',
42 | help='Parameterization style to use for D, spectral norm (SN) or SVD (SVD)'
43 | ' or None (default: %(default)s)')
44 | parser.add_argument(
45 | '--G_ch', type=int, default=64,
46 | help='Channel multiplier for G (default: %(default)s)')
47 | parser.add_argument(
48 | '--D_ch', type=int, default=64,
49 | help='Channel multiplier for D (default: %(default)s)')
50 | parser.add_argument(
51 | '--G_depth', type=int, default=1,
52 | help='Number of resblocks per stage in G? (default: %(default)s)')
53 | parser.add_argument(
54 | '--D_depth', type=int, default=1,
55 | help='Number of resblocks per stage in D? (default: %(default)s)')
56 | parser.add_argument(
57 | '--D_thin', action='store_false', dest='D_wide', default=True,
58 | help='Use the SN-GAN channel pattern for D? (default: %(default)s)')
59 | parser.add_argument(
60 | '--G_shared', action='store_true', default=False,
61 | help='Use shared embeddings in G? (default: %(default)s)')
62 | parser.add_argument(
63 | '--shared_dim', type=int, default=0,
64 | help='G''s shared embedding dimensionality; if 0, will be equal to dim_z. '
65 | '(default: %(default)s)')
66 | parser.add_argument(
67 | '--dim_z', type=int, default=128,
68 | help='Noise dimensionality: %(default)s)')
69 | parser.add_argument(
70 | '--hier', action='store_true', default=False,
71 | help='Use hierarchical z in G? (default: %(default)s)')
72 | parser.add_argument(
73 | '--n_classes', type=int, default=1000,
74 | help='Number of class conditions %(default)s)')
75 | parser.add_argument(
76 | '--cross_replica', action='store_true', default=False,
77 | help='Cross_replica batchnorm in G?(default: %(default)s)')
78 | parser.add_argument(
79 | '--mybn', action='store_true', default=False,
80 | help='Use my batchnorm (which supports standing stats?) %(default)s)')
81 | parser.add_argument(
82 | '--G_nl', type=str, default='inplace_relu',
83 | help='Activation function for G (default: %(default)s)')
84 | parser.add_argument(
85 | '--D_nl', type=str, default='inplace_relu',
86 | help='Activation function for D (default: %(default)s)')
87 | parser.add_argument(
88 | '--G_attn', type=str, default='64',
89 | help='What resolutions to use attention on for G (underscore separated) '
90 | '(default: %(default)s)')
91 | parser.add_argument(
92 | '--D_attn', type=str, default='64',
93 | help='What resolutions to use attention on for D (underscore separated) '
94 | '(default: %(default)s)')
95 | parser.add_argument(
96 | '--norm_style', type=str, default='bn',
97 | help='Normalizer style for G, one of bn [batchnorm], in [instancenorm], '
98 | 'ln [layernorm], gn [groupnorm] (default: %(default)s)')
99 |
100 | ### Model init stuff ###
101 | parser.add_argument(
102 | '--seed', type=int, default=0,
103 | help='Random seed to use; affects both initialization and '
104 | ' dataloading. (default: %(default)s)')
105 | parser.add_argument(
106 | '--G_init', type=str, default='ortho',
107 | help='Init style to use for G (default: %(default)s)')
108 | parser.add_argument(
109 | '--D_init', type=str, default='ortho',
110 | help='Init style to use for D(default: %(default)s)')
111 | parser.add_argument(
112 | '--skip_init', action='store_true', default=False,
113 | help='Skip initialization, ideal for testing when ortho init was used '
114 | '(default: %(default)s)')
115 |
116 | ### Optimizer stuff ###
117 | parser.add_argument(
118 | '--optimizer', type=str, default='Adam',
119 | help='Optimizer, Adam or SGD (default: %(default)s)')
120 | parser.add_argument(
121 | '--G_lr', type=float, default=5e-5,
122 | help='Learning rate to use for Generator (default: %(default)s)')
123 | parser.add_argument(
124 | '--D_lr', type=float, default=2e-4,
125 | help='Learning rate to use for Discriminator (default: %(default)s)')
126 | parser.add_argument(
127 | '--Z_lr_mult', type=float, default=50,
128 | help='Learning rate multiplication to use for Z (default: %(default)s)')
129 | parser.add_argument(
130 | '--G_B1', type=float, default=0.0,
131 | help='Beta1 to use for Generator (default: %(default)s)')
132 | parser.add_argument(
133 | '--D_B1', type=float, default=0.0,
134 | help='Beta1 to use for Discriminator (default: %(default)s)')
135 | parser.add_argument(
136 | '--G_B2', type=float, default=0.999,
137 | help='Beta2 to use for Generator (default: %(default)s)')
138 | parser.add_argument(
139 | '--D_B2', type=float, default=0.999,
140 | help='Beta2 to use for Discriminator (default: %(default)s)')
141 |
142 | ### Batch size, parallel, and precision stuff ###
143 | parser.add_argument(
144 | '--G_fp16', action='store_true', default=False,
145 | help='Train with half-precision in G? (default: %(default)s)')
146 | parser.add_argument(
147 | '--D_fp16', action='store_true', default=False,
148 | help='Train with half-precision in D? (default: %(default)s)')
149 | parser.add_argument(
150 | '--D_mixed_precision', action='store_true', default=False,
151 | help='Train with half-precision activations but fp32 params in D? '
152 | '(default: %(default)s)')
153 | parser.add_argument(
154 | '--G_mixed_precision', action='store_true', default=False,
155 | help='Train with half-precision activations but fp32 params in G? '
156 | '(default: %(default)s)')
157 | parser.add_argument(
158 | '--accumulate_stats', action='store_true', default=False,
159 | help='Accumulate "standing" batchnorm stats? (default: %(default)s)')
160 | parser.add_argument(
161 | '--num_standing_accumulations', type=int, default=16,
162 | help='Number of forward passes to use in accumulating standing stats? '
163 | '(default: %(default)s)')
164 |
165 | ### Bookkeping stuff ###
166 | parser.add_argument(
167 | '--weights_root', type=str, default='weights',
168 | help='Default location to store weights (default: %(default)s)')
169 |
170 | ### EMA Stuff ###
171 | parser.add_argument(
172 | '--use_ema', action='store_true', default=False,
173 | help='Use the EMA parameters of G for evaluation? (default: %(default)s)')
174 |
175 | ### Numerical precision and SV stuff ###
176 | parser.add_argument(
177 | '--adam_eps', type=float, default=1e-6,
178 | help='epsilon value to use for Adam (default: %(default)s)')
179 | parser.add_argument(
180 | '--BN_eps', type=float, default=1e-5,
181 | help='epsilon value to use for BatchNorm (default: %(default)s)')
182 | parser.add_argument(
183 | '--SN_eps', type=float, default=1e-6,
184 | help='epsilon value to use for Spectral Norm(default: %(default)s)')
185 | parser.add_argument(
186 | '--num_G_SVs', type=int, default=1,
187 | help='Number of SVs to track in G (default: %(default)s)')
188 | parser.add_argument(
189 | '--num_D_SVs', type=int, default=1,
190 | help='Number of SVs to track in D (default: %(default)s)')
191 | parser.add_argument(
192 | '--num_G_SV_itrs', type=int, default=1,
193 | help='Number of SV itrs in G (default: %(default)s)')
194 | parser.add_argument(
195 | '--num_D_SV_itrs', type=int, default=1,
196 | help='Number of SV itrs in D (default: %(default)s)')
197 |
198 | ### Resume training stuff
199 | parser.add_argument(
200 | '--load_weights', type=str, default='',
201 | help='Suffix for which weights to load (e.g. best0, copy0) '
202 | '(default: %(default)s)')
203 |
204 | ### Log stuff ###
205 | parser.add_argument(
206 | '--no_tb', action='store_true', default=False,
207 | help='Do not use tensorboard? '
208 | '(default: %(default)s)')
209 | return parser
210 |
211 |
212 | activation_dict = {'inplace_relu': nn.ReLU(inplace=True),
213 | 'relu': nn.ReLU(inplace=False),
214 | 'ir': nn.ReLU(inplace=True)}
215 |
216 |
217 | def dgp_update_config(config):
218 | config['G_activation'] = activation_dict[config['G_nl']]
219 | config['D_activation'] = activation_dict[config['D_nl']]
220 |
221 |
222 | class CenterCropLongEdge(object):
223 | """Crops the given PIL Image on the long edge.
224 | Args:
225 | size (sequence or int): Desired output size of the crop. If size is an
226 | int instead of sequence like (h, w), a square crop (size, size) is
227 | made.
228 | """
229 |
230 | def __call__(self, img):
231 | """
232 | Args:
233 | img (PIL Image): Image to be cropped.
234 | Returns:
235 | PIL Image: Cropped image.
236 | """
237 | return transforms.functional.center_crop(img, min(img.size))
238 |
239 | def __repr__(self):
240 | return self.__class__.__name__
241 |
242 |
243 | class RandomCropLongEdge(object):
244 | """Crops the given PIL Image on the long edge with a random start point.
245 | Args:
246 | size (sequence or int): Desired output size of the crop. If size is an
247 | int instead of sequence like (h, w), a square crop (size, size) is
248 | made.
249 | """
250 |
251 | def __call__(self, img):
252 | """
253 | Args:
254 | img (PIL Image): Image to be cropped.
255 | Returns:
256 | PIL Image: Cropped image.
257 | """
258 | size = (min(img.size), min(img.size))
259 | # Only step forward along this edge if it's the long edge
260 | i = (0 if size[0] == img.size[0] else np.random.randint(
261 | low=0, high=img.size[0] - size[0]))
262 | j = (0 if size[1] == img.size[1] else np.random.randint(
263 | low=0, high=img.size[1] - size[1]))
264 | return transforms.functional.crop(img, i, j, size[0], size[1])
265 |
266 | def __repr__(self):
267 | return self.__class__.__name__
268 |
269 |
270 | # Utility file to seed rngs
271 | def seed_rng(seed):
272 | torch.manual_seed(seed)
273 | torch.cuda.manual_seed(seed)
274 | np.random.seed(seed)
275 |
276 |
277 | # Apply modified ortho reg to a model
278 | # This function is an optimized version that directly computes the gradient,
279 | # instead of computing and then differentiating the loss.
280 | def ortho(model, strength=1e-4, blacklist=[], invconvonly=False):
281 | with torch.no_grad():
282 | for name, param in model.named_parameters():
283 | if invconvonly:
284 | if 'weight' not in name or 'NN' in name:
285 | continue
286 | # Only apply this to parameters with at least 2 axes, and not in the blacklist
287 | if len(param.shape) < 2 or any([param is item for item in blacklist]):
288 | continue
289 | w = param.view(param.shape[0], -1)
290 | grad = (2 * torch.mm(
291 | torch.mm(w, w.t()) *
292 | (1. - torch.eye(w.shape[0], device=w.device)), w))
293 | param.grad.data += strength * grad.view(param.shape)
294 |
295 |
296 | # Default ortho reg
297 | # This function is an optimized version that directly computes the gradient,
298 | # instead of computing and then differentiating the loss.
299 | def default_ortho(model, strength=1e-4, blacklist=[]):
300 | with torch.no_grad():
301 | for param in model.parameters():
302 | # Only apply this to parameters with at least 2 axes & not in blacklist
303 | if len(param.shape) < 2 or param in blacklist:
304 | continue
305 | w = param.view(param.shape[0], -1)
306 | grad = (2 * torch.mm(
307 | torch.mm(w, w.t()) - torch.eye(w.shape[0], device=w.device),
308 | w))
309 | param.grad.data += strength * grad.view(param.shape)
310 |
311 |
312 | # Convenience utility to switch off requires_grad
313 | def toggle_grad(model, on_or_off):
314 | for param in model.parameters():
315 | param.requires_grad = on_or_off
316 |
317 |
318 | # Function to join strings or ignore them
319 | # Base string is the string to link "strings," while strings
320 | # is a list of strings or Nones.
321 | def join_strings(base_string, strings):
322 | return base_string.join([item for item in strings if item])
323 |
324 |
325 | # Load a model's weights
326 | def load_weights(G,
327 | D,
328 | weights_root,
329 | name_suffix=None,
330 | G_ema=None,
331 | strict=False):
332 | def map_func(storage, location):
333 | return storage.cuda()
334 |
335 | if name_suffix:
336 | print('Loading %s weights from %s...' % (name_suffix, weights_root))
337 | else:
338 | print('Loading weights from %s...' % weights_root)
339 | if G is not None:
340 | G.load_state_dict(
341 | torch.load(
342 | '%s/%s.pth' %
343 | (weights_root, join_strings('_', ['G', name_suffix])),
344 | map_location=map_func),
345 | strict=strict)
346 | if D is not None:
347 | D.load_state_dict(
348 | torch.load(
349 | '%s/%s.pth' %
350 | (weights_root, join_strings('_', ['D', name_suffix])),
351 | map_location=map_func),
352 | strict=strict)
353 | if G_ema is not None:
354 | print('Loading ema generator...')
355 | G_ema.load_state_dict(
356 | torch.load(
357 | '%s/%s.pth' %
358 | (weights_root, join_strings('_', ['G_ema', name_suffix])),
359 | map_location=map_func),
360 | strict=strict)
361 |
362 |
363 | # Get singular values to log. This will use the state dict to find them
364 | # and substitute underscores for dots.
365 | def get_SVs(net, prefix):
366 | d = net.state_dict()
367 | return {('%s_%s' % (prefix, key)).replace('.', '_'): float(d[key].item())
368 | for key in d if 'sv' in key}
369 |
370 |
371 | # Name an experiment based on its config
372 | def name_from_config(config):
373 | name = '_'.join([
374 | item for item in [
375 | 'Big%s' % config['which_train_fn'],
376 | config['dataset'],
377 | config['model'] if config['model'] != 'BigGAN' else None,
378 | 'seed%d' % config['seed'],
379 | 'Gch%d' % config['G_ch'],
380 | 'Dch%d' % config['D_ch'],
381 | 'Gd%d' % config['G_depth'] if config['G_depth'] > 1 else None,
382 | 'Dd%d' % config['D_depth'] if config['D_depth'] > 1 else None,
383 | 'bs%d' % config['batch_size'],
384 | 'Gfp16' if config['G_fp16'] else None,
385 | 'Dfp16' if config['D_fp16'] else None,
386 | 'nDs%d' % config['num_D_steps'] if config['num_D_steps'] > 1 else None,
387 | 'nDa%d' % config['num_D_accumulations'] if config['num_D_accumulations'] > 1 else None,
388 | 'nGa%d' % config['num_G_accumulations'] if config['num_G_accumulations'] > 1 else None,
389 | 'Glr%2.1e' % config['G_lr'],
390 | 'Dlr%2.1e' % config['D_lr'],
391 | 'GB%3.3f' % config['G_B1'] if config['G_B1'] != 0.0 else None,
392 | 'GBB%3.3f' % config['G_B2'] if config['G_B2'] != 0.999 else None,
393 | 'DB%3.3f' % config['D_B1'] if config['D_B1'] != 0.0 else None,
394 | 'DBB%3.3f' % config['D_B2'] if config['D_B2'] != 0.999 else None,
395 | 'Gnl%s' % config['G_nl'],
396 | 'Dnl%s' % config['D_nl'],
397 | 'Ginit%s' % config['G_init'],
398 | 'Dinit%s' % config['D_init'],
399 | 'G%s' % config['G_param'] if config['G_param'] != 'SN' else None,
400 | 'D%s' % config['D_param'] if config['D_param'] != 'SN' else None,
401 | 'Gattn%s' % config['G_attn'] if config['G_attn'] != '0' else None,
402 | 'Dattn%s' % config['D_attn'] if config['D_attn'] != '0' else None,
403 | 'Gortho%2.1e' % config['G_ortho'] if config['G_ortho'] > 0.0 else None,
404 | 'Dortho%2.1e' % config['D_ortho'] if config['D_ortho'] > 0.0 else None,
405 | config['norm_style'] if config['norm_style'] != 'bn' else None,
406 | 'cr' if config['cross_replica'] else None,
407 | 'Gshared' if config['G_shared'] else None,
408 | 'hier' if config['hier'] else None,
409 | 'ema' if config['ema'] else None,
410 | 'Glow' if config['Glow'] else None,
411 | 'nlevels%d' % config['n_levels'] if config['n_levels'] else None,
412 | 'depth%d' % config['depth'] if config['depth'] else None,
413 | config['name_suffix'] if config['name_suffix'] else None,
414 | ] if item is not None
415 | ])
416 | return name
417 |
418 |
419 | # Get GPU memory, -i is the index
420 | def query_gpu(indices):
421 | os.system('nvidia-smi -i 0 --query-gpu=memory.free --format=csv')
422 |
423 |
424 | # Convenience function to count the number of parameters in a module
425 | def count_parameters(module):
426 | print('Number of parameters: {}'.format(
427 | sum([p.data.nelement() for p in module.parameters()])))
428 |
429 |
430 | # Convenience function to sample an index, not actually a 1-hot
431 | def sample_1hot(batch_size, num_classes, device='cuda'):
432 | return torch.randint(
433 | low=0,
434 | high=num_classes,
435 | size=(batch_size, ),
436 | device=device,
437 | dtype=torch.int64,
438 | requires_grad=False)
439 |
440 |
441 | # A highly simplified convenience class for sampling from distributions
442 | # One could also use PyTorch's inbuilt distributions package.
443 | # Note that this class requires initialization to proceed as
444 | # x = Distribution(torch.randn(size))
445 | # x.init_distribution(dist_type, **dist_kwargs)
446 | # x = x.to(device,dtype)
447 | # This is partially based on https://discuss.pytorch.org/t/subclassing-torch-tensor/23754/2
448 | class Distribution(torch.Tensor):
449 | # Init the params of the distribution
450 | def init_distribution(self, dist_type, **kwargs):
451 | self.dist_type = dist_type
452 | self.dist_kwargs = kwargs
453 | if self.dist_type == 'normal':
454 | self.mean, self.var = kwargs['mean'], kwargs['var']
455 | elif self.dist_type == 'categorical':
456 | self.num_categories = kwargs['num_categories']
457 |
458 | def sample_(self):
459 | if self.dist_type == 'normal':
460 | self.normal_(self.mean, self.var)
461 | elif self.dist_type == 'categorical':
462 | self.random_(0, self.num_categories)
463 |
464 | # update the mean and covariance for multi-variate distribution
465 | def update(self, mean, cov):
466 | self.mean, self.cov = mean, cov
467 |
468 | # Silly hack: overwrite the to() method to wrap the new object
469 | # in a distribution as well
470 | def to(self, *args, **kwargs):
471 | new_obj = Distribution(self)
472 | new_obj.init_distribution(self.dist_type, **self.dist_kwargs)
473 | new_obj.data = super().to(*args, **kwargs)
474 | return new_obj
475 |
476 |
477 | # Convenience function to prepare a z and y vector
478 | def prepare_z_y(G_batch_size,
479 | dim_z,
480 | nclasses,
481 | device='cuda',
482 | fp16=False,
483 | z_var=1.0):
484 | z_ = Distribution(torch.randn(G_batch_size, dim_z, requires_grad=False))
485 | z_.init_distribution('normal', mean=0, var=z_var)
486 | z_ = z_.to(device, torch.float16 if fp16 else torch.float32)
487 |
488 | if fp16:
489 | z_ = z_.half()
490 |
491 | y_ = Distribution(torch.zeros(G_batch_size, requires_grad=False))
492 | y_.init_distribution('categorical', num_categories=nclasses)
493 | y_ = y_.to(device, torch.int64)
494 | return z_, y_
495 |
496 |
497 | def initiate_standing_stats(net):
498 | for module in net.modules():
499 | if hasattr(module, 'accumulate_standing'):
500 | module.reset_stats()
501 | module.accumulate_standing = True
502 |
503 |
504 | def accumulate_standing_stats(net, z, y, nclasses, num_accumulations=16):
505 | initiate_standing_stats(net)
506 | net.train()
507 | for i in range(num_accumulations):
508 | with torch.no_grad():
509 | z.sample_()
510 | y.random_(0, nclasses)
511 | x = net(z, net.shared(y)) # No need to parallelize here unless using syncbn
512 | # Set to eval mode
513 | net.eval()
514 |
515 |
516 | # This version of Adam keeps an fp32 copy of the parameters and
517 | # does all of the parameter updates in fp32, while still doing the
518 | # forwards and backwards passes using fp16 (i.e. fp16 copies of the
519 | # parameters and fp16 activations).
520 | #
521 | # Note that this calls .float().cuda() on the params.
522 |
523 |
524 | class Adam16(Optimizer):
525 |
526 | def __init__(self,
527 | params,
528 | lr=1e-3,
529 | betas=(0.9, 0.999),
530 | eps=1e-8,
531 | weight_decay=0):
532 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
533 | params = list(params)
534 | super(Adam16, self).__init__(params, defaults)
535 |
536 | # Safety modification to make sure we floatify our state
537 | def load_state_dict(self, state_dict):
538 | super(Adam16, self).load_state_dict(state_dict)
539 | for group in self.param_groups:
540 | for p in group['params']:
541 | self.state[p]['exp_avg'] = self.state[p]['exp_avg'].float()
542 | self.state[p]['exp_avg_sq'] = self.state[p][
543 | 'exp_avg_sq'].float()
544 | self.state[p]['fp32_p'] = self.state[p]['fp32_p'].float()
545 |
546 | def step(self, closure=None):
547 | """Performs a single optimization step.
548 | Arguments:
549 | closure (callable, optional): A closure that reevaluates the model
550 | and returns the loss.
551 | """
552 | loss = None
553 | if closure is not None:
554 | loss = closure()
555 |
556 | for group in self.param_groups:
557 | for p in group['params']:
558 | if p.grad is None:
559 | continue
560 |
561 | grad = p.grad.data.float()
562 | state = self.state[p]
563 |
564 | # State initialization
565 | if len(state) == 0:
566 | state['step'] = 0
567 | # Exponential moving average of gradient values
568 | state['exp_avg'] = grad.new().resize_as_(grad).zero_()
569 | # Exponential moving average of squared gradient values
570 | state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()
571 | # Fp32 copy of the weights
572 | state['fp32_p'] = p.data.float()
573 |
574 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
575 | beta1, beta2 = group['betas']
576 |
577 | state['step'] += 1
578 |
579 | if group['weight_decay'] != 0:
580 | grad = grad.add(group['weight_decay'], state['fp32_p'])
581 |
582 | # Decay the first and second moment running average coefficient
583 | exp_avg.mul_(beta1).add_(1 - beta1, grad)
584 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
585 |
586 | denom = exp_avg_sq.sqrt().add_(group['eps'])
587 |
588 | bias_correction1 = 1 - beta1**state['step']
589 | bias_correction2 = 1 - beta2**state['step']
590 | step_size = group['lr'] * math.sqrt(
591 | bias_correction2) / bias_correction1
592 |
593 | state['fp32_p'].addcdiv_(-step_size, exp_avg, denom)
594 | p.data = state['fp32_p'].half()
595 |
596 | return loss
597 |
--------------------------------------------------------------------------------
/utils/common_utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 |
4 | import numpy as np
5 | import torch
6 | import torchvision.transforms as transforms
7 | from PIL import Image
8 |
9 | import dataset
10 |
11 | from .biggan_utils import CenterCropLongEdge
12 |
13 |
14 | # Arguments for DGP
15 | def add_dgp_parser(parser):
16 | parser.add_argument(
17 | '--dist', action='store_true', default=False,
18 | help='Train with distributed implementation (default: %(default)s)')
19 | parser.add_argument(
20 | '--port', type=str, default='12345',
21 | help='Port id for distributed training (default: %(default)s)')
22 | parser.add_argument(
23 | '--exp_path', type=str, default='',
24 | help='Experiment path (default: %(default)s)')
25 | parser.add_argument(
26 | '--root_dir', type=str, default='',
27 | help='Root path of dataset (default: %(default)s)')
28 | parser.add_argument(
29 | '--list_file', type=str, default='',
30 | help='List file of the dataset (default: %(default)s)')
31 | parser.add_argument(
32 | '--resolution', type=int, default=256,
33 | help='Resolution to resize the input image (default: %(default)s)')
34 | parser.add_argument(
35 | '--dgp_mode', type=str, default='reconstruct',
36 | help='DGP mode (default: %(default)s)')
37 | parser.add_argument(
38 | '--random_G', action='store_true', default=False,
39 | help='Use randomly initialized generator? (default: %(default)s)')
40 | parser.add_argument(
41 | '--update_G', action='store_true', default=False,
42 | help='Finetune Generator? (default: %(default)s)')
43 | parser.add_argument(
44 | '--update_embed', action='store_true', default=False,
45 | help='Finetune class embedding? (default: %(default)s)')
46 | parser.add_argument(
47 | '--save_G', action='store_true', default=False,
48 | help='Save fine-tuned generator and latent vector? (default: %(default)s)')
49 | parser.add_argument(
50 | '--ftr_type', type=str, default='Discriminator',
51 | choices=['Discriminator', 'VGG'],
52 | help='Feature loss type, choose from Discriminator and VGG (default: %(default)s)')
53 | parser.add_argument(
54 | '--ftr_num', type=int, default=[3], nargs='+',
55 | help='Number of features to computer feature loss (default: %(default)s)')
56 | parser.add_argument(
57 | '--ft_num', type=int, default=[2], nargs='+',
58 | help='Number of parameter groups to finetune (default: %(default)s)')
59 | parser.add_argument(
60 | '--print_interval', type=int, default=100, nargs='+',
61 | help='Number of iterations to print training loss (default: %(default)s)')
62 | parser.add_argument(
63 | '--save_interval', type=int, default=None, nargs='+',
64 | help='Number of iterations to save image')
65 | parser.add_argument(
66 | '--lr_ratio', type=float, default=[1.0, 1.0, 1.0, 1.0], nargs='+',
67 | help='Decreasing ratio for learning rate in blocks (default: %(default)s)')
68 | parser.add_argument(
69 | '--w_D_loss', type=float, default=[0.1], nargs='+',
70 | help='Discriminator feature loss weight (default: %(default)s)')
71 | parser.add_argument(
72 | '--w_nll', type=float, default=0.001,
73 | help='Weight for the negative log-likelihood loss (default: %(default)s)')
74 | parser.add_argument(
75 | '--w_mse', type=float, default=[0.1], nargs='+',
76 | help='MSE loss weight (default: %(default)s)')
77 | parser.add_argument(
78 | '--select_num', type=int, default=500,
79 | help='Number of image pool to select from (default: %(default)s)')
80 | parser.add_argument(
81 | '--sample_std', type=float, default=1.0,
82 | help='Std of the gaussian distribution used for sampling (default: %(default)s)')
83 | parser.add_argument(
84 | '--iterations', type=int, default=[200, 200, 200, 200], nargs='+',
85 | help='Training iterations for all stages')
86 | parser.add_argument(
87 | '--G_lrs', type=float, default=[1e-6, 2e-5, 1e-5, 1e-6], nargs='+',
88 | help='Learning rate steps of Generator')
89 | parser.add_argument(
90 | '--z_lrs', type=float, default=[1e-1, 1e-3, 1e-5, 1e-6], nargs='+',
91 | help='Learning rate steps of latent code z')
92 | parser.add_argument(
93 | '--warm_up', type=int, default=0,
94 | help='Number of warmup iterations (default: %(default)s)')
95 | parser.add_argument(
96 | '--use_in', type=str2bool, default=[False, False, False, False], nargs='+',
97 | help='Whether to use instance normalization in generator')
98 | parser.add_argument(
99 | '--stop_mse', type=float, default=0.0,
100 | help='MSE threshold for stopping training (default: %(default)s)')
101 | parser.add_argument(
102 | '--stop_ftr', type=float, default=0.0,
103 | help='Feature loss threshold for stopping training (default: %(default)s)')
104 | return parser
105 |
106 |
107 | def str2bool(v):
108 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
109 | return True
110 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
111 | return False
112 | else:
113 | raise argparse.ArgumentTypeError('Boolean value expected.')
114 |
115 |
116 | def get_img(img_path, resolution):
117 | img = dataset.default_loader(img_path)
118 | norm_mean = [0.5, 0.5, 0.5]
119 | norm_std = [0.5, 0.5, 0.5]
120 | transform = transforms.Compose([
121 | CenterCropLongEdge(),
122 | transforms.Resize(resolution),
123 | transforms.ToTensor(),
124 | transforms.Normalize(norm_mean, norm_std)
125 | ])
126 | img = transform(img)
127 | return img.unsqueeze(0)
128 |
129 |
130 | def save_img(image, path):
131 | image = np.uint8(255 * (image.cpu().detach().numpy() + 1) / 2.)
132 | image = np.transpose(image, (1, 2, 0))
133 | image = Image.fromarray(image)
134 | image.save(path)
135 |
136 |
137 | class AverageMeter(object):
138 | """Computes and stores the average and current value"""
139 |
140 | def __init__(self):
141 | self.reset()
142 |
143 | def reset(self):
144 | self.val = 0
145 | self.avg = 0
146 | self.sum = 0
147 | self.count = 0
148 |
149 | def update(self, val, n=1):
150 | self.val = val
151 | self.sum += val * n
152 | self.count += n
153 | self.avg = self.sum / self.count
154 |
155 |
156 | def size_splits(tensor, split_sizes, dim=0):
157 | """Splits the tensor according to chunks of split_sizes.
158 |
159 | Arguments:
160 | tensor (Tensor): tensor to split.
161 | split_sizes (list(int)): sizes of chunks
162 | dim (int): dimension along which to split the tensor.
163 | """
164 | if dim < 0:
165 | dim += tensor.dim()
166 |
167 | dim_size = tensor.size(dim)
168 | if dim_size != torch.sum(torch.Tensor(split_sizes)):
169 | raise KeyError("Sum of split sizes exceeds tensor dim")
170 |
171 | splits = torch.cumsum(torch.Tensor([0] + split_sizes), dim=0)[:-1]
172 |
173 | return tuple(
174 | tensor.narrow(int(dim), int(start), int(length))
175 | for start, length in zip(splits, split_sizes))
176 |
177 |
178 | class LRScheduler(object):
179 |
180 | def __init__(self, optimizer, warm_up):
181 | super(LRScheduler, self).__init__()
182 | self.optimizer = optimizer
183 | self.warm_up = warm_up
184 |
185 | def update(self, iteration, learning_rate, num_group=1000, ratio=1):
186 | if iteration < self.warm_up:
187 | learning_rate *= iteration / self.warm_up
188 | for i, param_group in enumerate(self.optimizer.param_groups):
189 | if i >= num_group:
190 | param_group['lr'] = 0
191 | else:
192 | param_group['lr'] = learning_rate * ratio**i
193 |
194 |
195 | def create_logger(name, log_file, level=logging.INFO):
196 | l = logging.getLogger(name)
197 | formatter = logging.Formatter('[%(asctime)s] %(message)s')
198 | fh = logging.FileHandler(log_file)
199 | fh.setFormatter(formatter)
200 | sh = logging.StreamHandler()
201 | sh.setFormatter(formatter)
202 | l.setLevel(level)
203 | l.addHandler(fh)
204 | l.addHandler(sh)
205 | return l
206 |
207 |
208 | def map_func(storage, location):
209 | return storage.cuda()
210 |
211 |
212 | def pil_to_np(img_PIL):
213 | '''Converts image in PIL format to np.array.
214 |
215 | From W x H x C [0...255] to C x W x H [0..1]
216 | '''
217 | ar = np.array(img_PIL)
218 |
219 | if len(ar.shape) == 3:
220 | ar = ar.transpose(2, 0, 1)
221 | else:
222 | ar = ar[None, ...]
223 |
224 | return ar.astype(np.float32) / 255.
225 |
226 |
227 | def np_to_pil(img_np):
228 | '''Converts image in np.array format to PIL image.
229 |
230 | From C x W x H [0..1] to W x H x C [0...255]
231 | '''
232 | ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
233 |
234 | if img_np.shape[0] == 1:
235 | ar = ar[0]
236 | else:
237 | ar = ar.transpose(1, 2, 0)
238 |
239 | return Image.fromarray(ar)
240 |
241 |
242 | def np_to_torch(img_np):
243 | '''Converts image in numpy.array to torch.Tensor.
244 |
245 | From C x W x H [0..1] to C x W x H [0..1]
246 | '''
247 | return torch.from_numpy(img_np)[None, :]
248 |
249 |
250 | def torch_to_np(img_var):
251 | '''Converts an image in torch.Tensor format to np.array.
252 |
253 | From 1 x C x W x H [0..1] to C x W x H [0..1]
254 | '''
255 | return img_var.detach().cpu().numpy()[0]
256 |
257 |
258 | def bicubic_torch(img, size):
259 | img_pil = np_to_pil(torch_to_np(img))
260 | img_bicubic_pil = img_pil.resize(size, Image.BICUBIC)
261 | img_bicubic_pth = np_to_torch(pil_to_np(img_bicubic_pil))
262 | return img_bicubic_pth
263 |
--------------------------------------------------------------------------------
/utils/distributed_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import multiprocessing as mp
3 | import os
4 |
5 | import torch
6 | import torch.distributed as dist
7 | from torch.nn import Module
8 | from torch.utils.data import Sampler
9 |
10 |
11 | class DistModule(Module):
12 |
13 | def __init__(self, module):
14 | super(DistModule, self).__init__()
15 | self.module = module
16 | broadcast_params(self.module)
17 |
18 | def forward(self, *inputs, **kwargs):
19 | return self.module(*inputs, **kwargs)
20 |
21 | def train(self, mode=True):
22 | super(DistModule, self).train(mode)
23 | self.module.train(mode)
24 |
25 |
26 | def average_gradients(model):
27 | """ average gradients """
28 | for param in model.parameters():
29 | if param.requires_grad and param.grad is not None:
30 | dist.all_reduce(param.grad.data)
31 |
32 |
33 | def broadcast_params(model):
34 | """ broadcast model parameters """
35 | for p in model.state_dict().values():
36 | dist.broadcast(p, 0)
37 |
38 |
39 | def average_params(model):
40 | """ broadcast model parameters """
41 | worldsize = dist.get_world_size()
42 | for p in model.state_dict().values():
43 | dist.all_reduce(p)
44 | p /= worldsize
45 |
46 |
47 | def dist_init(port):
48 | if mp.get_start_method(allow_none=True) != 'spawn':
49 | mp.set_start_method('spawn')
50 | proc_id = int(os.environ['SLURM_PROCID'])
51 | ntasks = int(os.environ['SLURM_NTASKS'])
52 | node_list = os.environ['SLURM_NODELIST']
53 | num_gpus = torch.cuda.device_count()
54 | torch.cuda.set_device(proc_id % num_gpus)
55 |
56 | if '[' in node_list:
57 | beg = node_list.find('[')
58 | pos1 = node_list.find('-', beg)
59 | if pos1 < 0:
60 | pos1 = 1000
61 | pos2 = node_list.find(',', beg)
62 | if pos2 < 0:
63 | pos2 = 1000
64 | node_list = node_list[:min(pos1, pos2)].replace('[', '')
65 | addr = node_list[8:].replace('-', '.')
66 | print(addr)
67 |
68 | os.environ['MASTER_PORT'] = port
69 | os.environ['MASTER_ADDR'] = addr
70 | os.environ['WORLD_SIZE'] = str(ntasks)
71 | os.environ['RANK'] = str(proc_id)
72 | dist.init_process_group(backend='nccl')
73 |
74 | rank = dist.get_rank()
75 | world_size = dist.get_world_size()
76 | return rank, world_size
77 |
78 |
79 | class DistributedSampler(Sampler):
80 | """Sampler that restricts data loading to a subset of the dataset.
81 |
82 | It is especially useful in conjunction with
83 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
84 | process can pass a DistributedSampler instance as a DataLoader sampler,
85 | and load a subset of the original dataset that is exclusive to it.
86 |
87 | .. note::
88 | Dataset is assumed to be of constant size.
89 |
90 | Arguments:
91 | dataset: Dataset used for sampling.
92 | num_replicas (optional): Number of processes participating in
93 | distributed training.
94 | rank (optional): Rank of the current process within num_replicas.
95 | """
96 |
97 | def __init__(self, dataset, num_replicas=None, rank=None):
98 | if num_replicas is None:
99 | if not dist.is_available():
100 | raise RuntimeError(
101 | "Requires distributed package to be available")
102 | num_replicas = dist.get_world_size()
103 | if rank is None:
104 | if not dist.is_available():
105 | raise RuntimeError(
106 | "Requires distributed package to be available")
107 | rank = dist.get_rank()
108 | self.dataset = dataset
109 | self.num_replicas = num_replicas
110 | self.rank = rank
111 | self.epoch = 0
112 | self.num_samples = int(
113 | math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
114 | self.total_size = self.num_samples * self.num_replicas
115 |
116 | def __iter__(self):
117 | # deterministically shuffle based on epoch
118 | indices = [i for i in range(len(self.dataset))]
119 |
120 | # add extra samples to make it evenly divisible
121 | indices += indices[:(self.total_size - len(indices))]
122 | assert len(indices) == self.total_size
123 |
124 | # subsample
125 | indices = indices[self.rank * self.num_samples:(self.rank + 1) *
126 | self.num_samples]
127 | assert len(indices) == self.num_samples
128 |
129 | return iter(indices)
130 |
131 | def __len__(self):
132 | return self.num_samples
133 |
134 | def set_epoch(self, epoch):
135 | self.epoch = epoch
136 |
--------------------------------------------------------------------------------
/utils/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class PerceptLoss(object):
7 |
8 | def __init__(self):
9 | pass
10 |
11 | def __call__(self, LossNet, fake_img, real_img):
12 | with torch.no_grad():
13 | real_feature = LossNet(real_img.detach())
14 | fake_feature = LossNet(fake_img)
15 | perceptual_penalty = F.mse_loss(fake_feature, real_feature)
16 | return perceptual_penalty
17 |
18 | def set_ftr_num(self, ftr_num):
19 | pass
20 |
21 |
22 | class DiscriminatorLoss(object):
23 |
24 | def __init__(self, ftr_num=4, data_parallel=False):
25 | self.data_parallel = data_parallel
26 | self.ftr_num = ftr_num
27 |
28 | def __call__(self, D, fake_img, real_img):
29 | if self.data_parallel:
30 | with torch.no_grad():
31 | d, real_feature = nn.parallel.data_parallel(
32 | D, real_img.detach())
33 | d, fake_feature = nn.parallel.data_parallel(D, fake_img)
34 | else:
35 | with torch.no_grad():
36 | d, real_feature = D(real_img.detach())
37 | d, fake_feature = D(fake_img)
38 | D_penalty = 0
39 | for i in range(self.ftr_num):
40 | f_id = -i - 1
41 | D_penalty = D_penalty + F.l1_loss(fake_feature[f_id],
42 | real_feature[f_id])
43 | return D_penalty
44 |
45 | def set_ftr_num(self, ftr_num):
46 | self.ftr_num = ftr_num
47 |
--------------------------------------------------------------------------------