├── .idea
├── .gitignore
├── deployment.xml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
├── shelf
│ ├── Uncommitted_changes_before_rebase_[Changes]
│ │ └── shelved.patch
│ └── Uncommitted_changes_before_rebase__Changes_.xml
├── vcs.xml
├── webServers.xml
└── workspace.xml
├── AutoGAN
├── cfg.py
├── datasets.py
├── exps
│ ├── autogan_cifar10_a.sh
│ ├── autogan_cifar10_a2stl10.sh
│ ├── autogan_cifar10_b.sh
│ ├── autogan_cifar10_c.sh
│ ├── autogan_search.sh
│ └── derive.sh
├── functions.py
├── models
│ ├── __init__.py
│ ├── autogan_cifar10_a.py
│ ├── autogan_cifar10_b.py
│ ├── autogan_cifar10_c.py
│ └── building_blocks.py
├── models_search
│ ├── __init__.py
│ ├── building_blocks_search.py
│ ├── controller.py
│ └── shared_gan.py
├── regan.py
├── search.py
├── test.py
├── train.py
├── train_derived.py
└── utils
│ ├── __init__.py
│ ├── cal_fid_stat.py
│ ├── fid_score.py
│ ├── inception_score.py
│ └── utils.py
├── ProGAN
├── main.py
├── model.py
├── regan.py
└── utils.py
├── SNGAN
├── main.py
├── model.py
└── regan.py
├── StyleGAN2
├── apply_factor.py
├── calc_inception.py
├── closed_form_factorization.py
├── convert_weight.py
├── dataset.py
├── diffaug.py
├── distributed.py
├── generate.py
├── inception.py
├── lpips
│ ├── __init__.py
│ ├── base_model.py
│ ├── dist_model.py
│ ├── networks_basic.py
│ ├── pretrained_networks.py
│ └── weights
│ │ ├── v0.0
│ │ ├── alex.pth
│ │ ├── squeeze.pth
│ │ └── vgg.pth
│ │ └── v0.1
│ │ ├── alex.pth
│ │ ├── squeeze.pth
│ │ └── vgg.pth
├── model.py
├── non_leaking.py
├── op
│ ├── __init__.py
│ ├── conv2d_gradfix.py
│ ├── fused_act.py
│ ├── fused_bias_act.cpp
│ ├── fused_bias_act_kernel.cu
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.py
│ └── upfirdn2d_kernel.cu
├── ppl.py
├── prepare_data.py
├── projector.py
├── regan.py
├── swagan.py
└── train.py
├── figures
├── Table5.png
└── main_figure.png
└── readme.md
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | ##ignore this file##
2 | /target/
3 | /.idea/
4 | /.settings/
5 | /.vscode/
6 | /bin/
7 |
8 | .classpath
9 | .project
10 | .settings
11 | .idea
12 | ##filter databfile、sln file##
13 | *.mdb
14 | *.ldb
15 | *.sln
16 | ##class file##
17 | *.com
18 | *.class
19 | *.dll
20 | *.exe
21 | *.o
22 | *.so
23 | # compression file
24 | *.7z
25 | *.dmg
26 | *.gz
27 | *.iso
28 | *.jar
29 | *.rar
30 | *.tar
31 | *.zip
32 | *.via
33 | *.tmp
34 | *.err
35 | *.log
36 | *.iml
37 | # OS generated files #
38 | .DS_Store
39 | .DS_Store?
40 | ._*
41 | .Spotlight-V100
42 | .Trashes
43 | Icon?
44 | ehthumbs.db
45 | Thumbs.db
46 | .factorypath
47 | /.mvn/
48 | /mvnw.cmd
49 | /mvnw
50 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/shelf/Uncommitted_changes_before_rebase_[Changes]/shelved.patch:
--------------------------------------------------------------------------------
1 | Index: .idea/workspace.xml
2 | IDEA additional info:
3 | Subsystem: com.intellij.openapi.diff.impl.patch.BaseRevisionTextPatchEP
4 | <+>\n\n \n \n \n \n
\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n 1684574870748\n \n \n 1684574870748\n \n \n \n 1684575907640\n \n \n \n 1684575907640\n \n \n 1684576125760\n \n \n \n 1684576125760\n \n \n 1684576347861\n \n \n \n 1684576347861\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
5 | Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
6 | <+>UTF-8
7 | ===================================================================
8 | diff --git a/.idea/workspace.xml b/.idea/workspace.xml
9 | --- a/.idea/workspace.xml (revision 8fa8f129305002c132d0447fa83a1ba46daff869)
10 | +++ b/.idea/workspace.xml (date 1684578038650)
11 | @@ -4,7 +4,10 @@
12 |
13 |
14 |
15 | -
16 | +
17 | +
18 | +
19 | +
20 |
21 |
22 |
23 | @@ -65,7 +68,7 @@
24 |
25 |
26 | 1684574870748
27 | -
28 | +
29 |
30 |
31 | 1684575907640
32 | @@ -88,7 +91,21 @@
33 |
34 | 1684576347861
35 |
36 | -
37 | +
38 | + 1684576943020
39 | +
40 | +
41 | +
42 | + 1684576943020
43 | +
44 | +
45 | + 1684577986548
46 | +
47 | +
48 | +
49 | + 1684577986548
50 | +
51 | +
52 |
53 |
54 |
55 | @@ -112,6 +129,7 @@
56 |
57 |
58 |
59 | -
60 | +
61 | +
62 |
63 |
64 | \ No newline at end of file
65 |
--------------------------------------------------------------------------------
/.idea/shelf/Uncommitted_changes_before_rebase__Changes_.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/webServers.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
60 |
61 |
62 |
63 |
64 | 1684574870748
65 |
66 |
67 | 1684574870748
68 |
69 |
70 |
71 | 1684575907640
72 |
73 |
74 |
75 | 1684575907640
76 |
77 |
78 | 1684576125760
79 |
80 |
81 |
82 | 1684576125760
83 |
84 |
85 | 1684576347861
86 |
87 |
88 |
89 | 1684576347861
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
--------------------------------------------------------------------------------
/AutoGAN/cfg.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-07-25
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 | # Modified by Jiahao Xu (jiahxu@polyu.edu.hk)
7 |
8 | import argparse
9 |
10 |
11 | def str2bool(v):
12 | if v.lower() in ("yes", "true", "t", "y", "1"):
13 | return True
14 | elif v.lower() in ("no", "false", "f", "n", "0"):
15 | return False
16 | else:
17 | raise argparse.ArgumentTypeError("Boolean value expected.")
18 |
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument(
23 | "--max_epoch", type=int, default=200, help="number of epochs of training"
24 | )
25 | parser.add_argument(
26 | "--max_iter", type=int, default=None, help="set the max iteration number"
27 | )
28 | parser.add_argument(
29 | "-gen_bs", "--gen_batch_size", type=int, default=64, help="size of the batches"
30 | )
31 | parser.add_argument(
32 | "-dis_bs", "--dis_batch_size", type=int, default=64, help="size of the batches"
33 | )
34 | parser.add_argument(
35 | "--g_lr", type=float, default=0.0002, help="adam: gen learning rate"
36 | )
37 | parser.add_argument(
38 | "--d_lr", type=float, default=0.0002, help="adam: disc learning rate"
39 | )
40 | parser.add_argument(
41 | "--ctrl_lr", type=float, default=3.5e-4, help="adam: ctrl learning rate"
42 | )
43 | parser.add_argument(
44 | "--lr_decay", action="store_true", help="learning rate decay or not"
45 | )
46 | parser.add_argument(
47 | "--beta1",
48 | type=float,
49 | default=0.0,
50 | help="adam: decay of first order momentum of gradient",
51 | )
52 | parser.add_argument(
53 | "--beta2",
54 | type=float,
55 | default=0.9,
56 | help="adam: decay of first order momentum of gradient",
57 | )
58 | parser.add_argument(
59 | "--num_workers",
60 | type=int,
61 | default=8,
62 | help="number of cpu threads to use during batch generation",
63 | )
64 | parser.add_argument(
65 | "--latent_dim", type=int, default=128, help="dimensionality of the latent space"
66 | )
67 | parser.add_argument(
68 | "--img_size", type=int, default=32, help="size of each image dimension"
69 | )
70 | parser.add_argument(
71 | "--channels", type=int, default=3, help="number of image channels"
72 | )
73 | parser.add_argument(
74 | "--n_critic",
75 | type=int,
76 | default=1,
77 | help="number of training steps for discriminator per iter",
78 | )
79 | parser.add_argument(
80 | "--val_freq", type=int, default=20, help="interval between each validation"
81 | )
82 | parser.add_argument(
83 | "--print_freq", type=int, default=100, help="interval between each verbose"
84 | )
85 | parser.add_argument("--load_path", type=str, help="The reload model path")
86 | parser.add_argument("--exp_name", type=str, help="The name of exp")
87 | parser.add_argument(
88 | "--d_spectral_norm",
89 | type=str2bool,
90 | default=False,
91 | help="add spectral_norm on discriminator?",
92 | )
93 | parser.add_argument(
94 | "--g_spectral_norm",
95 | type=str2bool,
96 | default=False,
97 | help="add spectral_norm on generator?",
98 | )
99 | parser.add_argument("--dataset", type=str, default="cifar10", help="dataset type")
100 | parser.add_argument(
101 | "--data_path", type=str, default="./data", help="The path of data set"
102 | )
103 | parser.add_argument(
104 | "--init_type",
105 | type=str,
106 | default="normal",
107 | choices=["normal", "orth", "xavier_uniform", "false"],
108 | help="The init type",
109 | )
110 | parser.add_argument(
111 | "--gf_dim", type=int, default=64, help="The base channel num of gen"
112 | )
113 | parser.add_argument(
114 | "--df_dim", type=int, default=64, help="The base channel num of disc"
115 | )
116 | parser.add_argument(
117 | "--gen_model", type=str, default="shared_gan", help="path of gen model"
118 | )
119 | parser.add_argument(
120 | "--dis_model", type=str, default="shared_gan", help="path of dis model"
121 | )
122 | parser.add_argument(
123 | "--controller", type=str, default="controller", help="path of controller"
124 | )
125 | parser.add_argument("--eval_batch_size", type=int, default=100)
126 | parser.add_argument("--num_eval_imgs", type=int, default=10000)
127 | parser.add_argument(
128 | "--bottom_width", type=int, default=4, help="the base resolution of the GAN"
129 | )
130 | parser.add_argument("--random_seed", type=int, default=12345)
131 |
132 | # search
133 | parser.add_argument(
134 | "--shared_epoch",
135 | type=int,
136 | default=15,
137 | help="the number of epoch to train the shared gan at each search iteration",
138 | )
139 | parser.add_argument(
140 | "--grow_step1",
141 | type=int,
142 | default=25,
143 | help="which iteration to grow the image size from 8 to 16",
144 | )
145 | parser.add_argument(
146 | "--grow_step2",
147 | type=int,
148 | default=55,
149 | help="which iteration to grow the image size from 16 to 32",
150 | )
151 | parser.add_argument(
152 | "--max_search_iter",
153 | type=int,
154 | default=90,
155 | help="max search iterations of this algorithm",
156 | )
157 | parser.add_argument(
158 | "--ctrl_step",
159 | type=int,
160 | default=30,
161 | help="number of steps to train the controller at each search iteration",
162 | )
163 | parser.add_argument(
164 | "--ctrl_sample_batch",
165 | type=int,
166 | default=1,
167 | help="sample size of controller of each step",
168 | )
169 | parser.add_argument(
170 | "--hid_size", type=int, default=100, help="the size of hidden vector"
171 | )
172 | parser.add_argument(
173 | "--baseline_decay", type=float, default=0.9, help="baseline decay rate in RL"
174 | )
175 | parser.add_argument(
176 | "--rl_num_eval_img",
177 | type=int,
178 | default=5000,
179 | help="number of images to be sampled in order to get the reward",
180 | )
181 | parser.add_argument(
182 | "--num_candidate",
183 | type=int,
184 | default=10,
185 | help="number of candidate architectures to be sampled",
186 | )
187 | parser.add_argument(
188 | "--topk",
189 | type=int,
190 | default=5,
191 | help="preserve topk models architectures after each stage",
192 | )
193 | parser.add_argument(
194 | "--entropy_coeff", type=float, default=1e-3, help="to encourage the exploration"
195 | )
196 | parser.add_argument(
197 | "--dynamic_reset_threshold", type=float, default=1e-3, help="var threshold"
198 | )
199 | parser.add_argument(
200 | "--dynamic_reset_window", type=int, default=500, help="the window size"
201 | )
202 | parser.add_argument(
203 | "--arch", nargs="+", type=int, help="the vector of a discovered architecture"
204 | )
205 | parser.add_argument('--regan', action="store_true")
206 | parser.add_argument('--sparsity', type=float, default=0.3)
207 | parser.add_argument('--g', type=int, default=5)
208 | parser.add_argument('--warmup_epoch', type=int, default=100)
209 |
210 |
211 | opt = parser.parse_args()
212 |
213 | return opt
214 |
--------------------------------------------------------------------------------
/AutoGAN/datasets.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-07-25
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | import torch
8 | import torchvision.datasets as datasets
9 | import torchvision.transforms as transforms
10 | from torch.utils.data import Dataset
11 |
12 |
13 | class ImageDataset(object):
14 | def __init__(self, args, cur_img_size=None):
15 | img_size = cur_img_size if cur_img_size else args.img_size
16 | if args.dataset.lower() == "cifar10":
17 | Dt = datasets.CIFAR10
18 | transform = transforms.Compose(
19 | [
20 | transforms.Resize(img_size),
21 | transforms.ToTensor(),
22 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
23 | ]
24 | )
25 | args.n_classes = 10
26 | elif args.dataset.lower() == "stl10":
27 | Dt = datasets.STL10
28 | transform = transforms.Compose(
29 | [
30 | transforms.Resize(img_size),
31 | transforms.ToTensor(),
32 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
33 | ]
34 | )
35 | else:
36 | raise NotImplementedError("Unknown dataset: {}".format(args.dataset))
37 |
38 | if args.dataset.lower() == "stl10":
39 | self.train = torch.utils.data.DataLoader(
40 | Dt(
41 | root=args.data_path,
42 | split="train+unlabeled",
43 | transform=transform,
44 | download=True,
45 | ),
46 | batch_size=args.dis_batch_size,
47 | shuffle=True,
48 | num_workers=args.num_workers,
49 | pin_memory=True,
50 | )
51 |
52 | self.valid = torch.utils.data.DataLoader(
53 | Dt(root=args.data_path, split="test", transform=transform),
54 | batch_size=args.dis_batch_size,
55 | shuffle=False,
56 | num_workers=args.num_workers,
57 | pin_memory=True,
58 | )
59 |
60 | self.test = self.valid
61 | else:
62 | # import numpy as np
63 | self.train = torch.utils.data.DataLoader(Dt(root=args.data_path, train=True, transform=transform, download=True),
64 | batch_size=args.dis_batch_size,
65 | shuffle=True,
66 | num_workers=args.num_workers,
67 | pin_memory=True,
68 | )
69 | # subset = torch.utils.data.Subset(dataset, )
70 |
71 | self.valid = torch.utils.data.DataLoader(
72 | Dt(root=args.data_path, train=False, transform=transform),
73 | batch_size=args.dis_batch_size,
74 | shuffle=False,
75 | num_workers=args.num_workers,
76 | pin_memory=True,
77 | )
78 |
79 | self.test = self.valid
80 |
--------------------------------------------------------------------------------
/AutoGAN/exps/autogan_cifar10_a.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python train.py \
4 | -gen_bs 128 \
5 | -dis_bs 64 \
6 | --dataset cifar10 \
7 | --bottom_width 4 \
8 | --img_size 32 \
9 | --max_iter 50000 \
10 | --gen_model autogan_cifar10_a \
11 | --dis_model autogan_cifar10_a \
12 | --latent_dim 128 \
13 | --gf_dim 256 \
14 | --df_dim 128 \
15 | --g_spectral_norm False \
16 | --d_spectral_norm True \
17 | --g_lr 0.0002 \
18 | --d_lr 0.0002 \
19 | --beta1 0.0 \
20 | --beta2 0.9 \
21 | --init_type xavier_uniform \
22 | --n_critic 5 \
23 | --val_freq 20 \
24 | --exp_name autogan_cifar10_a \
25 | --max_epoch 20 \
26 | #--dsd True \
27 | #--gap 200 \
28 | #--sparsity 0.1 \
29 | #--load_epoch 200 \
30 | #--load_path ./logs/autogan_cifar10_a_2022_11_01_18_44_37
--------------------------------------------------------------------------------
/AutoGAN/exps/autogan_cifar10_a2stl10.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python train.py \
4 | -gen_bs 128 \
5 | -dis_bs 64 \
6 | --dataset stl10 \
7 | --bottom_width 6 \
8 | --img_size 48 \
9 | --max_iter 50000 \
10 | --gen_model autogan_cifar10_a \
11 | --dis_model autogan_cifar10_a \
12 | --latent_dim 128 \
13 | --gf_dim 256 \
14 | --df_dim 128 \
15 | --g_spectral_norm False \
16 | --d_spectral_norm True \
17 | --g_lr 0.0002 \
18 | --d_lr 0.0002 \
19 | --beta1 0.0 \
20 | --beta2 0.9 \
21 | --init_type xavier_uniform \
22 | --n_critic 5 \
23 | --val_freq 10 \
24 | --exp_name autogan_cifar10_a2stl10
--------------------------------------------------------------------------------
/AutoGAN/exps/autogan_cifar10_b.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python train.py \
4 | -gen_bs 128 \
5 | -dis_bs 64 \
6 | --dataset cifar10 \
7 | --bottom_width 4 \
8 | --img_size 32 \
9 | --max_iter 50000 \
10 | --gen_model autogan_cifar10_b \
11 | --dis_model autogan_cifar10_b \
12 | --latent_dim 128 \
13 | --gf_dim 256 \
14 | --df_dim 128 \
15 | --g_spectral_norm False \
16 | --d_spectral_norm True \
17 | --g_lr 0.0002 \
18 | --d_lr 0.0002 \
19 | --beta1 0.0 \
20 | --beta2 0.9 \
21 | --init_type xavier_uniform \
22 | --n_critic 5 \
23 | --val_freq 20 \
24 | --exp_name autogan_cifar10_b
--------------------------------------------------------------------------------
/AutoGAN/exps/autogan_cifar10_c.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python train.py \
4 | -gen_bs 128 \
5 | -dis_bs 64 \
6 | --dataset cifar10 \
7 | --bottom_width 4 \
8 | --img_size 32 \
9 | --max_iter 50000 \
10 | --gen_model autogan_cifar10_c \
11 | --dis_model autogan_cifar10_c \
12 | --latent_dim 128 \
13 | --gf_dim 256 \
14 | --df_dim 128 \
15 | --g_spectral_norm False \
16 | --d_spectral_norm True \
17 | --g_lr 0.0002 \
18 | --d_lr 0.0002 \
19 | --beta1 0.0 \
20 | --beta2 0.9 \
21 | --init_type xavier_uniform \
22 | --n_critic 5 \
23 | --val_freq 20 \
24 | --exp_name autogan_cifar10_c
--------------------------------------------------------------------------------
/AutoGAN/exps/autogan_search.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python search.py \
4 | -gen_bs 128 \
5 | -dis_bs 64 \
6 | --dataset cifar10 \
7 | --bottom_width 4 \
8 | --img_size 32 \
9 | --gen_model shared_gan \
10 | --dis_model shared_gan \
11 | --controller controller \
12 | --latent_dim 128 \
13 | --gf_dim 128 \
14 | --df_dim 64 \
15 | --g_spectral_norm False \
16 | --d_spectral_norm True \
17 | --g_lr 0.0002 \
18 | --d_lr 0.0002 \
19 | --beta1 0.0 \
20 | --beta2 0.9 \
21 | --init_type xavier_uniform \
22 | --n_critic 5 \
23 | --val_freq 20 \
24 | --ctrl_sample_batch 1 \
25 | --num_candidate 10 \
26 | --topk 5 \
27 | --shared_epoch 15 \
28 | --grow_step1 15 \
29 | --grow_step2 35 \
30 | --max_search_iter 65 \
31 | --ctrl_step 30 \
32 | --exp_name autogan_search
--------------------------------------------------------------------------------
/AutoGAN/exps/derive.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python train_derived.py \
4 | -gen_bs 128 \
5 | -dis_bs 64 \
6 | --dataset cifar10 \
7 | --bottom_width 4 \
8 | --img_size 32 \
9 | --max_iter 50000 \
10 | --gen_model shared_gan \
11 | --dis_model shared_gan \
12 | --latent_dim 128 \
13 | --gf_dim 256 \
14 | --df_dim 128 \
15 | --g_spectral_norm False \
16 | --d_spectral_norm True \
17 | --g_lr 0.0002 \
18 | --d_lr 0.0002 \
19 | --beta1 0.0 \
20 | --beta2 0.9 \
21 | --init_type xavier_uniform \
22 | --n_critic 5 \
23 | --val_freq 20 \
24 | --arch 1 0 1 1 1 0 0 1 1 1 0 1 0 3 \
25 | --exp_name derive
--------------------------------------------------------------------------------
/AutoGAN/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-07-25
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | from __future__ import absolute_import, division, print_function
8 |
9 | import models.autogan_cifar10_a
10 | import models.autogan_cifar10_b
11 | import models.autogan_cifar10_c
12 |
--------------------------------------------------------------------------------
/AutoGAN/models/autogan_cifar10_a.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-07-31
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | from torch import nn
8 |
9 | from models.building_blocks import Cell, DisBlock, OptimizedDisBlock
10 |
11 |
12 | class Generator(nn.Module):
13 | def __init__(self, args):
14 | super(Generator, self).__init__()
15 | self.args = args
16 | self.ch = args.gf_dim
17 | self.bottom_width = args.bottom_width
18 | self.l1 = nn.Linear(args.latent_dim, (self.bottom_width ** 2) * args.gf_dim)
19 | self.cell1 = Cell(
20 | args.gf_dim, args.gf_dim, "nearest", num_skip_in=0, short_cut=True
21 | )
22 | self.cell2 = Cell(
23 | args.gf_dim, args.gf_dim, "bilinear", num_skip_in=1, short_cut=True
24 | )
25 | self.cell3 = Cell(
26 | args.gf_dim, args.gf_dim, "nearest", num_skip_in=2, short_cut=False
27 | )
28 | self.to_rgb = nn.Sequential(
29 | nn.BatchNorm2d(args.gf_dim),
30 | nn.ReLU(),
31 | nn.Conv2d(args.gf_dim, 3, 3, 1, 1),
32 | nn.Tanh(),
33 | )
34 |
35 | def forward(self, z):
36 | h = self.l1(z).view(-1, self.ch, self.bottom_width, self.bottom_width)
37 | h1_skip_out, h1 = self.cell1(h)
38 | h2_skip_out, h2 = self.cell2(h1, (h1_skip_out,))
39 | _, h3 = self.cell3(h2, (h1_skip_out, h2_skip_out))
40 | output = self.to_rgb(h3)
41 |
42 | return output
43 |
44 |
45 | class Discriminator(nn.Module):
46 | def __init__(self, args, activation=nn.ReLU()):
47 | super(Discriminator, self).__init__()
48 | self.ch = args.df_dim
49 | self.activation = activation
50 | self.block1 = OptimizedDisBlock(args, 3, self.ch)
51 | self.block2 = DisBlock(
52 | args, self.ch, self.ch, activation=activation, downsample=True
53 | )
54 | self.block3 = DisBlock(
55 | args, self.ch, self.ch, activation=activation, downsample=False
56 | )
57 | self.block4 = DisBlock(
58 | args, self.ch, self.ch, activation=activation, downsample=False
59 | )
60 | self.l5 = nn.Linear(self.ch, 1, bias=False)
61 | if args.d_spectral_norm:
62 | self.l5 = nn.utils.spectral_norm(self.l5)
63 |
64 | def forward(self, x):
65 | h = x
66 | layers = [self.block1, self.block2, self.block3]
67 | model = nn.Sequential(*layers)
68 | h = model(h)
69 | h = self.block4(h)
70 | h = self.activation(h)
71 | # Global average pooling
72 | h = h.sum(2).sum(2)
73 | output = self.l5(h)
74 |
75 | return output
76 |
--------------------------------------------------------------------------------
/AutoGAN/models/autogan_cifar10_b.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-07-31
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | from torch import nn
8 |
9 | from models.building_blocks import Cell, DisBlock, OptimizedDisBlock
10 |
11 |
12 | class Generator(nn.Module):
13 | def __init__(self, args):
14 | super(Generator, self).__init__()
15 | self.args = args
16 | self.ch = args.gf_dim
17 | self.bottom_width = args.bottom_width
18 | self.l1 = nn.Linear(args.latent_dim, (self.bottom_width ** 2) * args.gf_dim)
19 | self.cell1 = Cell(
20 | args.gf_dim, args.gf_dim, "nearest", num_skip_in=0, short_cut=True
21 | )
22 | self.cell2 = Cell(
23 | args.gf_dim, args.gf_dim, "nearest", num_skip_in=1, short_cut=True
24 | )
25 | self.cell3 = Cell(
26 | args.gf_dim, args.gf_dim, "nearest", num_skip_in=2, short_cut=True
27 | )
28 | self.to_rgb = nn.Sequential(
29 | nn.BatchNorm2d(args.gf_dim),
30 | nn.ReLU(),
31 | nn.Conv2d(args.gf_dim, 3, 3, 1, 1),
32 | nn.Tanh(),
33 | )
34 |
35 | def forward(self, z):
36 | h = self.l1(z).view(-1, self.ch, self.bottom_width, self.bottom_width)
37 | h1_skip_out, h1 = self.cell1(h)
38 | h2_skip_out, h2 = self.cell2(h1, (h1_skip_out,))
39 | _, h3 = self.cell3(h2, (h1_skip_out, h2_skip_out))
40 | output = self.to_rgb(h3)
41 |
42 | return output
43 |
44 |
45 | class Discriminator(nn.Module):
46 | def __init__(self, args, activation=nn.ReLU()):
47 | super(Discriminator, self).__init__()
48 | self.ch = args.df_dim
49 | self.activation = activation
50 | self.block1 = OptimizedDisBlock(args, 3, self.ch)
51 | self.block2 = DisBlock(
52 | args, self.ch, self.ch, activation=activation, downsample=True
53 | )
54 | self.block3 = DisBlock(
55 | args, self.ch, self.ch, activation=activation, downsample=False
56 | )
57 | self.block4 = DisBlock(
58 | args, self.ch, self.ch, activation=activation, downsample=False
59 | )
60 | self.l5 = nn.Linear(self.ch, 1, bias=False)
61 | if args.d_spectral_norm:
62 | self.l5 = nn.utils.spectral_norm(self.l5)
63 |
64 | def forward(self, x):
65 | h = x
66 | layers = [self.block1, self.block2, self.block3]
67 | model = nn.Sequential(*layers)
68 | h = model(h)
69 | h = self.block4(h)
70 | h = self.activation(h)
71 | # Global average pooling
72 | h = h.sum(2).sum(2)
73 | output = self.l5(h)
74 |
75 | return output
76 |
--------------------------------------------------------------------------------
/AutoGAN/models/autogan_cifar10_c.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-07-31
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | from torch import nn
8 |
9 | from models.building_blocks import Cell, DisBlock, OptimizedDisBlock
10 |
11 |
12 | class Generator(nn.Module):
13 | def __init__(self, args):
14 | super(Generator, self).__init__()
15 | self.args = args
16 | self.ch = args.gf_dim
17 | self.bottom_width = args.bottom_width
18 | self.l1 = nn.Linear(args.latent_dim, (self.bottom_width ** 2) * args.gf_dim)
19 | self.cell1 = Cell(
20 | args.gf_dim,
21 | args.gf_dim,
22 | "nearest",
23 | num_skip_in=0,
24 | short_cut=True,
25 | norm="bn",
26 | )
27 | self.cell2 = Cell(
28 | args.gf_dim, args.gf_dim, "bilinear", num_skip_in=1, short_cut=True
29 | )
30 | self.cell3 = Cell(
31 | args.gf_dim, args.gf_dim, "nearest", num_skip_in=2, short_cut=False
32 | )
33 | self.to_rgb = nn.Sequential(
34 | nn.BatchNorm2d(args.gf_dim),
35 | nn.ReLU(),
36 | nn.Conv2d(args.gf_dim, 3, 3, 1, 1),
37 | nn.Tanh(),
38 | )
39 |
40 | def forward(self, z):
41 | h = self.l1(z).view(-1, self.ch, self.bottom_width, self.bottom_width)
42 | h1_skip_out, h1 = self.cell1(h)
43 | h2_skip_out, h2 = self.cell2(h1, (h1_skip_out,))
44 | _, h3 = self.cell3(h2, (h1_skip_out, h2_skip_out))
45 | output = self.to_rgb(h3)
46 |
47 | return output
48 |
49 |
50 | class Discriminator(nn.Module):
51 | def __init__(self, args, activation=nn.ReLU()):
52 | super(Discriminator, self).__init__()
53 | self.ch = args.df_dim
54 | self.activation = activation
55 | self.block1 = OptimizedDisBlock(args, 3, self.ch)
56 | self.block2 = DisBlock(
57 | args, self.ch, self.ch, activation=activation, downsample=True
58 | )
59 | self.block3 = DisBlock(
60 | args, self.ch, self.ch, activation=activation, downsample=False
61 | )
62 | self.block4 = DisBlock(
63 | args, self.ch, self.ch, activation=activation, downsample=False
64 | )
65 | self.l5 = nn.Linear(self.ch, 1, bias=False)
66 | if args.d_spectral_norm:
67 | self.l5 = nn.utils.spectral_norm(self.l5)
68 |
69 | def forward(self, x):
70 | h = x
71 | layers = [self.block1, self.block2, self.block3]
72 | model = nn.Sequential(*layers)
73 | h = model(h)
74 | h = self.block4(h)
75 | h = self.activation(h)
76 | # Global average pooling
77 | h = h.sum(2).sum(2)
78 | output = self.l5(h)
79 |
80 | return output
81 |
--------------------------------------------------------------------------------
/AutoGAN/models/building_blocks.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-08-02
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | import torch.nn.functional as F
8 | from torch import nn
9 |
10 | UP_MODES = ["nearest", "bilinear"]
11 | NORMS = ["in", "bn"]
12 |
13 |
14 | class Cell(nn.Module):
15 | def __init__(
16 | self,
17 | in_channels,
18 | out_channels,
19 | up_mode,
20 | ksize=3,
21 | num_skip_in=0,
22 | short_cut=False,
23 | norm=None,
24 | ):
25 | super(Cell, self).__init__()
26 | self.c1 = nn.Conv2d(in_channels, out_channels, ksize, padding=ksize // 2)
27 | self.c2 = nn.Conv2d(out_channels, out_channels, ksize, padding=ksize // 2)
28 | assert up_mode in UP_MODES
29 | self.up_mode = up_mode
30 | self.norm = norm
31 | if norm:
32 | assert norm in NORMS
33 | if norm == "bn":
34 | self.n1 = nn.BatchNorm2d(in_channels)
35 | self.n2 = nn.BatchNorm2d(out_channels)
36 | elif norm == "in":
37 | self.n1 = nn.InstanceNorm2d(in_channels)
38 | self.n2 = nn.InstanceNorm2d(out_channels)
39 | else:
40 | raise NotImplementedError(norm)
41 |
42 | # inner shortcut
43 | self.c_sc = None
44 | if short_cut:
45 | self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1)
46 |
47 | # cross scale skip
48 | self.skip_in_ops = None
49 | if num_skip_in:
50 | self.skip_in_ops = nn.ModuleList(
51 | [
52 | nn.Conv2d(out_channels, out_channels, kernel_size=1)
53 | for _ in range(num_skip_in)
54 | ]
55 | )
56 |
57 | def forward(self, x, skip_ft=None):
58 | residual = x
59 |
60 | # first conv
61 | if self.norm:
62 | residual = self.n1(residual)
63 | h = nn.ReLU()(residual)
64 | h = F.interpolate(h, scale_factor=2, mode=self.up_mode)
65 | _, _, ht, wt = h.size()
66 | h = self.c1(h)
67 | h_skip_out = h
68 |
69 | # second conv
70 | if self.skip_in_ops:
71 | assert len(self.skip_in_ops) == len(skip_ft)
72 | for ft, skip_in_op in zip(skip_ft, self.skip_in_ops):
73 | h += skip_in_op(F.interpolate(ft, size=(ht, wt), mode=self.up_mode))
74 | if self.norm:
75 | h = self.n2(h)
76 | h = nn.ReLU()(h)
77 | final_out = self.c2(h)
78 |
79 | # shortcut
80 | if self.c_sc:
81 | final_out += self.c_sc(F.interpolate(x, scale_factor=2, mode=self.up_mode))
82 |
83 | return h_skip_out, final_out
84 |
85 |
86 | def _downsample(x):
87 | # Downsample (Mean Avg Pooling with 2x2 kernel)
88 | return nn.AvgPool2d(kernel_size=2)(x)
89 |
90 |
91 | class OptimizedDisBlock(nn.Module):
92 | def __init__(
93 | self, args, in_channels, out_channels, ksize=3, pad=1, activation=nn.ReLU()
94 | ):
95 | super(OptimizedDisBlock, self).__init__()
96 | self.activation = activation
97 |
98 | self.c1 = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, padding=pad)
99 | self.c2 = nn.Conv2d(out_channels, out_channels, kernel_size=ksize, padding=pad)
100 | self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
101 | if args.d_spectral_norm:
102 | self.c1 = nn.utils.spectral_norm(self.c1)
103 | self.c2 = nn.utils.spectral_norm(self.c2)
104 | self.c_sc = nn.utils.spectral_norm(self.c_sc)
105 |
106 | def residual(self, x):
107 | h = x
108 | h = self.c1(h)
109 | h = self.activation(h)
110 | h = self.c2(h)
111 | h = _downsample(h)
112 | return h
113 |
114 | def shortcut(self, x):
115 | return self.c_sc(_downsample(x))
116 |
117 | def forward(self, x):
118 | return self.residual(x) + self.shortcut(x)
119 |
120 |
121 | class DisBlock(nn.Module):
122 | def __init__(
123 | self,
124 | args,
125 | in_channels,
126 | out_channels,
127 | hidden_channels=None,
128 | ksize=3,
129 | pad=1,
130 | activation=nn.ReLU(),
131 | downsample=False,
132 | ):
133 | super(DisBlock, self).__init__()
134 | self.activation = activation
135 | self.downsample = downsample
136 | self.learnable_sc = (in_channels != out_channels) or downsample
137 | hidden_channels = in_channels if hidden_channels is None else hidden_channels
138 |
139 | self.c1 = nn.Conv2d(
140 | in_channels, hidden_channels, kernel_size=ksize, padding=pad
141 | )
142 | self.c2 = nn.Conv2d(
143 | hidden_channels, out_channels, kernel_size=ksize, padding=pad
144 | )
145 | if args.d_spectral_norm:
146 | self.c1 = nn.utils.spectral_norm(self.c1)
147 | self.c2 = nn.utils.spectral_norm(self.c2)
148 |
149 | if self.learnable_sc:
150 | self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
151 | if args.d_spectral_norm:
152 | self.c_sc = nn.utils.spectral_norm(self.c_sc)
153 |
154 | def residual(self, x):
155 | h = x
156 | h = self.activation(h)
157 | h = self.c1(h)
158 | h = self.activation(h)
159 | h = self.c2(h)
160 | if self.downsample:
161 | h = _downsample(h)
162 | return h
163 |
164 | def shortcut(self, x):
165 | if self.learnable_sc:
166 | x = self.c_sc(x)
167 | if self.downsample:
168 | return _downsample(x)
169 | else:
170 | return x
171 | else:
172 | return x
173 |
174 | def forward(self, x):
175 | return self.residual(x) + self.shortcut(x)
176 |
--------------------------------------------------------------------------------
/AutoGAN/models_search/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-08-15
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | from models_search import controller, shared_gan
8 |
--------------------------------------------------------------------------------
/AutoGAN/models_search/controller.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-09-29
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 |
12 | from models_search.building_blocks_search import (CONV_TYPE, NORM_TYPE, SHORT_CUT_TYPE, SKIP_TYPE, UP_TYPE)
13 |
14 |
15 | class Controller(nn.Module):
16 | def __init__(self, args, cur_stage):
17 | """
18 | init
19 | :param args:
20 | :param cur_stage: varies from 0 to ...
21 | """
22 | super(Controller, self).__init__()
23 | self.hid_size = args.hid_size
24 | self.cur_stage = cur_stage
25 | self.lstm = torch.nn.LSTMCell(self.hid_size, self.hid_size)
26 | if cur_stage:
27 | self.tokens = [
28 | len(CONV_TYPE),
29 | len(NORM_TYPE),
30 | len(UP_TYPE),
31 | len(SHORT_CUT_TYPE),
32 | len(SKIP_TYPE) ** cur_stage,
33 | ]
34 | else:
35 | self.tokens = [
36 | len(CONV_TYPE),
37 | len(NORM_TYPE),
38 | len(UP_TYPE),
39 | len(SHORT_CUT_TYPE),
40 | ]
41 | self.encoder = nn.Embedding(sum(self.tokens), self.hid_size)
42 | self.decoders = nn.ModuleList(
43 | [nn.Linear(self.hid_size, token) for token in self.tokens]
44 | )
45 |
46 | def initHidden(self, batch_size):
47 | return torch.zeros(batch_size, self.hid_size, requires_grad=False).cuda()
48 |
49 | def forward(self, x, hidden, index):
50 | if index == 0:
51 | embed = x
52 | else:
53 | embed = self.encoder(x)
54 | hx, cx = self.lstm(embed, hidden)
55 |
56 | # decode
57 | logit = self.decoders[index](hx)
58 |
59 | return logit, (hx, cx)
60 |
61 | def sample(self, batch_size, with_hidden=False, prev_hiddens=None, prev_archs=None):
62 | x = self.initHidden(batch_size)
63 |
64 | if prev_hiddens:
65 | assert prev_archs
66 | prev_hxs, prev_cxs = prev_hiddens
67 | selected_idx = np.random.choice(
68 | len(prev_archs), batch_size
69 | ) # TODO: replace=False
70 | selected_idx = [int(x) for x in selected_idx]
71 |
72 | selected_archs = []
73 | selected_hxs = []
74 | selected_cxs = []
75 |
76 | for s_idx in selected_idx:
77 | selected_archs.append(prev_archs[s_idx].unsqueeze(0))
78 | selected_hxs.append(prev_hxs[s_idx].unsqueeze(0))
79 | selected_cxs.append(prev_cxs[s_idx].unsqueeze(0))
80 | selected_archs = torch.cat(selected_archs, 0)
81 | hidden = (torch.cat(selected_hxs, 0), torch.cat(selected_cxs, 0))
82 | else:
83 | hidden = (self.initHidden(batch_size), self.initHidden(batch_size))
84 | entropies = []
85 | actions = []
86 | selected_log_probs = []
87 | for decode_idx in range(len(self.decoders)):
88 | logit, hidden = self.forward(x, hidden, decode_idx)
89 | prob = F.softmax(logit, dim=-1) # bs * logit_dim
90 | log_prob = F.log_softmax(logit, dim=-1)
91 | entropies.append(-(log_prob * prob).sum(1, keepdim=True)) # bs * 1
92 | action = prob.multinomial(1) # batch_size * 1
93 | actions.append(action)
94 | selected_log_prob = log_prob.gather(1, action.data) # batch_size * 1
95 | selected_log_probs.append(selected_log_prob)
96 |
97 | x = action.view(batch_size) + sum(self.tokens[:decode_idx])
98 | x = x.requires_grad_(False)
99 |
100 | archs = torch.cat(actions, -1) # batch_size * len(self.decoders)
101 | selected_log_probs = torch.cat(
102 | selected_log_probs, -1
103 | ) # batch_size * len(self.decoders)
104 | entropies = torch.cat(entropies, 0) # bs * 1
105 |
106 | if prev_hiddens:
107 | archs = torch.cat([selected_archs, archs], -1)
108 |
109 | if with_hidden:
110 | return archs, selected_log_probs, entropies, hidden
111 |
112 | return archs, selected_log_probs, entropies
113 |
--------------------------------------------------------------------------------
/AutoGAN/models_search/shared_gan.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-08-15
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 | import torch.nn as nn
7 |
8 | from models_search.building_blocks_search import Cell
9 |
10 |
11 | class Generator(nn.Module):
12 | def __init__(self, args):
13 | super(Generator, self).__init__()
14 | self.args = args
15 | self.ch = args.gf_dim
16 | self.bottom_width = args.bottom_width
17 | self.l1 = nn.Linear(args.latent_dim, (self.bottom_width ** 2) * args.gf_dim)
18 | self.cell1 = Cell(args.gf_dim, args.gf_dim, num_skip_in=0)
19 | self.cell2 = Cell(args.gf_dim, args.gf_dim, num_skip_in=1)
20 | self.cell3 = Cell(args.gf_dim, args.gf_dim, num_skip_in=2)
21 | self.to_rgb = nn.Sequential(
22 | nn.BatchNorm2d(args.gf_dim),
23 | nn.ReLU(),
24 | nn.Conv2d(args.gf_dim, 3, 3, 1, 1),
25 | nn.Tanh(),
26 | )
27 |
28 | def set_arch(self, arch_id, cur_stage):
29 | if not isinstance(arch_id, list):
30 | arch_id = arch_id.to("cpu").numpy().tolist()
31 | arch_id = [int(x) for x in arch_id]
32 | self.cur_stage = cur_stage
33 | arch_stage1 = arch_id[:4]
34 | self.cell1.set_arch(
35 | conv_id=arch_stage1[0],
36 | norm_id=arch_stage1[1],
37 | up_id=arch_stage1[2],
38 | short_cut_id=arch_stage1[3],
39 | skip_ins=[],
40 | )
41 | if cur_stage >= 1:
42 | arch_stage2 = arch_id[4:9]
43 | self.cell2.set_arch(
44 | conv_id=arch_stage2[0],
45 | norm_id=arch_stage2[1],
46 | up_id=arch_stage2[2],
47 | short_cut_id=arch_stage2[3],
48 | skip_ins=arch_stage2[4],
49 | )
50 |
51 | if cur_stage == 2:
52 | arch_stage3 = arch_id[9:]
53 | self.cell3.set_arch(
54 | conv_id=arch_stage3[0],
55 | norm_id=arch_stage3[1],
56 | up_id=arch_stage3[2],
57 | short_cut_id=arch_stage3[3],
58 | skip_ins=arch_stage3[4],
59 | )
60 |
61 | def forward(self, z):
62 | h = self.l1(z).view(-1, self.ch, self.bottom_width, self.bottom_width)
63 | h1_skip_out, h1 = self.cell1(h)
64 | if self.cur_stage == 0:
65 | return self.to_rgb(h1)
66 | h2_skip_out, h2 = self.cell2(h1, (h1_skip_out,))
67 | if self.cur_stage == 1:
68 | return self.to_rgb(h2)
69 | _, h3 = self.cell3(h2, (h1_skip_out, h2_skip_out))
70 | if self.cur_stage == 2:
71 | return self.to_rgb(h3)
72 |
73 |
74 | def _downsample(x):
75 | # Downsample (Mean Avg Pooling with 2x2 kernel)
76 | return nn.AvgPool2d(kernel_size=2)(x)
77 |
78 |
79 | class OptimizedDisBlock(nn.Module):
80 | def __init__(
81 | self, args, in_channels, out_channels, ksize=3, pad=1, activation=nn.ReLU()
82 | ):
83 | super(OptimizedDisBlock, self).__init__()
84 | self.activation = activation
85 |
86 | self.c1 = nn.Conv2d(in_channels, out_channels, kernel_size=ksize, padding=pad)
87 | self.c2 = nn.Conv2d(out_channels, out_channels, kernel_size=ksize, padding=pad)
88 | self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
89 | if args.d_spectral_norm:
90 | self.c1 = nn.utils.spectral_norm(self.c1)
91 | self.c2 = nn.utils.spectral_norm(self.c2)
92 | self.c_sc = nn.utils.spectral_norm(self.c_sc)
93 |
94 | def residual(self, x):
95 | h = x
96 | h = self.c1(h)
97 | h = self.activation(h)
98 | h = self.c2(h)
99 | h = _downsample(h)
100 | return h
101 |
102 | def shortcut(self, x):
103 | return self.c_sc(_downsample(x))
104 |
105 | def forward(self, x):
106 | return self.residual(x) + self.shortcut(x)
107 |
108 |
109 | class DisBlock(nn.Module):
110 | def __init__(
111 | self,
112 | args,
113 | in_channels,
114 | out_channels,
115 | hidden_channels=None,
116 | ksize=3,
117 | pad=1,
118 | activation=nn.ReLU(),
119 | downsample=False,
120 | ):
121 | super(DisBlock, self).__init__()
122 | self.activation = activation
123 | self.downsample = downsample
124 | self.learnable_sc = (in_channels != out_channels) or downsample
125 | hidden_channels = in_channels if hidden_channels is None else hidden_channels
126 |
127 | self.c1 = nn.Conv2d(
128 | in_channels, hidden_channels, kernel_size=ksize, padding=pad
129 | )
130 | self.c2 = nn.Conv2d(
131 | hidden_channels, out_channels, kernel_size=ksize, padding=pad
132 | )
133 | if args.d_spectral_norm:
134 | self.c1 = nn.utils.spectral_norm(self.c1)
135 | self.c2 = nn.utils.spectral_norm(self.c2)
136 |
137 | if self.learnable_sc:
138 | self.c_sc = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
139 | if args.d_spectral_norm:
140 | self.c_sc = nn.utils.spectral_norm(self.c_sc)
141 |
142 | def residual(self, x):
143 | h = x
144 | h = self.activation(h)
145 | h = self.c1(h)
146 | h = self.activation(h)
147 | h = self.c2(h)
148 | if self.downsample:
149 | h = _downsample(h)
150 | return h
151 |
152 | def shortcut(self, x):
153 | if self.learnable_sc:
154 | x = self.c_sc(x)
155 | if self.downsample:
156 | return _downsample(x)
157 | else:
158 | return x
159 | else:
160 | return x
161 |
162 | def forward(self, x):
163 | return self.residual(x) + self.shortcut(x)
164 |
165 |
166 | class Discriminator(nn.Module):
167 | def __init__(self, args, activation=nn.ReLU()):
168 | super(Discriminator, self).__init__()
169 | self.ch = args.df_dim
170 | self.activation = activation
171 | self.block1 = OptimizedDisBlock(args, 3, self.ch)
172 | self.block2 = DisBlock(
173 | args, self.ch, self.ch, activation=activation, downsample=True
174 | )
175 | self.block3 = DisBlock(
176 | args, self.ch, self.ch, activation=activation, downsample=False
177 | )
178 | self.block4 = DisBlock(
179 | args, self.ch, self.ch, activation=activation, downsample=False
180 | )
181 | self.l5 = nn.Linear(self.ch, 1, bias=False)
182 | if args.d_spectral_norm:
183 | self.l5 = nn.utils.spectral_norm(self.l5)
184 | self.cur_stage = 0
185 |
186 | def forward(self, x):
187 | h = x
188 | layers = [self.block1, self.block2, self.block3]
189 | variable_model = nn.Sequential(*layers[: (self.cur_stage + 1)])
190 | h = variable_model(h)
191 | h = self.block4(h)
192 | h = self.activation(h)
193 | # Global average pooling
194 | h = h.sum(2).sum(2)
195 | output = self.l5(h)
196 |
197 | return output
198 |
--------------------------------------------------------------------------------
/AutoGAN/regan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Regan_training(nn.Module):
6 |
7 | def __init__(self, model, sparsity, train_on_sparse=False):
8 | super(Regan_training, self).__init__()
9 |
10 | self.model = model
11 | self.sparsity = sparsity
12 | self.train_on_sparse = train_on_sparse
13 | self.layers = []
14 | self.masks = []
15 |
16 | layers = list(self.model.named_parameters())
17 |
18 | for i in range(0, len(layers)):
19 | w = layers[i]
20 | self.layers.append(w[1])
21 |
22 | self.reset_masks()
23 |
24 | def reset_masks(self):
25 |
26 | for w in self.layers:
27 | mask_w = torch.ones_like(w, dtype=bool)
28 | self.masks.append(mask_w)
29 |
30 | return self.masks
31 |
32 | def update_masks(self):
33 |
34 | for i, w in enumerate(self.layers):
35 | q_w = torch.quantile(torch.abs(w), q=self.sparsity)
36 | mask_w = torch.where(torch.abs(w) < q_w, True, False)
37 |
38 | self.masks[i] = mask_w
39 |
40 | def turn_training_mode(self, mode):
41 | if mode == 'dense':
42 | self.train_on_sparse = False
43 | else:
44 | self.train_on_sparse = True
45 | self.update_masks()
46 |
47 | def apply_masks(self):
48 | for w, mask_w in zip(self.layers, self.masks):
49 | w.data[mask_w] = 0
50 | w.grad.data[mask_w] = 0
51 |
52 | def forward(self, x):
53 | return self.model(x)
54 |
--------------------------------------------------------------------------------
/AutoGAN/search.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-09-25
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | from __future__ import absolute_import, division, print_function
8 |
9 | import os
10 |
11 | import torch
12 | import torch.nn as nn
13 | from tensorboardX import SummaryWriter
14 | from tqdm import tqdm
15 |
16 | import cfg
17 | import datasets
18 | import models_search # noqa
19 | from functions import get_topk_arch_hidden, train_controller, train_shared
20 | from utils.fid_score import check_or_download_inception, create_inception_graph
21 | from utils.inception_score import _init_inception
22 | from utils.utils import create_logger, RunningStats, save_checkpoint, set_log_dir
23 |
24 | torch.backends.cudnn.enabled = True
25 | torch.backends.cudnn.benchmark = True
26 |
27 |
28 | class GrowCtrler(object):
29 | def __init__(self, grow_step1, grow_step2):
30 | self.grow_step1 = grow_step1
31 | self.grow_step2 = grow_step2
32 |
33 | def cur_stage(self, search_iter):
34 | """
35 | Return current stage.
36 | :param epoch: current epoch.
37 | :return: current stage
38 | """
39 | if search_iter < self.grow_step1:
40 | return 0
41 | elif self.grow_step1 <= search_iter < self.grow_step2:
42 | return 1
43 | else:
44 | return 2
45 |
46 |
47 | def create_ctrler(args, cur_stage, weights_init):
48 | controller = eval("models_search." + args.controller + ".Controller")(
49 | args=args, cur_stage=cur_stage
50 | ).cuda()
51 | controller.apply(weights_init)
52 | ctrl_optimizer = torch.optim.Adam(
53 | filter(lambda p: p.requires_grad, controller.parameters()),
54 | args.ctrl_lr,
55 | (args.beta1, args.beta2),
56 | )
57 | return controller, ctrl_optimizer
58 |
59 |
60 | def create_shared_gan(args, weights_init):
61 | gen_net = eval("models_search." + args.gen_model + ".Generator")(args=args).cuda()
62 | dis_net = eval("models_search." + args.dis_model + ".Discriminator")(
63 | args=args
64 | ).cuda()
65 | gen_net.apply(weights_init)
66 | dis_net.apply(weights_init)
67 | gen_optimizer = torch.optim.Adam(
68 | filter(lambda p: p.requires_grad, gen_net.parameters()),
69 | args.g_lr,
70 | (args.beta1, args.beta2),
71 | )
72 | dis_optimizer = torch.optim.Adam(
73 | filter(lambda p: p.requires_grad, dis_net.parameters()),
74 | args.d_lr,
75 | (args.beta1, args.beta2),
76 | )
77 | return gen_net, dis_net, gen_optimizer, dis_optimizer
78 |
79 |
80 | def main():
81 | args = cfg.parse_args()
82 | torch.cuda.manual_seed(args.random_seed)
83 |
84 | # set tf env
85 | _init_inception()
86 | inception_path = check_or_download_inception(None)
87 | create_inception_graph(inception_path)
88 |
89 | # weight init
90 | def weights_init(m):
91 | classname = m.__class__.__name__
92 | if classname.find("Conv2d") != -1:
93 | if args.init_type == "normal":
94 | nn.init.normal_(m.weight.data, 0.0, 0.02)
95 | elif args.init_type == "orth":
96 | nn.init.orthogonal_(m.weight.data)
97 | elif args.init_type == "xavier_uniform":
98 | nn.init.xavier_uniform(m.weight.data, 1.0)
99 | else:
100 | raise NotImplementedError(
101 | "{} unknown inital type".format(args.init_type)
102 | )
103 | elif classname.find("BatchNorm2d") != -1:
104 | nn.init.normal_(m.weight.data, 1.0, 0.02)
105 | nn.init.constant_(m.bias.data, 0.0)
106 |
107 | gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(
108 | args, weights_init
109 | )
110 |
111 | # set grow controller
112 | grow_ctrler = GrowCtrler(args.grow_step1, args.grow_step2)
113 |
114 | # initial
115 | start_search_iter = 0
116 |
117 | # set writer
118 | if args.load_path:
119 | print(f"=> resuming from {args.load_path}")
120 | assert os.path.exists(args.load_path)
121 | checkpoint_file = os.path.join(args.load_path, "Model", "checkpoint.pth")
122 | assert os.path.exists(checkpoint_file)
123 | checkpoint = torch.load(checkpoint_file)
124 | # set controller && its optimizer
125 | cur_stage = checkpoint["cur_stage"]
126 | controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init)
127 |
128 | start_search_iter = checkpoint["search_iter"]
129 | gen_net.load_state_dict(checkpoint["gen_state_dict"])
130 | dis_net.load_state_dict(checkpoint["dis_state_dict"])
131 | controller.load_state_dict(checkpoint["ctrl_state_dict"])
132 | gen_optimizer.load_state_dict(checkpoint["gen_optimizer"])
133 | dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])
134 | ctrl_optimizer.load_state_dict(checkpoint["ctrl_optimizer"])
135 | prev_archs = checkpoint["prev_archs"]
136 | prev_hiddens = checkpoint["prev_hiddens"]
137 |
138 | args.path_helper = checkpoint["path_helper"]
139 | logger = create_logger(args.path_helper["log_path"])
140 | logger.info(
141 | f"=> loaded checkpoint {checkpoint_file} (search iteration {start_search_iter})"
142 | )
143 | else:
144 | # create new log dir
145 | assert args.exp_name
146 | args.path_helper = set_log_dir("logs", args.exp_name)
147 | logger = create_logger(args.path_helper["log_path"])
148 | prev_archs = None
149 | prev_hiddens = None
150 |
151 | # set controller && its optimizer
152 | cur_stage = 0
153 | controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init)
154 |
155 | # set up data_loader
156 | dataset = datasets.ImageDataset(args, 2 ** (cur_stage + 3))
157 | train_loader = dataset.train
158 |
159 | logger.info(args)
160 | writer_dict = {
161 | "writer": SummaryWriter(args.path_helper["log_path"]),
162 | "controller_steps": start_search_iter * args.ctrl_step,
163 | }
164 |
165 | g_loss_history = RunningStats(args.dynamic_reset_window)
166 | d_loss_history = RunningStats(args.dynamic_reset_window)
167 |
168 | # train loop
169 | for search_iter in tqdm(
170 | range(int(start_search_iter), int(args.max_search_iter)), desc="search progress"
171 | ):
172 | logger.info(f"")
173 | if search_iter == args.grow_step1 or search_iter == args.grow_step2:
174 |
175 | # save
176 | cur_stage = grow_ctrler.cur_stage(search_iter)
177 | logger.info(f"=> grow to stage {cur_stage}")
178 | prev_archs, prev_hiddens = get_topk_arch_hidden(
179 | args, controller, gen_net, prev_archs, prev_hiddens
180 | )
181 |
182 | # grow section
183 | del controller
184 | del ctrl_optimizer
185 | controller, ctrl_optimizer = create_ctrler(args, cur_stage, weights_init)
186 |
187 | dataset = datasets.ImageDataset(args, 2 ** (cur_stage + 3))
188 | train_loader = dataset.train
189 |
190 | dynamic_reset = train_shared(
191 | args,
192 | gen_net,
193 | dis_net,
194 | g_loss_history,
195 | d_loss_history,
196 | controller,
197 | gen_optimizer,
198 | dis_optimizer,
199 | train_loader,
200 | prev_hiddens=prev_hiddens,
201 | prev_archs=prev_archs,
202 | )
203 | train_controller(
204 | args,
205 | controller,
206 | ctrl_optimizer,
207 | gen_net,
208 | prev_hiddens,
209 | prev_archs,
210 | writer_dict,
211 | )
212 |
213 | if dynamic_reset:
214 | logger.info("re-initialize share GAN")
215 | del gen_net, dis_net, gen_optimizer, dis_optimizer
216 | gen_net, dis_net, gen_optimizer, dis_optimizer = create_shared_gan(
217 | args, weights_init
218 | )
219 |
220 | save_checkpoint(
221 | {
222 | "cur_stage": cur_stage,
223 | "search_iter": search_iter + 1,
224 | "gen_model": args.gen_model,
225 | "dis_model": args.dis_model,
226 | "controller": args.controller,
227 | "gen_state_dict": gen_net.state_dict(),
228 | "dis_state_dict": dis_net.state_dict(),
229 | "ctrl_state_dict": controller.state_dict(),
230 | "gen_optimizer": gen_optimizer.state_dict(),
231 | "dis_optimizer": dis_optimizer.state_dict(),
232 | "ctrl_optimizer": ctrl_optimizer.state_dict(),
233 | "prev_archs": prev_archs,
234 | "prev_hiddens": prev_hiddens,
235 | "path_helper": args.path_helper,
236 | },
237 | False,
238 | args.path_helper["ckpt_path"],
239 | )
240 |
241 | final_archs, _ = get_topk_arch_hidden(
242 | args, controller, gen_net, prev_archs, prev_hiddens
243 | )
244 | logger.info(f"discovered archs: {final_archs}")
245 |
246 |
247 | if __name__ == "__main__":
248 | main()
249 |
--------------------------------------------------------------------------------
/AutoGAN/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-07-25
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | from __future__ import absolute_import, division, print_function
8 |
9 | import os
10 |
11 | import numpy as np
12 | import torch
13 | from tensorboardX import SummaryWriter
14 |
15 | import cfg
16 | import models # noqa
17 | from functions import validate
18 | from utils.fid_score import check_or_download_inception, create_inception_graph
19 | from utils.inception_score import _init_inception
20 | from utils.utils import create_logger, set_log_dir
21 |
22 | torch.backends.cudnn.enabled = True
23 | torch.backends.cudnn.benchmark = True
24 |
25 |
26 | def main():
27 | args = cfg.parse_args()
28 | torch.cuda.manual_seed(args.random_seed)
29 | assert args.exp_name
30 | assert args.load_path.endswith(".pth")
31 | assert os.path.exists(args.load_path)
32 | args.path_helper = set_log_dir("logs_eval", args.exp_name)
33 | logger = create_logger(args.path_helper["log_path"], phase="test")
34 |
35 | # set tf env
36 | _init_inception()
37 | inception_path = check_or_download_inception(None)
38 | create_inception_graph(inception_path)
39 |
40 | # import network
41 | gen_net = eval("models." + args.gen_model + ".Generator")(args=args).cuda()
42 |
43 | # fid stat
44 | if args.dataset.lower() == "cifar10":
45 | fid_stat = "fid_stat/fid_stats_cifar10_train.npz"
46 | elif args.dataset.lower() == "stl10":
47 | fid_stat = "fid_stat/stl10_train_unlabeled_fid_stats_48.npz"
48 | else:
49 | raise NotImplementedError(f"no fid stat for {args.dataset.lower()}")
50 | assert os.path.exists(fid_stat)
51 |
52 | # initial
53 | fixed_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (25, args.latent_dim)))
54 |
55 | # set writer
56 | logger.info(f"=> resuming from {args.load_path}")
57 | checkpoint_file = args.load_path
58 | assert os.path.exists(checkpoint_file)
59 | checkpoint = torch.load(checkpoint_file)
60 |
61 | if "avg_gen_state_dict" in checkpoint:
62 | gen_net.load_state_dict(checkpoint["avg_gen_state_dict"])
63 | epoch = checkpoint["epoch"]
64 | logger.info(f"=> loaded checkpoint {checkpoint_file} (epoch {epoch})")
65 | else:
66 | gen_net.load_state_dict(checkpoint)
67 | logger.info(f"=> loaded checkpoint {checkpoint_file}")
68 |
69 | logger.info(args)
70 | writer_dict = {
71 | "writer": SummaryWriter(args.path_helper["log_path"]),
72 | "valid_global_steps": 0,
73 | }
74 | inception_score, fid_score = validate(
75 | args, fixed_z, fid_stat, gen_net, writer_dict, clean_dir=False
76 | )
77 | logger.info(f"Inception score: {inception_score}, FID score: {fid_score}.")
78 |
79 |
80 | if __name__ == "__main__":
81 | main()
82 |
--------------------------------------------------------------------------------
/AutoGAN/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-07-25
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 | # Modified by Jiahao Xu (jiahxu@polyu.edu.hk)
7 |
8 | from __future__ import absolute_import, division, print_function
9 |
10 | import os
11 | from copy import deepcopy
12 |
13 | import numpy as np
14 | import torch
15 | import torch.nn as nn
16 | from tensorboardX import SummaryWriter
17 | from tqdm import tqdm
18 | from regan import Regan_training
19 |
20 | import cfg
21 | import datasets
22 | import models # noqa
23 | from functions import copy_params, LinearLrDecay, load_params, train, validate
24 | # from utils.fid_score import check_or_download_inception, create_inception_graph
25 | # from utils.inception_score import _init_inception
26 | from utils.utils import create_logger, save_checkpoint, set_log_dir
27 |
28 |
29 | torch.backends.cudnn.enabled = True
30 | torch.backends.cudnn.benchmark = True
31 |
32 |
33 | def main():
34 | args = cfg.parse_args()
35 | torch.cuda.manual_seed(args.random_seed)
36 |
37 | # # set tf env
38 | # _init_inception()
39 | # inception_path = check_or_download_inception(None)
40 | # create_inception_graph(inception_path)
41 |
42 | # import network
43 | gen_net = Regan_training(eval("models." + args.gen_model + ".Generator")(args=args).cuda(), sparsity=args.sparsity)
44 | dis_net = eval("models." + args.dis_model + ".Discriminator")(args=args).cuda()
45 |
46 | # weight init
47 | def weights_init(m):
48 | classname = m.__class__.__name__
49 | if classname.find("Conv2d") != -1:
50 | if args.init_type == "normal":
51 | nn.init.normal_(m.weight.data, 0.0, 0.02)
52 | elif args.init_type == "orth":
53 | nn.init.orthogonal_(m.weight.data)
54 | elif args.init_type == "xavier_uniform":
55 | nn.init.xavier_uniform(m.weight.data, 1.0)
56 | else:
57 | raise NotImplementedError(
58 | "{} unknown inital type".format(args.init_type)
59 | )
60 | elif classname.find("BatchNorm2d") != -1:
61 | nn.init.normal_(m.weight.data, 1.0, 0.02)
62 | nn.init.constant_(m.bias.data, 0.0)
63 |
64 | gen_net.apply(weights_init)
65 | dis_net.apply(weights_init)
66 |
67 | # set optimizer
68 | gen_optimizer = torch.optim.Adam(
69 | filter(lambda p: p.requires_grad, gen_net.parameters()),
70 | args.g_lr,
71 | (args.beta1, args.beta2),
72 | )
73 | dis_optimizer = torch.optim.Adam(
74 | filter(lambda p: p.requires_grad, dis_net.parameters()),
75 | args.d_lr,
76 | (args.beta1, args.beta2),
77 | )
78 | gen_scheduler = LinearLrDecay(
79 | gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic
80 | )
81 | dis_scheduler = LinearLrDecay(
82 | dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic
83 | )
84 |
85 | # set up data_loader
86 | dataset = datasets.ImageDataset(args)
87 | train_loader = dataset.train
88 |
89 | # fid stat
90 | if args.dataset.lower() == "cifar10":
91 | fid_stat = "fid_stat/cifar10.test.npz"
92 | elif args.dataset.lower() == "stl10":
93 | fid_stat = "fid_stat/stl10_train_unlabeled_fid_stats_48.npz"
94 | else:
95 | raise NotImplementedError(f"no fid stat for {args.dataset.lower()}")
96 | # assert os.path.exists(fid_stat)
97 |
98 | # epoch number for dis_net
99 | # args.max_epoch = args.max_epoch * args.n_critic
100 | # if args.max_iter:
101 | # args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader))
102 |
103 | # initial
104 | fixed_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (25, args.latent_dim)))
105 | gen_avg_param = copy_params(gen_net)
106 | start_epoch = 0
107 | best_fid = 1e4
108 |
109 | assert args.exp_name
110 | args.path_helper = set_log_dir("logs", args.exp_name)
111 | logger = create_logger(args.path_helper["log_path"])
112 |
113 | logger.info(args)
114 | writer_dict = {
115 | "writer": SummaryWriter(args.path_helper["log_path"]),
116 | "train_global_steps": start_epoch * len(train_loader),
117 | "valid_global_steps": start_epoch // args.val_freq,
118 | }
119 |
120 | # train loop
121 | flag_g = 1
122 | print(start_epoch)
123 | pbar = range(1, 1 + args.max_epoch)
124 | print(len(pbar))
125 | pbar = tqdm(pbar, initial=start_epoch, desc="total progress", dynamic_ncols=True, smoothing=0.01)
126 | for cur_epoch in pbar:
127 | epoch = cur_epoch + start_epoch
128 |
129 | if epoch > args.max_epoch:
130 | print("Done!")
131 |
132 | break
133 |
134 | if args.regan:
135 | # Warm-up phase, do not enable the ReGAN training
136 | if epoch < args.warmup_epoch + 1:
137 | print('current is warmup training')
138 | gen_net.train_on_sparse = False
139 |
140 | # Warm-up phase finished, get into Sparse training phase
141 | elif epoch > args.warmup_epoch and flag_g < args.g + 1:
142 | print('epoch %d, current is sparse training' % epoch)
143 | # turn training mode to sparse, update mask
144 | gen_net.turn_training_mode(mode='sparse')
145 | # make sure the learning rate of sparse phase is the original one
146 | if flag_g == 1:
147 | print('turn learning rate to normal')
148 | for params in gen_optimizer.param_groups:
149 | params['lr'] = args.g_lr
150 | flag_g = flag_g + 1
151 |
152 | # Sparse training phase finished, get into dense training phase
153 | elif epoch > args.warmup_epoch and flag_g < 2 * args.g + 1:
154 | print('epoch %d, current is dense training' % epoch)
155 | # turn training mode to dense
156 | gen_net.turn_training_mode(mode='dense')
157 | # make sure the learning rate of Dense phase is 10 times smaller than the original one
158 | if flag_g == args.g + 1:
159 | print('turn learning rate to 10 times smaller')
160 | for params in gen_optimizer.param_groups:
161 | params['lr'] = args.g_lr * 0.1
162 | flag_g = flag_g + 1
163 |
164 | # When curren Sparse-Dense pair training finished, get into next pair training
165 | if flag_g == 2 * args.g + 1:
166 | print('clean flag')
167 | flag_g = 1
168 |
169 | lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None
170 |
171 | train(
172 | args,
173 | gen_net,
174 | dis_net,
175 | gen_optimizer,
176 | dis_optimizer,
177 | gen_avg_param,
178 | train_loader,
179 | epoch,
180 | writer_dict,
181 | lr_schedulers,
182 | )
183 |
184 |
185 | if epoch % args.val_freq == 0 or epoch == args.max_epoch:
186 | backup_param = copy_params(gen_net)
187 | load_params(gen_net, gen_avg_param)
188 | inception_score, is_std, fid_score = validate(
189 | args, fixed_z, fid_stat, gen_net, writer_dict
190 | )
191 | logger.info('IS: %.4f (%.4f) || FID: %.4f || @ epoch %d.' % (inception_score, is_std, fid_score, epoch))
192 | load_params(gen_net, backup_param)
193 | if fid_score < best_fid:
194 | best_fid = fid_score
195 | is_best = True
196 | else:
197 | is_best = False
198 | if epoch == int(args.max_epoch):
199 | logger.info(
200 | f"Total time cost: {time_count}s.")
201 |
202 | else:
203 | is_best = False
204 |
205 | avg_gen_net = deepcopy(gen_net)
206 | load_params(avg_gen_net, gen_avg_param)
207 | if epoch % args.val_freq == 0 or epoch == args.max_epoch:
208 | save_checkpoint(
209 | {
210 | "epoch": epoch,
211 | "gen_model": args.gen_model,
212 | "dis_model": args.dis_model,
213 | "gen_state_dict": gen_net.state_dict(),
214 | "dis_state_dict": dis_net.state_dict(),
215 | "avg_gen_state_dict": avg_gen_net.state_dict(),
216 | "gen_optimizer": gen_optimizer.state_dict(),
217 | "dis_optimizer": dis_optimizer.state_dict(),
218 | "best_fid": best_fid,
219 | "path_helper": args.path_helper,
220 | "regan": args.regan,
221 | "g": args.g,
222 | "sparsity": args.sparsity
223 | },
224 | is_best,
225 | args.path_helper["ckpt_path"], epoch
226 | )
227 | del avg_gen_net
228 |
229 |
230 | if __name__ == "__main__":
231 | main()
232 |
233 |
--------------------------------------------------------------------------------
/AutoGAN/train_derived.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-10-01
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | from __future__ import absolute_import, division, print_function
8 |
9 | import os
10 | from copy import deepcopy
11 |
12 | import numpy as np
13 | import torch
14 | import torch.nn as nn
15 | from tensorboardX import SummaryWriter
16 | from tqdm import tqdm
17 |
18 | import cfg
19 | import datasets
20 | import models_search # noqa
21 | from functions import copy_params, LinearLrDecay, load_params, train, validate
22 | from utils.fid_score import check_or_download_inception, create_inception_graph
23 | from utils.inception_score import _init_inception
24 | from utils.utils import create_logger, save_checkpoint, set_log_dir
25 |
26 | torch.backends.cudnn.enabled = True
27 | torch.backends.cudnn.benchmark = True
28 |
29 |
30 | def main():
31 | args = cfg.parse_args()
32 | torch.cuda.manual_seed(args.random_seed)
33 |
34 | # set tf env
35 | _init_inception()
36 | inception_path = check_or_download_inception(None)
37 | create_inception_graph(inception_path)
38 |
39 | # import network
40 | gen_net = eval("models_search." + args.gen_model + ".Generator")(args=args).cuda()
41 | dis_net = eval("models_search." + args.dis_model + ".Discriminator")(
42 | args=args
43 | ).cuda()
44 |
45 | gen_net.set_arch(args.arch, cur_stage=2)
46 | dis_net.cur_stage = 2
47 |
48 | # weight init
49 | def weights_init(m):
50 | classname = m.__class__.__name__
51 | if classname.find("Conv2d") != -1:
52 | if args.init_type == "normal":
53 | nn.init.normal_(m.weight.data, 0.0, 0.02)
54 | elif args.init_type == "orth":
55 | nn.init.orthogonal_(m.weight.data)
56 | elif args.init_type == "xavier_uniform":
57 | nn.init.xavier_uniform(m.weight.data, 1.0)
58 | else:
59 | raise NotImplementedError(
60 | "{} unknown inital type".format(args.init_type)
61 | )
62 | elif classname.find("BatchNorm2d") != -1:
63 | nn.init.normal_(m.weight.data, 1.0, 0.02)
64 | nn.init.constant_(m.bias.data, 0.0)
65 |
66 | gen_net.apply(weights_init)
67 | dis_net.apply(weights_init)
68 |
69 | # set optimizer
70 | gen_optimizer = torch.optim.Adam(
71 | filter(lambda p: p.requires_grad, gen_net.parameters()),
72 | args.g_lr,
73 | (args.beta1, args.beta2),
74 | )
75 | dis_optimizer = torch.optim.Adam(
76 | filter(lambda p: p.requires_grad, dis_net.parameters()),
77 | args.d_lr,
78 | (args.beta1, args.beta2),
79 | )
80 | gen_scheduler = LinearLrDecay(
81 | gen_optimizer, args.g_lr, 0.0, 0, args.max_iter * args.n_critic
82 | )
83 | dis_scheduler = LinearLrDecay(
84 | dis_optimizer, args.d_lr, 0.0, 0, args.max_iter * args.n_critic
85 | )
86 |
87 | # set up data_loader
88 | dataset = datasets.ImageDataset(args)
89 | train_loader = dataset.train
90 |
91 | # fid stat
92 | if args.dataset.lower() == "cifar10":
93 | fid_stat = "fid_stat/fid_stats_cifar10_train.npz"
94 | elif args.dataset.lower() == "stl10":
95 | fid_stat = "fid_stat/stl10_train_unlabeled_fid_stats_48.npz"
96 | else:
97 | raise NotImplementedError(f"no fid stat for {args.dataset.lower()}")
98 | assert os.path.exists(fid_stat)
99 |
100 | # epoch number for dis_net
101 | args.max_epoch = args.max_epoch * args.n_critic
102 | if args.max_iter:
103 | args.max_epoch = np.ceil(args.max_iter * args.n_critic / len(train_loader))
104 |
105 | # initial
106 | fixed_z = torch.cuda.FloatTensor(np.random.normal(0, 1, (25, args.latent_dim)))
107 | gen_avg_param = copy_params(gen_net)
108 | start_epoch = 0
109 | best_fid = 1e4
110 |
111 | # set writer
112 | if args.load_path:
113 | print(f"=> resuming from {args.load_path}")
114 | assert os.path.exists(args.load_path)
115 | checkpoint_file = os.path.join(args.load_path, "Model", "checkpoint.pth")
116 | assert os.path.exists(checkpoint_file)
117 | checkpoint = torch.load(checkpoint_file)
118 | start_epoch = checkpoint["epoch"]
119 | best_fid = checkpoint["best_fid"]
120 | gen_net.load_state_dict(checkpoint["gen_state_dict"])
121 | dis_net.load_state_dict(checkpoint["dis_state_dict"])
122 | gen_optimizer.load_state_dict(checkpoint["gen_optimizer"])
123 | dis_optimizer.load_state_dict(checkpoint["dis_optimizer"])
124 | avg_gen_net = deepcopy(gen_net)
125 | avg_gen_net.load_state_dict(checkpoint["avg_gen_state_dict"])
126 | gen_avg_param = copy_params(avg_gen_net)
127 | del avg_gen_net
128 |
129 | args.path_helper = checkpoint["path_helper"]
130 | logger = create_logger(args.path_helper["log_path"])
131 | logger.info(f"=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})")
132 | else:
133 | # create new log dir
134 | assert args.exp_name
135 | args.path_helper = set_log_dir("logs", args.exp_name)
136 | logger = create_logger(args.path_helper["log_path"])
137 |
138 | logger.info(args)
139 | writer_dict = {
140 | "writer": SummaryWriter(args.path_helper["log_path"]),
141 | "train_global_steps": start_epoch * len(train_loader),
142 | "valid_global_steps": start_epoch // args.val_freq,
143 | }
144 |
145 | # train loop
146 | for epoch in tqdm(
147 | range(int(start_epoch), int(args.max_epoch)), desc="total progress"
148 | ):
149 | lr_schedulers = (gen_scheduler, dis_scheduler) if args.lr_decay else None
150 | train(
151 | args,
152 | gen_net,
153 | dis_net,
154 | gen_optimizer,
155 | dis_optimizer,
156 | gen_avg_param,
157 | train_loader,
158 | epoch,
159 | writer_dict,
160 | lr_schedulers,
161 | )
162 |
163 | if epoch and epoch % args.val_freq == 0 or epoch == int(args.max_epoch) - 1:
164 | backup_param = copy_params(gen_net)
165 | load_params(gen_net, gen_avg_param)
166 | inception_score, fid_score = validate(
167 | args, fixed_z, fid_stat, gen_net, writer_dict
168 | )
169 | logger.info(
170 | f"Inception score: {inception_score}, FID score: {fid_score} || @ epoch {epoch}."
171 | )
172 | load_params(gen_net, backup_param)
173 | if fid_score < best_fid:
174 | best_fid = fid_score
175 | is_best = True
176 | else:
177 | is_best = False
178 | else:
179 | is_best = False
180 |
181 | avg_gen_net = deepcopy(gen_net)
182 | load_params(avg_gen_net, gen_avg_param)
183 | save_checkpoint(
184 | {
185 | "epoch": epoch + 1,
186 | "gen_model": args.gen_model,
187 | "dis_model": args.dis_model,
188 | "gen_state_dict": gen_net.state_dict(),
189 | "dis_state_dict": dis_net.state_dict(),
190 | "avg_gen_state_dict": avg_gen_net.state_dict(),
191 | "gen_optimizer": gen_optimizer.state_dict(),
192 | "dis_optimizer": dis_optimizer.state_dict(),
193 | "best_fid": best_fid,
194 | "path_helper": args.path_helper,
195 | },
196 | is_best,
197 | args.path_helper["ckpt_path"],
198 | )
199 | del avg_gen_net
200 |
201 |
202 | if __name__ == "__main__":
203 | main()
204 |
--------------------------------------------------------------------------------
/AutoGAN/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-07-25
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | from __future__ import absolute_import, division, print_function
8 |
9 | from utils import utils
10 |
--------------------------------------------------------------------------------
/AutoGAN/utils/cal_fid_stat.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-07-26
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 |
8 | import argparse
9 | import glob
10 | import os
11 |
12 | import numpy as np
13 | import tensorflow as tf
14 | from imageio import imread
15 |
16 | import utils.fid_score as fid
17 |
18 |
19 | def parse_args():
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument(
22 | "--data_path",
23 | type=str,
24 | required=True,
25 | help="set path to training set jpg images dir",
26 | )
27 | parser.add_argument(
28 | "--output_file",
29 | type=str,
30 | default="fid_stat/fid_stats_cifar10_train.npz",
31 | help="path for where to store the statistics",
32 | )
33 |
34 | opt = parser.parse_args()
35 | print(opt)
36 | return opt
37 |
38 |
39 | def main():
40 | args = parse_args()
41 |
42 | ########
43 | # PATHS
44 | ########
45 | data_path = args.data_path
46 | output_path = args.output_file
47 | # if you have downloaded and extracted
48 | # http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
49 | # set this path to the directory where the extracted files are, otherwise
50 | # just set it to None and the script will later download the files for you
51 | inception_path = None
52 | print("check for inception model..", end=" ", flush=True)
53 | inception_path = fid.check_or_download_inception(
54 | inception_path
55 | ) # download inception if necessary
56 | print("ok")
57 |
58 | # loads all images into memory (this might require a lot of RAM!)
59 | print("load images..", end=" ", flush=True)
60 | image_list = glob.glob(os.path.join(data_path, "*.jpg"))
61 | images = np.array([imread(str(fn)).astype(np.float32) for fn in image_list])
62 | print("%d images found and loaded" % len(images))
63 |
64 | print("create inception graph..", end=" ", flush=True)
65 | fid.create_inception_graph(
66 | inception_path
67 | ) # load the graph into the current TF graph
68 | print("ok")
69 |
70 | print("calculte FID stats..", end=" ", flush=True)
71 | config = tf.ConfigProto()
72 | config.gpu_options.allow_growth = True
73 | with tf.Session(config=config) as sess:
74 | sess.run(tf.global_variables_initializer())
75 | mu, sigma = fid.calculate_activation_statistics(images, sess, batch_size=100)
76 | np.savez_compressed(output_path, mu=mu, sigma=sigma)
77 | print("finished")
78 |
79 |
80 | if __name__ == "__main__":
81 | main()
82 |
--------------------------------------------------------------------------------
/AutoGAN/utils/inception_score.py:
--------------------------------------------------------------------------------
1 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
2 | from __future__ import absolute_import, division, print_function
3 |
4 | import math
5 | import os
6 | import os.path
7 | import sys
8 | import tarfile
9 |
10 | import numpy as np
11 | import tensorflow as tf
12 | from six.moves import urllib
13 | from tqdm import tqdm
14 |
15 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
16 | MODEL_DIR = "/tmp/imagenet"
17 | DATA_URL = (
18 | "http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz"
19 | )
20 | softmax = None
21 |
22 | config = tf.ConfigProto()
23 | config.gpu_options.allow_growth = True
24 |
25 |
26 | # Call this function with list of images. Each of elements should be a
27 | # numpy array with values ranging from 0 to 255.
28 | def get_inception_score(images, splits=10):
29 | assert type(images) == list
30 | assert type(images[0]) == np.ndarray
31 | assert len(images[0].shape) == 3
32 | assert np.max(images[0]) > 10
33 | assert np.min(images[0]) >= 0.0
34 | inps = []
35 | for img in images:
36 | img = img.astype(np.float32)
37 | inps.append(np.expand_dims(img, 0))
38 | bs = 100
39 | with tf.Session(config=config) as sess:
40 | preds = []
41 | n_batches = int(math.ceil(float(len(inps)) / float(bs)))
42 | for i in tqdm(range(n_batches), desc="Calculate inception score"):
43 | sys.stdout.flush()
44 | inp = inps[(i * bs) : min((i + 1) * bs, len(inps))]
45 | inp = np.concatenate(inp, 0)
46 | pred = sess.run(softmax, {"ExpandDims:0": inp})
47 | preds.append(pred)
48 | preds = np.concatenate(preds, 0)
49 | scores = []
50 | for i in range(splits):
51 | part = preds[
52 | (i * preds.shape[0] // splits) : ((i + 1) * preds.shape[0] // splits), :
53 | ]
54 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
55 | kl = np.mean(np.sum(kl, 1))
56 | scores.append(np.exp(kl))
57 |
58 | sess.close()
59 | return np.mean(scores), np.std(scores)
60 |
61 |
62 | # This function is called automatically.
63 | def _init_inception():
64 | global softmax
65 | if not os.path.exists(MODEL_DIR):
66 | os.makedirs(MODEL_DIR)
67 | filename = DATA_URL.split("/")[-1]
68 | filepath = os.path.join(MODEL_DIR, filename)
69 | if not os.path.exists(filepath):
70 |
71 | def _progress(count, block_size, total_size):
72 | sys.stdout.write(
73 | "\r>> Downloading %s %.1f%%"
74 | % (filename, float(count * block_size) / float(total_size) * 100.0)
75 | )
76 | sys.stdout.flush()
77 |
78 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
79 | print()
80 | statinfo = os.stat(filepath)
81 | print("Succesfully downloaded", filename, statinfo.st_size, "bytes.")
82 | tarfile.open(filepath, "r:gz").extractall(MODEL_DIR)
83 | with tf.gfile.FastGFile(
84 | os.path.join(MODEL_DIR, "classify_image_graph_def.pb"), "rb"
85 | ) as f:
86 | graph_def = tf.GraphDef()
87 | graph_def.ParseFromString(f.read())
88 | _ = tf.import_graph_def(graph_def, name="")
89 | # Works with an arbitrary minibatch size.
90 | with tf.Session(config=config) as sess:
91 | pool3 = sess.graph.get_tensor_by_name("pool_3:0")
92 | ops = pool3.graph.get_operations()
93 | for op_idx, op in enumerate(ops):
94 | for o in op.outputs:
95 | shape = o.get_shape()
96 | if shape._dims != []:
97 | shape = [s.value for s in shape]
98 | new_shape = []
99 | for j, s in enumerate(shape):
100 | if s == 1 and j == 0:
101 | new_shape.append(None)
102 | else:
103 | new_shape.append(s)
104 | o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
105 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
106 | logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
107 | softmax = tf.nn.softmax(logits)
108 | sess.close()
109 |
--------------------------------------------------------------------------------
/AutoGAN/utils/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Date : 2019-07-25
3 | # @Author : Xinyu Gong (xy_gong@tamu.edu)
4 | # @Link : None
5 | # @Version : 0.0
6 |
7 | import collections
8 | import logging
9 | import math
10 | import os
11 | import time
12 | from datetime import datetime
13 |
14 | import dateutil.tz
15 | import torch
16 |
17 |
18 | def create_logger(log_dir, phase="train"):
19 | time_str = time.strftime("%Y-%m-%d-%H-%M")
20 | log_file = "{}_{}.log".format(time_str, phase)
21 | final_log_file = os.path.join(log_dir, log_file)
22 | head = "%(asctime)-15s %(message)s"
23 | logging.basicConfig(filename=str(final_log_file), format=head)
24 | logger = logging.getLogger()
25 | logger.setLevel(logging.INFO)
26 | console = logging.StreamHandler()
27 | logging.getLogger("").addHandler(console)
28 |
29 | return logger
30 |
31 |
32 | def set_log_dir(root_dir, exp_name):
33 | path_dict = {}
34 | os.makedirs(root_dir, exist_ok=True)
35 |
36 | # set log path
37 | exp_path = os.path.join(root_dir, exp_name)
38 | now = datetime.now(dateutil.tz.tzlocal())
39 | timestamp = now.strftime("%Y_%m_%d_%H_%M_%S")
40 | prefix = exp_path + "_" + timestamp
41 | os.makedirs(prefix)
42 | path_dict["prefix"] = prefix
43 |
44 | # set checkpoint path
45 | ckpt_path = os.path.join(prefix, "Model")
46 | os.makedirs(ckpt_path)
47 | path_dict["ckpt_path"] = ckpt_path
48 |
49 | log_path = os.path.join(prefix, "Log")
50 | os.makedirs(log_path)
51 | path_dict["log_path"] = log_path
52 |
53 | # set sample image path for fid calculation
54 | sample_path = os.path.join(prefix, "Samples")
55 | os.makedirs(sample_path)
56 | path_dict["sample_path"] = sample_path
57 |
58 | return path_dict
59 |
60 |
61 | def save_checkpoint(states, is_best, output_dir, epoch):
62 | filename = "checkpoint_%d.pth" % epoch
63 | torch.save(states, os.path.join(output_dir, filename))
64 | if is_best:
65 | torch.save(states, os.path.join(output_dir, "checkpoint_best.pth"))
66 |
67 |
68 | class RunningStats:
69 | def __init__(self, WIN_SIZE):
70 | self.mean = 0
71 | self.run_var = 0
72 | self.WIN_SIZE = WIN_SIZE
73 |
74 | self.window = collections.deque(maxlen=WIN_SIZE)
75 |
76 | def clear(self):
77 | self.window.clear()
78 | self.mean = 0
79 | self.run_var = 0
80 |
81 | def is_full(self):
82 | return len(self.window) == self.WIN_SIZE
83 |
84 | def push(self, x):
85 |
86 | if len(self.window) == self.WIN_SIZE:
87 | # Adjusting variance
88 | x_removed = self.window.popleft()
89 | self.window.append(x)
90 | old_m = self.mean
91 | self.mean += (x - x_removed) / self.WIN_SIZE
92 | self.run_var += (x + x_removed - old_m - self.mean) * (x - x_removed)
93 | else:
94 | # Calculating first variance
95 | self.window.append(x)
96 | delta = x - self.mean
97 | self.mean += delta / len(self.window)
98 | self.run_var += delta * (x - self.mean)
99 |
100 | def get_mean(self):
101 | return self.mean if len(self.window) else 0.0
102 |
103 | def get_var(self):
104 | return self.run_var / len(self.window) if len(self.window) > 1 else 0.0
105 |
106 | def get_std(self):
107 | return math.sqrt(self.get_var())
108 |
109 | def get_all(self):
110 | return list(self.window)
111 |
112 | def __str__(self):
113 | return "Current window values: {}".format(list(self.window))
114 |
--------------------------------------------------------------------------------
/ProGAN/main.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 | from model import Discriminator, Generator
3 | import argparse
4 | from utils import get_loader, gradient_penalty
5 | from regan import Regan_training
6 |
7 |
8 | def train(epoch, num_epochs, critic, gen, loader, step, opt_critic, opt_gen):
9 | for batch_idx, (real, _) in enumerate(loader, 0):
10 |
11 | alpha = (batch_idx + epoch * len(loader)) / (num_epochs * len(loader))
12 |
13 | real = real.to(device)
14 | cur_batch_size = real.shape[0]
15 |
16 | # Train Critic: max E[critic(real)] - E[critic(fake)] <-> min -E[critic(real)] + E[critic(fake)]
17 | # which is equivalent to minimizing the negative of the expression
18 | noise = torch.randn(cur_batch_size, args.z_dim, 1, 1).to(device)
19 |
20 | fake = gen(noise, alpha, step)
21 | critic_real = critic(real, alpha, step)
22 | critic_fake = critic(fake.detach(), alpha, step)
23 | gp = gradient_penalty(critic, real, fake, alpha, step, device=device)
24 | loss_critic = (
25 | -(torch.mean(critic_real) - torch.mean(critic_fake))
26 | + args.lambda_gp * gp
27 | + (0.001 * torch.mean(critic_real ** 2))
28 | )
29 |
30 | opt_critic.zero_grad()
31 | loss_critic.backward()
32 | opt_critic.step()
33 |
34 | # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
35 | gen_fake = critic(fake, alpha, step)
36 | loss_gen = -torch.mean(gen_fake)
37 |
38 | opt_gen.zero_grad()
39 | loss_gen.backward()
40 | # scaler_gen.step(opt_gen)
41 |
42 | if args.regan and gen.train_on_sparse:
43 | gen.apply_masks()
44 |
45 | opt_gen.step()
46 |
47 | alpha = min(alpha, 1)
48 |
49 | # Output training stats
50 | if batch_idx % 50 == 0:
51 | print('[%d][%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f\tAlpha: %.2f'
52 | % (4 * 2 ** step, epoch, num_epochs, batch_idx, len(loader), loss_critic.item(),
53 | loss_gen.item(), critic_real.mean().item(), critic_fake.mean().item(),
54 | gen_fake.mean().item(), alpha))
55 |
56 | return alpha
57 |
58 |
59 | def main():
60 | gen = Regan_training(Generator(args.z_dim, args.in_channels).to(device), sparsity=args.sparsity)
61 | critic = Discriminator(args.in_channels).to(device)
62 |
63 | opt_gen = optim.Adam(gen.parameters(), lr=args.lr, betas=(0.0, 0.99))
64 | opt_critic = optim.Adam(critic.parameters(), lr=args.lr, betas=(0.0, 0.99))
65 |
66 | step = 0
67 |
68 | flag_g = 1
69 | image_size = 4 * 2 ** step
70 | for num_epochs in progressive_epochs[step:]:
71 | print(f"Current image size: {image_size}")
72 | # 4->0, 8->1, 16->2, 32->3, 64 -> 4
73 |
74 | loader, dataset = get_loader(image_size, args.workers, batch_size, '../dataset', args.data_ratio)
75 |
76 | if args.epoch != -1:
77 | num_epochs = args.epoch
78 | for epoch in range(1, 1 + num_epochs):
79 | print(f"Epoch [{epoch}/{num_epochs}]")
80 |
81 | if args.regan and step == len(batch_size) - 1:
82 | # Warm-up phase, do not enable the ReGAN training
83 | if epoch < args.warmup_epoch + 1:
84 | print('current is warmup training')
85 | gen.train_on_sparse = False
86 |
87 | # Warm-up phase finished, get into Sparse training phase
88 | elif epoch > args.warmup_epoch and flag_g < args.g + 1:
89 | print('epoch %d, current is sparse training' % epoch)
90 | # turn training mode to sparse, update mask
91 | gen.turn_training_mode(mode='sparse')
92 | # make sure the learning rate of sparse phase is the original one
93 | if flag_g == 1:
94 | print('turn learning rate to normal')
95 | for params in opt_gen.param_groups:
96 | params['lr'] = args.lr
97 | flag_g = flag_g + 1
98 |
99 | # Sparse training phase finished, get into dense training phase
100 | elif epoch > args.warmup_epoch and flag_g < 2 * args.g + 1:
101 | print('epoch %d, current is dense training' % epoch)
102 | # turn training mode to dense
103 | gen.turn_training_mode(mode='dense')
104 | # make sure the learning rate of Dense phase is 10 times smaller than the original one
105 | if flag_g == args.g + 1:
106 | print('turn learning rate to 10 times smaller')
107 | for params in opt_gen.param_groups:
108 | params['lr'] = args.lr * 0.1
109 | flag_g = flag_g + 1
110 |
111 | # When curren Sparse-Dense pair training finished, get into next pair training
112 | if flag_g == 2 * args.g + 1:
113 | print('clean flag')
114 | flag_g = 1
115 |
116 | alpha = train(epoch, num_epochs, critic, gen, loader, step, opt_critic, opt_gen)
117 |
118 | step += 1 # progress to the next img size
119 | image_size = 4 * 2 ** step
120 |
121 |
122 | if __name__ == "__main__":
123 | argparser = argparse.ArgumentParser()
124 | argparser.add_argument('--load', type=bool, default=False)
125 | argparser.add_argument('--load_epoch', type=int, default=0)
126 | argparser.add_argument('--epoch', type=int, default=-1)
127 | argparser.add_argument('--z_dim', type=int, default=512)
128 | argparser.add_argument('--in_channels', type=int, default=512)
129 | argparser.add_argument('--lr', type=float, default=1e-3)
130 | argparser.add_argument('--workers', type=int, default=4)
131 | argparser.add_argument('--lambda_gp', type=int, default=10)
132 | argparser.add_argument('--data_ratio', type=float, default=1.0)
133 | argparser.add_argument('--regan', action="store_true")
134 | argparser.add_argument('--sparsity', type=float, default=0.3)
135 | argparser.add_argument('--g', type=int, default=2)
136 | argparser.add_argument('--warmup_epoch', type=int, default=1)
137 |
138 |
139 | args = argparser.parse_args()
140 |
141 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
142 | start_training_at = 4
143 | batch_size = [128, 128, 128, 64]
144 | progressive_epochs = [2, 2, 2, 8]
145 |
146 | main()
147 |
--------------------------------------------------------------------------------
/ProGAN/regan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Regan_training(nn.Module):
6 |
7 | def __init__(self, model, sparsity, train_on_sparse=False):
8 | super(Regan_training, self).__init__()
9 |
10 | self.model = model
11 | self.sparsity = sparsity
12 | self.train_on_sparse = train_on_sparse
13 | self.layers = []
14 | self.masks = []
15 |
16 | layers = list(self.model.named_parameters())
17 |
18 | for i in range(0, len(layers)):
19 | w = layers[i]
20 | self.layers.append(w[1])
21 |
22 | self.reset_masks()
23 |
24 | def reset_masks(self):
25 |
26 | for w in self.layers:
27 | mask_w = torch.ones_like(w, dtype=bool)
28 | self.masks.append(mask_w)
29 |
30 | return self.masks
31 |
32 | def update_masks(self):
33 |
34 | for i, w in enumerate(self.layers):
35 | q_w = torch.quantile(torch.abs(w), q=self.sparsity)
36 | mask_w = torch.where(torch.abs(w) < q_w, True, False)
37 |
38 | self.masks[i] = mask_w
39 |
40 | def turn_training_mode(self, mode):
41 | if mode == 'dense':
42 | self.train_on_sparse = False
43 | else:
44 | self.train_on_sparse = True
45 | self.update_masks()
46 |
47 | def apply_masks(self):
48 | for w, mask_w in zip(self.layers, self.masks):
49 | w.data[mask_w] = 0
50 | w.grad.data[mask_w] = 0
51 |
52 | def forward(self, x, alpha, step):
53 | return self.model(x, alpha, step)
54 |
--------------------------------------------------------------------------------
/ProGAN/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torchvision.datasets as datasets
4 | import torchvision.transforms as transforms
5 | from torch.utils.data import DataLoader
6 | from math import log2
7 |
8 |
9 | def get_loader(image_size, workers, bs, dataroot, data_ratio):
10 | transform = transforms.Compose(
11 | [
12 | transforms.Resize((image_size, image_size)),
13 | transforms.ToTensor(),
14 | # transforms.RandomHorizontalFlip(p=0.5),
15 | transforms.Normalize(
16 | [0.5 for _ in range(3)],
17 | [0.5 for _ in range(3)],
18 | ),
19 | ]
20 | )
21 | batch_size = bs[int(log2(image_size / 4))]
22 | dataset = datasets.CIFAR10(root=dataroot, transform=transform)
23 | subset = torch.utils.data.Subset(dataset, np.arange(int(len(dataset) * data_ratio)))
24 | loader = DataLoader(
25 | subset,
26 | batch_size=batch_size,
27 | shuffle=True,
28 | num_workers=workers,
29 | pin_memory=True,
30 | )
31 | return loader, dataset
32 |
33 |
34 | def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
35 | BATCH_SIZE, C, H, W = real.shape
36 | beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
37 | interpolated_images = real * beta + fake.detach() * (1 - beta)
38 | interpolated_images.requires_grad_(True)
39 |
40 | # Calculate critic scores
41 | mixed_scores = critic(interpolated_images, alpha, train_step)
42 |
43 | # Take the gradient of the scores with respect to the images
44 | gradient = torch.autograd.grad(
45 | inputs=interpolated_images,
46 | outputs=mixed_scores,
47 | grad_outputs=torch.ones_like(mixed_scores),
48 | create_graph=True,
49 | retain_graph=True,
50 | )[0]
51 | gradient = gradient.view(gradient.shape[0], -1)
52 | gradient_norm = gradient.norm(2, dim=1)
53 | gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
54 | return gradient_penalty
55 |
56 |
57 |
--------------------------------------------------------------------------------
/SNGAN/main.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import argparse
3 | import os
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import torch.utils.data
8 | import torchvision.datasets as dset
9 | import torchvision.transforms as transforms
10 | from model import ResDiscriminator32, ResGenerator32
11 | from regan import Regan_training
12 | import numpy as np
13 | import warnings
14 | warnings.filterwarnings("ignore")
15 |
16 |
17 | def main():
18 | # Create the dataset
19 | dataset = dset.CIFAR10(root=args.dataroot,
20 | transform=transforms.Compose([
21 | transforms.Resize(args.image_size),
22 | transforms.CenterCrop(args.image_size),
23 | transforms.ToTensor(),
24 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
25 | ]), download=True, train=True)
26 |
27 | # Make sub-training dataset
28 | subset = torch.utils.data.Subset(dataset, np.arange(int(len(dataset) * args.data_ratio)))
29 | # Create the dataloader
30 | dataloader = torch.utils.data.DataLoader(subset, batch_size=args.batch_size,
31 | shuffle=True, num_workers=args.workers)
32 |
33 | netD = ResDiscriminator32().to(device)
34 | netG = Regan_training(ResGenerator32(args.noise_size).to(device), sparsity=args.sparsity)
35 |
36 | # Setup Adam optimizers for both G and D
37 | optimizerD = optim.Adam(netD.parameters(), args.lr, (0, 0.9))
38 | optimizerG = optim.Adam(netG.parameters(), args.lr, (0, 0.9))
39 |
40 | print("Starting Training Loop...")
41 |
42 | flag_g = 1
43 |
44 | for epoch in range(1, args.epoch + 1):
45 |
46 | if args.regan:
47 | # Warm-up phase, do not enable the ReGAN training
48 | if epoch < args.warmup_epoch + 1:
49 | print('current is warmup training')
50 | netG.train_on_sparse = False
51 |
52 | # Warm-up phase finished, get into Sparse training phase
53 | elif epoch > args.warmup_epoch and flag_g < args.g + 1:
54 | print('epoch %d, current is sparse training' % epoch)
55 | # turn training mode to sparse, update mask
56 | netG.turn_training_mode(mode='sparse')
57 | # make sure the learning rate of sparse phase is the original one
58 | if flag_g == 1:
59 | print('turn learning rate to normal')
60 | for params in optimizerG.param_groups:
61 | params['lr'] = args.lr
62 | flag_g = flag_g + 1
63 |
64 | # Sparse training phase finished, get into dense training phase
65 | elif epoch > args.warmup_epoch and flag_g < 2 * args.g + 1:
66 | print('epoch %d, current is dense training' % epoch)
67 | # turn training mode to dense
68 | netG.turn_training_mode(mode='dense')
69 | # make sure the learning rate of Dense phase is 10 times smaller than the original one
70 | if flag_g == args.g + 1:
71 | print('turn learning rate to 10 times smaller')
72 | for params in optimizerG.param_groups:
73 | params['lr'] = args.lr * 0.1
74 | flag_g = flag_g + 1
75 |
76 | # When curren Sparse-Dense pair training finished, get into next pair training
77 | if flag_g == 2 * args.g + 1:
78 | print('clean flag')
79 | flag_g = 1
80 |
81 | for i, data in enumerate(dataloader, 0):
82 |
83 | netD.zero_grad()
84 | real_cpu = data[0].to(device)
85 | b_size = real_cpu.size(0)
86 | output = netD(real_cpu).view(-1)
87 | errD_real = torch.mean(nn.ReLU(inplace=True)(1.0 - output))
88 | errD_real.backward()
89 | D_x = output.mean().item()
90 |
91 | noise = torch.randn(b_size, args.noise_size, device=device)
92 | fake = netG(noise)
93 |
94 | output = netD(fake.detach()).view(-1)
95 | errD_fake = torch.mean(nn.ReLU(inplace=True)(1 + output))
96 | errD_fake.backward()
97 | D_G_z1 = output.mean().item()
98 | errD = errD_real + errD_fake
99 | optimizerD.step()
100 |
101 | if i % args.n_critic == 0:
102 | netG.zero_grad()
103 | noise = torch.randn(b_size, args.noise_size, device=device)
104 | fake = netG(noise)
105 | output = netD(fake).view(-1)
106 | errG = -torch.mean(output)
107 | errG.backward()
108 | D_G_z2 = output.mean().item()
109 |
110 | # Eliminate weights and their gradients
111 | if args.regan and netG.train_on_sparse:
112 | netG.apply_masks()
113 |
114 | optimizerG.step()
115 |
116 | # Output training stats
117 | if i % 50 == 0:
118 | print('[%4d/%4d][%3d/%3d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
119 | % (epoch, args.epoch, i, len(dataloader),
120 | errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
121 |
122 |
123 | if __name__ == '__main__':
124 | model_name = 'SNGAN'
125 | argparser = argparse.ArgumentParser()
126 | argparser.add_argument('--epoch', type=int, default=20)
127 | argparser.add_argument('--batch_size', type=int, default=64)
128 | argparser.add_argument('--lr', type=float, default=2e-4)
129 | argparser.add_argument('--workers', type=int, default=4)
130 | argparser.add_argument('--image_size', type=int, default=32)
131 | argparser.add_argument('--noise_size', type=int, default=128)
132 | argparser.add_argument('--dataroot', type=str, default='../dataset')
133 | argparser.add_argument('--clip_value', type=float, default=0.01)
134 | argparser.add_argument('--n_critic', type=int, default=5)
135 | argparser.add_argument('--sparsity', type=float, default=0.3)
136 | argparser.add_argument('--g', type=int, default=5)
137 | argparser.add_argument('--warmup_epoch', type=int, default=100)
138 | argparser.add_argument('--data_ratio', type=float, default=1.0)
139 | argparser.add_argument('--regan', action="store_true")
140 | args = argparser.parse_args()
141 |
142 | if not os.path.exists(args.dataroot):
143 | os.makedirs(args.dataroot)
144 |
145 | device = "cuda"
146 |
147 | main()
148 |
--------------------------------------------------------------------------------
/SNGAN/model.py:
--------------------------------------------------------------------------------
1 | # This page taken from https://github.com/w86763777/pytorch-gan-collections/blob/master/source/models/sngan.py
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.init as init
7 | from torch.nn.utils.spectral_norm import spectral_norm
8 |
9 |
10 | class Generator(nn.Module):
11 | def __init__(self, z_dim, M=4):
12 | super().__init__()
13 | self.M = M
14 | self.linear = nn.Linear(z_dim, M * M * 512)
15 | self.main = nn.Sequential(
16 | nn.BatchNorm2d(512),
17 | nn.ReLU(True),
18 | nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
19 | nn.BatchNorm2d(256),
20 | nn.ReLU(True),
21 | nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
22 | nn.BatchNorm2d(128),
23 | nn.ReLU(True),
24 | nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
25 | nn.BatchNorm2d(64),
26 | nn.ReLU(True),
27 | nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1),
28 | nn.Tanh())
29 | self.initialize()
30 |
31 | def initialize(self):
32 | for m in self.modules():
33 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
34 | init.normal_(m.weight, std=0.02)
35 | init.zeros_(m.bias)
36 |
37 | def forward(self, z, *args, **kwargs):
38 | x = self.linear(z)
39 | x = x.view(x.size(0), -1, self.M, self.M)
40 | x = self.main(x)
41 | return x
42 |
43 |
44 | class Discriminator(nn.Module):
45 | def __init__(self, M=32):
46 | super().__init__()
47 | self.M = M
48 |
49 | self.main = nn.Sequential(
50 | # M
51 | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
52 | nn.LeakyReLU(0.1, inplace=True),
53 | nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
54 | nn.LeakyReLU(0.1, inplace=True),
55 | # M / 2
56 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
57 | nn.LeakyReLU(0.1, inplace=True),
58 | nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
59 | nn.LeakyReLU(0.1, inplace=True),
60 | # M / 4
61 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
62 | nn.LeakyReLU(0.1, inplace=True),
63 | nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
64 | nn.LeakyReLU(0.1, inplace=True),
65 | # M / 8
66 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
67 | nn.LeakyReLU(0.1, inplace=True))
68 |
69 | self.linear = nn.Linear(M // 8 * M // 8 * 512, 1)
70 | self.initialize()
71 |
72 | def initialize(self):
73 | for m in self.modules():
74 | if isinstance(m, (nn.Conv2d, nn.Linear)):
75 | init.normal_(m.weight, std=0.02)
76 | init.zeros_(m.bias)
77 | spectral_norm(m)
78 |
79 | def forward(self, x, *args, **kwargs):
80 | x = self.main(x)
81 | x = torch.flatten(x, start_dim=1)
82 | x = self.linear(x)
83 | return x
84 |
85 |
86 | class Generator32(Generator):
87 | def __init__(self, z_dim):
88 | super().__init__(z_dim, M=4)
89 |
90 |
91 | class Discriminator32(Discriminator):
92 | def __init__(self):
93 | super().__init__(M=32)
94 |
95 |
96 | class ResGenBlock(nn.Module):
97 | def __init__(self, in_channels, out_channels):
98 | super().__init__()
99 | self.residual = nn.Sequential(
100 | nn.BatchNorm2d(in_channels),
101 | nn.ReLU(),
102 | nn.Upsample(scale_factor=2),
103 | nn.Conv2d(in_channels, out_channels, 3, stride=1, padding=1),
104 | nn.BatchNorm2d(out_channels),
105 | nn.ReLU(),
106 | nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1),
107 | )
108 | self.shortcut = nn.Sequential(
109 | nn.Upsample(scale_factor=2),
110 | nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0)
111 | )
112 | self.initialize()
113 |
114 | def initialize(self):
115 | for m in self.residual.modules():
116 | if isinstance(m, nn.Conv2d):
117 | init.xavier_uniform_(m.weight, math.sqrt(2))
118 | init.zeros_(m.bias)
119 | for m in self.shortcut.modules():
120 | if isinstance(m, nn.Conv2d):
121 | init.xavier_uniform_(m.weight)
122 | init.zeros_(m.bias)
123 |
124 | def forward(self, x):
125 | return self.residual(x) + self.shortcut(x)
126 |
127 |
128 | class ResGenerator32(nn.Module):
129 | def __init__(self, z_dim):
130 | super().__init__()
131 | self.z_dim = z_dim
132 | self.linear = nn.Linear(z_dim, 4 * 4 * 256)
133 |
134 | self.blocks = nn.Sequential(
135 | ResGenBlock(256, 256),
136 | ResGenBlock(256, 256),
137 | ResGenBlock(256, 256),
138 | )
139 | self.output = nn.Sequential(
140 | nn.BatchNorm2d(256),
141 | nn.ReLU(True),
142 | nn.Conv2d(256, 3, 3, stride=1, padding=1),
143 | nn.Tanh(),
144 | )
145 | self.initialize()
146 |
147 | def initialize(self):
148 | init.xavier_uniform_(self.linear.weight)
149 | init.zeros_(self.linear.bias)
150 | for m in self.output.modules():
151 | if isinstance(m, nn.Conv2d):
152 | init.xavier_uniform_(m.weight)
153 | init.zeros_(m.bias)
154 |
155 | def forward(self, z):
156 | z = self.linear(z)
157 | z = z.view(-1, 256, 4, 4)
158 | return self.output(self.blocks(z))
159 |
160 |
161 | class OptimizedResDisblock(nn.Module):
162 | def __init__(self, in_channels, out_channels):
163 | super().__init__()
164 | self.shortcut = nn.Sequential(
165 | nn.AvgPool2d(2),
166 | nn.Conv2d(in_channels, out_channels, 1, 1, 0))
167 | self.residual = nn.Sequential(
168 | nn.Conv2d(in_channels, out_channels, 3, 1, 1),
169 | nn.ReLU(),
170 | nn.Conv2d(out_channels, out_channels, 3, 1, 1),
171 | nn.AvgPool2d(2))
172 | self.initialize()
173 |
174 | def initialize(self):
175 | for m in self.residual.modules():
176 | if isinstance(m, nn.Conv2d):
177 | init.xavier_uniform_(m.weight, math.sqrt(2))
178 | init.zeros_(m.bias)
179 | spectral_norm(m)
180 | for m in self.shortcut.modules():
181 | if isinstance(m, nn.Conv2d):
182 | init.xavier_uniform_(m.weight)
183 | init.zeros_(m.bias)
184 | spectral_norm(m)
185 |
186 | def forward(self, x):
187 | return self.residual(x) + self.shortcut(x)
188 |
189 |
190 | class ResDisBlock(nn.Module):
191 | def __init__(self, in_channels, out_channels, down=False):
192 | super().__init__()
193 | shortcut = []
194 | if in_channels != out_channels or down:
195 | shortcut.append(
196 | nn.Conv2d(in_channels, out_channels, 1, 1, 0))
197 | if down:
198 | shortcut.append(nn.AvgPool2d(2))
199 | self.shortcut = nn.Sequential(*shortcut)
200 |
201 | residual = [
202 | nn.ReLU(),
203 | nn.Conv2d(in_channels, out_channels, 3, 1, 1),
204 | nn.ReLU(),
205 | nn.Conv2d(out_channels, out_channels, 3, 1, 1),
206 | ]
207 | if down:
208 | residual.append(nn.AvgPool2d(2))
209 | self.residual = nn.Sequential(*residual)
210 | self.initialize()
211 |
212 | def initialize(self):
213 | for m in self.residual.modules():
214 | if isinstance(m, nn.Conv2d):
215 | init.xavier_uniform_(m.weight, math.sqrt(2))
216 | init.zeros_(m.bias)
217 | spectral_norm(m)
218 | for m in self.shortcut.modules():
219 | if isinstance(m, nn.Conv2d):
220 | init.xavier_uniform_(m.weight)
221 | init.zeros_(m.bias)
222 | spectral_norm(m)
223 |
224 | def forward(self, x):
225 | return (self.residual(x) + self.shortcut(x))
226 |
227 |
228 | class ResDiscriminator32(nn.Module):
229 | def __init__(self):
230 | super().__init__()
231 | self.model = nn.Sequential(
232 | OptimizedResDisblock(3, 128),
233 | ResDisBlock(128, 128, down=True),
234 | ResDisBlock(128, 128),
235 | ResDisBlock(128, 128),
236 | nn.ReLU())
237 | self.linear = nn.Linear(128, 1, bias=False)
238 | self.initialize()
239 |
240 | def initialize(self):
241 | init.xavier_uniform_(self.linear.weight)
242 | spectral_norm(self.linear)
243 |
244 | def forward(self, x):
245 | x = self.model(x).sum(dim=[2, 3])
246 | x = self.linear(x)
247 | return x
--------------------------------------------------------------------------------
/SNGAN/regan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Regan_training(nn.Module):
6 |
7 | def __init__(self, model, sparsity, train_on_sparse=False):
8 | super(Regan_training, self).__init__()
9 |
10 | self.model = model
11 | self.sparsity = sparsity
12 | self.train_on_sparse = train_on_sparse
13 | self.layers = []
14 | self.masks = []
15 |
16 | layers = list(self.model.named_parameters())
17 |
18 | for i in range(0, len(layers)):
19 | w = layers[i]
20 | self.layers.append(w[1])
21 |
22 | self.reset_masks()
23 |
24 | def reset_masks(self):
25 |
26 | for w in self.layers:
27 | mask_w = torch.ones_like(w, dtype=bool)
28 | self.masks.append(mask_w)
29 |
30 | return self.masks
31 |
32 | def update_masks(self):
33 |
34 | for i, w in enumerate(self.layers):
35 | q_w = torch.quantile(torch.abs(w), q=self.sparsity)
36 | mask_w = torch.where(torch.abs(w) < q_w, True, False)
37 |
38 | self.masks[i] = mask_w
39 |
40 | def turn_training_mode(self, mode):
41 | if mode == 'dense':
42 | self.train_on_sparse = False
43 | else:
44 | self.train_on_sparse = True
45 | self.update_masks()
46 |
47 | def apply_masks(self):
48 | for w, mask_w in zip(self.layers, self.masks):
49 | w.data[mask_w] = 0
50 | w.grad.data[mask_w] = 0
51 |
52 | def forward(self, x):
53 | return self.model(x)
54 |
--------------------------------------------------------------------------------
/StyleGAN2/apply_factor.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from dsd import *
3 | import torch
4 | from torchvision import utils
5 | print('test')
6 | from model import Generator
7 | print('test')
8 |
9 | if __name__ == "__main__":
10 | print('start')
11 | torch.set_grad_enabled(False)
12 |
13 | parser = argparse.ArgumentParser(description="Apply closed form factorization")
14 |
15 | parser.add_argument(
16 | "-i", "--index", type=int, default=0, help="index of eigenvector"
17 | )
18 | parser.add_argument(
19 | "-d",
20 | "--degree",
21 | type=float,
22 | default=5,
23 | help="scalar factors for moving latent vectors along eigenvector",
24 | )
25 | parser.add_argument(
26 | "--channel_multiplier",
27 | type=int,
28 | default=2,
29 | help='channel multiplier factor. config-f = 2, else = 1',
30 | )
31 | parser.add_argument("--ckpt", type=str, required=True, help="stylegan2 checkpoints")
32 | parser.add_argument(
33 | "--size", type=int, default=256, help="output image size of the generator"
34 | )
35 | parser.add_argument(
36 | "-n", "--n_sample", type=int, default=7, help="number of samples created"
37 | )
38 | parser.add_argument(
39 | "--truncation", type=float, default=0.7, help="truncation factor"
40 | )
41 | parser.add_argument(
42 | "--device", type=str, default="cuda", help="device to run the model"
43 | )
44 | parser.add_argument(
45 | "--out_prefix",
46 | type=str,
47 | default="factor",
48 | help="filename prefix to result samples",
49 | )
50 | parser.add_argument(
51 | "factor",
52 | type=str,
53 | help="name of the closed form factorization result factor file",
54 | )
55 |
56 | args = parser.parse_args()
57 | print('---')
58 |
59 | eigvec = torch.load(args.factor)["eigvec"].to(args.device)
60 | ckpt = torch.load(args.ckpt)
61 | g = Generator(args.size, 512, 8, channel_multiplier=args.channel_multiplier).to(args.device)
62 | g = DSDTraining(g, 0.3)
63 | g.load_state_dict(ckpt["g_ema"], strict=False)
64 |
65 | trunc = g.mean_latent(4096)
66 | latent = torch.randn(args.n_sample, 512, device=args.device)
67 |
68 | latent = g.get_latent(latent)
69 |
70 | for xx in [10, 20, 30, 50, 100]:
71 |
72 | args.index = xx
73 | direction = args.degree * eigvec[:, args.index].unsqueeze(0)
74 |
75 | img, _ = g(
76 | [latent],
77 | truncation=args.truncation,
78 | truncation_latent=trunc,
79 | input_is_latent=True,
80 | )
81 | img1, _ = g(
82 | [latent + direction],
83 | truncation=args.truncation,
84 | truncation_latent=trunc,
85 | input_is_latent=True,
86 | )
87 | img2, _ = g(
88 | [latent - direction],
89 | truncation=args.truncation,
90 | truncation_latent=trunc,
91 | input_is_latent=True,
92 | )
93 |
94 | img3, _ = g(
95 | [latent + 2 * direction],
96 | truncation=args.truncation,
97 | truncation_latent=trunc,
98 | input_is_latent=True,
99 | )
100 |
101 | img4, _ = g(
102 | [latent - 2 * direction],
103 | truncation=args.truncation,
104 | truncation_latent=trunc,
105 | input_is_latent=True,
106 | )
107 | img5, _ = g(
108 | [latent + 3 * direction],
109 | truncation=args.truncation,
110 | truncation_latent=trunc,
111 | input_is_latent=True,
112 | )
113 |
114 | img6, _ = g(
115 | [latent - 3 * direction],
116 | truncation=args.truncation,
117 | truncation_latent=trunc,
118 | input_is_latent=True,
119 | )
120 |
121 | img7, _ = g(
122 | [latent - 4 * direction],
123 | truncation=args.truncation,
124 | truncation_latent=trunc,
125 | input_is_latent=True,
126 | )
127 | # print(torch.cat([img4, img1, img, img2, img3], 0).shape)
128 | # print(img4[0].shape)
129 | # print(img4[0].unsqueeze(0).shape)
130 | # print(torch.cat([img4[0].unsqueeze(0), img1[0].unsqueeze(0),
131 | # img[0].unsqueeze(0),
132 | # img2[0].unsqueeze(0), img3[0].unsqueeze(0)], 0).shape)
133 |
134 | for i in range(args.n_sample):
135 | grid = utils.save_image(
136 | torch.cat([img7[i].unsqueeze(0), img6[i].unsqueeze(0), img4[i].unsqueeze(0), img2[i].unsqueeze(0),
137 | img[i].unsqueeze(0),
138 | img1[i].unsqueeze(0), img3[i].unsqueeze(0), img5[i].unsqueeze(0)], 0),
139 | f"./factor/{args.out_prefix}_index-{args.index}_degree-{args.degree}_{i}.png",
140 | normalize=True,
141 | # range=(-1, 1),
142 | nrow=args.n_sample,
143 | )
144 |
145 | # for i in range(args.n_sample):
146 | # grid = utils.save_image(
147 | # torch.cat([img[i].unsqueeze(0), img1[i].unsqueeze(0), img2[i].unsqueeze(0), img3[i].unsqueeze(0)], 0),
148 | # f"./factor/{args.out_prefix}_index-{args.index}_degree-{args.degree}_{i}.png",
149 | # normalize=True,
150 | # range=(-1, 1),
151 | # nrow=args.n_sample,
152 | # )
153 |
154 | # grid = utils.save_image(
155 | # torch.cat([img4, img1, img, img2, img3], 0),
156 | # f"./factor/{args.out_prefix}_index-{args.index}_degree-{args.degree}.png",
157 | # normalize=True,
158 | # range=(-1, 1),
159 | # nrow=args.n_sample,
160 | # )
161 |
--------------------------------------------------------------------------------
/StyleGAN2/calc_inception.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import pickle
3 | import os
4 |
5 | import torch
6 | from torch import nn
7 | from torch.nn import functional as F
8 | from torch.utils.data import DataLoader
9 | from torchvision import transforms
10 | from torchvision.models import inception_v3, Inception3
11 | import numpy as np
12 | from tqdm import tqdm
13 |
14 | from inception import InceptionV3
15 | from dataset import MultiResolutionDataset
16 |
17 |
18 | class Inception3Feature(Inception3):
19 | def forward(self, x):
20 | if x.shape[2] != 299 or x.shape[3] != 299:
21 | x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=True)
22 |
23 | x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3
24 | x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32
25 | x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32
26 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64
27 |
28 | x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64
29 | x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80
30 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192
31 |
32 | x = self.Mixed_5b(x) # 35 x 35 x 192
33 | x = self.Mixed_5c(x) # 35 x 35 x 256
34 | x = self.Mixed_5d(x) # 35 x 35 x 288
35 |
36 | x = self.Mixed_6a(x) # 35 x 35 x 288
37 | x = self.Mixed_6b(x) # 17 x 17 x 768
38 | x = self.Mixed_6c(x) # 17 x 17 x 768
39 | x = self.Mixed_6d(x) # 17 x 17 x 768
40 | x = self.Mixed_6e(x) # 17 x 17 x 768
41 |
42 | x = self.Mixed_7a(x) # 17 x 17 x 768
43 | x = self.Mixed_7b(x) # 8 x 8 x 1280
44 | x = self.Mixed_7c(x) # 8 x 8 x 2048
45 |
46 | x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048
47 |
48 | return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048
49 |
50 |
51 | def load_patched_inception_v3():
52 | # inception = inception_v3(pretrained=True)
53 | # inception_feat = Inception3Feature()
54 | # inception_feat.load_state_dict(inception.state_dict())
55 | inception_feat = InceptionV3([3], normalize_input=False)
56 |
57 | return inception_feat
58 |
59 |
60 | @torch.no_grad()
61 | def extract_features(loader, inception, device):
62 | pbar = tqdm(loader)
63 |
64 | feature_list = []
65 |
66 | for img in pbar:
67 | img = img.to(device)
68 | feature = inception(img)[0].view(img.shape[0], -1)
69 | feature_list.append(feature.to("cpu"))
70 |
71 | features = torch.cat(feature_list, 0)
72 |
73 | return features
74 |
75 |
76 | if __name__ == "__main__":
77 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78 |
79 | parser = argparse.ArgumentParser(
80 | description="Calculate Inception v3 features for datasets"
81 | )
82 | parser.add_argument(
83 | "--size",
84 | type=int,
85 | default=256,
86 | help="image sizes used for embedding calculation",
87 | )
88 | parser.add_argument(
89 | "--batch", default=64, type=int, help="batch size for inception networks"
90 | )
91 | parser.add_argument(
92 | "--n_sample",
93 | type=int,
94 | default=50000,
95 | help="number of samples used for embedding calculation",
96 | )
97 | parser.add_argument(
98 | "--flip", action="store_true", help="apply random flipping to real images"
99 | )
100 | parser.add_argument("path", metavar="PATH", help="path to datset lmdb file")
101 |
102 | args = parser.parse_args()
103 |
104 | inception = load_patched_inception_v3()
105 | inception = nn.DataParallel(inception).eval().to(device)
106 |
107 | transform = transforms.Compose(
108 | [
109 | transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0),
110 | transforms.ToTensor(),
111 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
112 | ]
113 | )
114 |
115 | dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size)
116 | loader = DataLoader(dset, batch_size=args.batch, num_workers=4)
117 |
118 | features = extract_features(loader, inception, device).numpy()
119 |
120 | features = features[: args.n_sample]
121 |
122 | print(f"extracted {features.shape[0]} features")
123 |
124 | mean = np.mean(features, 0)
125 | cov = np.cov(features, rowvar=False)
126 |
127 | name = os.path.splitext(os.path.basename(args.path))[0]
128 |
129 | with open(f"inception_{name}.pkl", "wb") as f:
130 | pickle.dump({"mean": mean, "cov": cov, "size": args.size, "path": args.path}, f)
131 |
--------------------------------------------------------------------------------
/StyleGAN2/closed_form_factorization.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 |
5 |
6 | if __name__ == "__main__":
7 | parser = argparse.ArgumentParser(
8 | description="Extract factor/eigenvectors of latent spaces using closed form factorization"
9 | )
10 |
11 | parser.add_argument(
12 | "--out", type=str, default="factor.pt", help="name of the result factor file"
13 | )
14 | parser.add_argument("ckpt", type=str, help="name of the model checkpoint")
15 |
16 | args = parser.parse_args()
17 |
18 | ckpt = torch.load(args.ckpt)
19 | modulate = {
20 | k: v
21 | for k, v in ckpt["g_ema"].items()
22 | if "modulation" in k and "to_rgbs" not in k and "weight" in k
23 | }
24 |
25 | weight_mat = []
26 | for k, v in modulate.items():
27 | weight_mat.append(v)
28 |
29 | W = torch.cat(weight_mat, 0)
30 | eigvec = torch.svd(W).V.to("cpu")
31 |
32 | torch.save({"ckpt": args.ckpt, "eigvec": eigvec}, args.out)
33 |
34 |
--------------------------------------------------------------------------------
/StyleGAN2/dataset.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | import lmdb
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class MultiResolutionDataset(Dataset):
9 | def __init__(self, path, transform, resolution=256):
10 | self.env = lmdb.open(
11 | path,
12 | max_readers=32,
13 | readonly=True,
14 | lock=False,
15 | readahead=False,
16 | meminit=False,
17 | )
18 |
19 | if not self.env:
20 | raise IOError('Cannot open lmdb dataset', path)
21 |
22 | with self.env.begin(write=False) as txn:
23 | self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))
24 |
25 | self.resolution = resolution
26 | self.transform = transform
27 |
28 | def __len__(self):
29 | return self.length
30 |
31 | def __getitem__(self, index):
32 | with self.env.begin(write=False) as txn:
33 | key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8')
34 | img_bytes = txn.get(key)
35 |
36 | buffer = BytesIO(img_bytes)
37 | img = Image.open(buffer)
38 | img = self.transform(img)
39 |
40 | return img
41 |
--------------------------------------------------------------------------------
/StyleGAN2/diffaug.py:
--------------------------------------------------------------------------------
1 | # Differentiable Augmentation for Data-Efficient GAN Training
2 | # Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
3 | # https://arxiv.org/pdf/2006.10738
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 |
9 | def DiffAugment(x, policy='', channels_first=True):
10 | if policy:
11 | if not channels_first:
12 | x = x.permute(0, 3, 1, 2)
13 | for p in policy.split(','):
14 | for f in AUGMENT_FNS[p]:
15 | x = f(x)
16 | if not channels_first:
17 | x = x.permute(0, 2, 3, 1)
18 | x = x.contiguous()
19 | return x
20 |
21 |
22 | def rand_brightness(x):
23 | x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
24 | return x
25 |
26 |
27 | def rand_saturation(x):
28 | x_mean = x.mean(dim=1, keepdim=True)
29 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
30 | return x
31 |
32 |
33 | def rand_contrast(x):
34 | x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
35 | x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
36 | return x
37 |
38 |
39 | def rand_translation(x, ratio=0.125):
40 | shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
41 | translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
42 | translation_y = torch.randint(-shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device)
43 | grid_batch, grid_x, grid_y = torch.meshgrid(
44 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
45 | torch.arange(x.size(2), dtype=torch.long, device=x.device),
46 | torch.arange(x.size(3), dtype=torch.long, device=x.device),
47 | )
48 | grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
49 | grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
50 | x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
51 | x = x_pad.permute(0, 2, 3, 1).contiguous()[grid_batch, grid_x, grid_y].permute(0, 3, 1, 2).contiguous()
52 | return x
53 |
54 |
55 | def rand_cutout(x, ratio=0.5):
56 | cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
57 | offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
58 | offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
59 | grid_batch, grid_x, grid_y = torch.meshgrid(
60 | torch.arange(x.size(0), dtype=torch.long, device=x.device),
61 | torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
62 | torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
63 | )
64 | grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
65 | grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
66 | mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
67 | mask[grid_batch, grid_x, grid_y] = 0
68 | x = x * mask.unsqueeze(1)
69 | return x
70 |
71 |
72 | AUGMENT_FNS = {
73 | 'color': [rand_brightness, rand_saturation, rand_contrast],
74 | 'translation': [rand_translation],
75 | 'cutout': [rand_cutout],
76 | }
--------------------------------------------------------------------------------
/StyleGAN2/distributed.py:
--------------------------------------------------------------------------------
1 | import math
2 | import pickle
3 |
4 | import torch
5 | from torch import distributed as dist
6 | from torch.utils.data.sampler import Sampler
7 |
8 |
9 | def get_rank():
10 | if not dist.is_available():
11 | return 0
12 |
13 | if not dist.is_initialized():
14 | return 0
15 |
16 | return dist.get_rank()
17 |
18 |
19 | def synchronize():
20 | if not dist.is_available():
21 | return
22 |
23 | if not dist.is_initialized():
24 | return
25 |
26 | world_size = dist.get_world_size()
27 |
28 | if world_size == 1:
29 | return
30 |
31 | dist.barrier()
32 |
33 |
34 | def get_world_size():
35 | if not dist.is_available():
36 | return 1
37 |
38 | if not dist.is_initialized():
39 | return 1
40 |
41 | return dist.get_world_size()
42 |
43 |
44 | def reduce_sum(tensor):
45 | if not dist.is_available():
46 | return tensor
47 |
48 | if not dist.is_initialized():
49 | return tensor
50 |
51 | tensor = tensor.clone()
52 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
53 |
54 | return tensor
55 |
56 |
57 | def gather_grad(params):
58 | world_size = get_world_size()
59 |
60 | if world_size == 1:
61 | return
62 |
63 | for param in params:
64 | if param.grad is not None:
65 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
66 | param.grad.data.div_(world_size)
67 |
68 |
69 | def all_gather(data):
70 | world_size = get_world_size()
71 |
72 | if world_size == 1:
73 | return [data]
74 |
75 | buffer = pickle.dumps(data)
76 | storage = torch.ByteStorage.from_buffer(buffer)
77 | tensor = torch.ByteTensor(storage).to('cuda')
78 |
79 | local_size = torch.IntTensor([tensor.numel()]).to('cuda')
80 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
81 | dist.all_gather(size_list, local_size)
82 | size_list = [int(size.item()) for size in size_list]
83 | max_size = max(size_list)
84 |
85 | tensor_list = []
86 | for _ in size_list:
87 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
88 |
89 | if local_size != max_size:
90 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
91 | tensor = torch.cat((tensor, padding), 0)
92 |
93 | dist.all_gather(tensor_list, tensor)
94 |
95 | data_list = []
96 |
97 | for size, tensor in zip(size_list, tensor_list):
98 | buffer = tensor.cpu().numpy().tobytes()[:size]
99 | data_list.append(pickle.loads(buffer))
100 |
101 | return data_list
102 |
103 |
104 | def reduce_loss_dict(loss_dict):
105 | world_size = get_world_size()
106 |
107 | if world_size < 2:
108 | return loss_dict
109 |
110 | with torch.no_grad():
111 | keys = []
112 | losses = []
113 |
114 | for k in sorted(loss_dict.keys()):
115 | keys.append(k)
116 | losses.append(loss_dict[k])
117 |
118 | losses = torch.stack(losses, 0)
119 | dist.reduce(losses, dst=0)
120 |
121 | if dist.get_rank() == 0:
122 | losses /= world_size
123 |
124 | reduced_losses = {k: v for k, v in zip(keys, losses)}
125 |
126 | return reduced_losses
127 |
--------------------------------------------------------------------------------
/StyleGAN2/generate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from torchvision import utils
5 | from model import Generator
6 | from tqdm import tqdm
7 |
8 |
9 | def generate(args, g_ema, device, mean_latent):
10 |
11 | with torch.no_grad():
12 | g_ema.eval()
13 | for i in tqdm(range(args.pics)):
14 | sample_z = torch.randn(args.sample, args.latent, device=device)
15 |
16 | sample, _ = g_ema(
17 | [sample_z], truncation=args.truncation, truncation_latent=mean_latent
18 | )
19 |
20 | utils.save_image(
21 | sample,
22 | f"sample/{str(i).zfill(6)}.png",
23 | nrow=1,
24 | normalize=True,
25 | range=(-1, 1),
26 | )
27 |
28 |
29 | if __name__ == "__main__":
30 | device = "cuda"
31 |
32 | parser = argparse.ArgumentParser(description="Generate samples from the generator")
33 |
34 | parser.add_argument(
35 | "--size", type=int, default=1024, help="output image size of the generator"
36 | )
37 | parser.add_argument(
38 | "--sample",
39 | type=int,
40 | default=1,
41 | help="number of samples to be generated for each image",
42 | )
43 | parser.add_argument(
44 | "--pics", type=int, default=20, help="number of images to be generated"
45 | )
46 | parser.add_argument("--truncation", type=float, default=1, help="truncation ratio")
47 | parser.add_argument(
48 | "--truncation_mean",
49 | type=int,
50 | default=4096,
51 | help="number of vectors to calculate mean for the truncation",
52 | )
53 | parser.add_argument(
54 | "--ckpt",
55 | type=str,
56 | default="stylegan2-ffhq-config-f.pt",
57 | help="path to the model checkpoint",
58 | )
59 | parser.add_argument(
60 | "--channel_multiplier",
61 | type=int,
62 | default=2,
63 | help="channel multiplier of the generator. config-f = 2, else = 1",
64 | )
65 |
66 | args = parser.parse_args()
67 |
68 | args.latent = 512
69 | args.n_mlp = 8
70 |
71 | g_ema = Generator(
72 | args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier
73 | ).to(device)
74 | checkpoint = torch.load(args.ckpt)
75 |
76 | g_ema.load_state_dict(checkpoint["g_ema"])
77 |
78 | if args.truncation < 1:
79 | with torch.no_grad():
80 | mean_latent = g_ema.mean_latent(args.truncation_mean)
81 | else:
82 | mean_latent = None
83 |
84 | generate(args, g_ema, device, mean_latent)
85 |
--------------------------------------------------------------------------------
/StyleGAN2/lpips/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import numpy as np
7 | from skimage.measure import compare_ssim
8 | import torch
9 | from torch.autograd import Variable
10 |
11 | from lpips import dist_model
12 |
13 | class PerceptualLoss(torch.nn.Module):
14 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric)
15 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
16 | super(PerceptualLoss, self).__init__()
17 | print('Setting up Perceptual loss...')
18 | self.use_gpu = use_gpu
19 | self.spatial = spatial
20 | self.gpu_ids = gpu_ids
21 | self.model = dist_model.DistModel()
22 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids)
23 | print('...[%s] initialized'%self.model.name())
24 | print('...Done')
25 |
26 | def forward(self, pred, target, normalize=False):
27 | """
28 | Pred and target are Variables.
29 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
30 | If normalize is False, assumes the images are already between [-1,+1]
31 |
32 | Inputs pred and target are Nx3xHxW
33 | Output pytorch Variable N long
34 | """
35 |
36 | if normalize:
37 | target = 2 * target - 1
38 | pred = 2 * pred - 1
39 |
40 | return self.model.forward(target, pred)
41 |
42 | def normalize_tensor(in_feat,eps=1e-10):
43 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
44 | return in_feat/(norm_factor+eps)
45 |
46 | def l2(p0, p1, range=255.):
47 | return .5*np.mean((p0 / range - p1 / range)**2)
48 |
49 | def psnr(p0, p1, peak=255.):
50 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
51 |
52 | def dssim(p0, p1, range=255.):
53 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
54 |
55 | def rgb2lab(in_img,mean_cent=False):
56 | from skimage import color
57 | img_lab = color.rgb2lab(in_img)
58 | if(mean_cent):
59 | img_lab[:,:,0] = img_lab[:,:,0]-50
60 | return img_lab
61 |
62 | def tensor2np(tensor_obj):
63 | # change dimension of a tensor object into a numpy array
64 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
65 |
66 | def np2tensor(np_obj):
67 | # change dimenion of np array into tensor array
68 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
69 |
70 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
71 | # image tensor to lab tensor
72 | from skimage import color
73 |
74 | img = tensor2im(image_tensor)
75 | img_lab = color.rgb2lab(img)
76 | if(mc_only):
77 | img_lab[:,:,0] = img_lab[:,:,0]-50
78 | if(to_norm and not mc_only):
79 | img_lab[:,:,0] = img_lab[:,:,0]-50
80 | img_lab = img_lab/100.
81 |
82 | return np2tensor(img_lab)
83 |
84 | def tensorlab2tensor(lab_tensor,return_inbnd=False):
85 | from skimage import color
86 | import warnings
87 | warnings.filterwarnings("ignore")
88 |
89 | lab = tensor2np(lab_tensor)*100.
90 | lab[:,:,0] = lab[:,:,0]+50
91 |
92 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
93 | if(return_inbnd):
94 | # convert back to lab, see if we match
95 | lab_back = color.rgb2lab(rgb_back.astype('uint8'))
96 | mask = 1.*np.isclose(lab_back,lab,atol=2.)
97 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
98 | return (im2tensor(rgb_back),mask)
99 | else:
100 | return im2tensor(rgb_back)
101 |
102 | def rgb2lab(input):
103 | from skimage import color
104 | return color.rgb2lab(input / 255.)
105 |
106 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
107 | image_numpy = image_tensor[0].cpu().float().numpy()
108 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
109 | return image_numpy.astype(imtype)
110 |
111 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
112 | return torch.Tensor((image / factor - cent)
113 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
114 |
115 | def tensor2vec(vector_tensor):
116 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
117 |
118 | def voc_ap(rec, prec, use_07_metric=False):
119 | """ ap = voc_ap(rec, prec, [use_07_metric])
120 | Compute VOC AP given precision and recall.
121 | If use_07_metric is true, uses the
122 | VOC 07 11 point method (default:False).
123 | """
124 | if use_07_metric:
125 | # 11 point metric
126 | ap = 0.
127 | for t in np.arange(0., 1.1, 0.1):
128 | if np.sum(rec >= t) == 0:
129 | p = 0
130 | else:
131 | p = np.max(prec[rec >= t])
132 | ap = ap + p / 11.
133 | else:
134 | # correct AP calculation
135 | # first append sentinel values at the end
136 | mrec = np.concatenate(([0.], rec, [1.]))
137 | mpre = np.concatenate(([0.], prec, [0.]))
138 |
139 | # compute the precision envelope
140 | for i in range(mpre.size - 1, 0, -1):
141 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
142 |
143 | # to calculate area under PR curve, look for points
144 | # where X axis (recall) changes value
145 | i = np.where(mrec[1:] != mrec[:-1])[0]
146 |
147 | # and sum (\Delta recall) * prec
148 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
149 | return ap
150 |
151 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
152 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
153 | image_numpy = image_tensor[0].cpu().float().numpy()
154 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
155 | return image_numpy.astype(imtype)
156 |
157 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
158 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
159 | return torch.Tensor((image / factor - cent)
160 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
161 |
--------------------------------------------------------------------------------
/StyleGAN2/lpips/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | from torch.autograd import Variable
5 | from pdb import set_trace as st
6 | from IPython import embed
7 |
8 | class BaseModel():
9 | def __init__(self):
10 | pass;
11 |
12 | def name(self):
13 | return 'BaseModel'
14 |
15 | def initialize(self, use_gpu=True, gpu_ids=[0]):
16 | self.use_gpu = use_gpu
17 | self.gpu_ids = gpu_ids
18 |
19 | def forward(self):
20 | pass
21 |
22 | def get_image_paths(self):
23 | pass
24 |
25 | def optimize_parameters(self):
26 | pass
27 |
28 | def get_current_visuals(self):
29 | return self.input
30 |
31 | def get_current_errors(self):
32 | return {}
33 |
34 | def save(self, label):
35 | pass
36 |
37 | # helper saving function that can be used by subclasses
38 | def save_network(self, network, path, network_label, epoch_label):
39 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
40 | save_path = os.path.join(path, save_filename)
41 | torch.save(network.state_dict(), save_path)
42 |
43 | # helper loading function that can be used by subclasses
44 | def load_network(self, network, network_label, epoch_label):
45 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
46 | save_path = os.path.join(self.save_dir, save_filename)
47 | print('Loading network from %s'%save_path)
48 | network.load_state_dict(torch.load(save_path))
49 |
50 | def update_learning_rate():
51 | pass
52 |
53 | def get_image_paths(self):
54 | return self.image_paths
55 |
56 | def save_done(self, flag=False):
57 | np.save(os.path.join(self.save_dir, 'done_flag'),flag)
58 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
59 |
--------------------------------------------------------------------------------
/StyleGAN2/lpips/networks_basic.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | import sys
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.init as init
8 | from torch.autograd import Variable
9 | import numpy as np
10 | from pdb import set_trace as st
11 | from skimage import color
12 | from IPython import embed
13 | from . import pretrained_networks as pn
14 |
15 | import lpips as util
16 |
17 | def spatial_average(in_tens, keepdim=True):
18 | return in_tens.mean([2,3],keepdim=keepdim)
19 |
20 | def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W
21 | in_H = in_tens.shape[2]
22 | scale_factor = 1.*out_H/in_H
23 |
24 | return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens)
25 |
26 | # Learned perceptual metric
27 | class PNetLin(nn.Module):
28 | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True):
29 | super(PNetLin, self).__init__()
30 |
31 | self.pnet_type = pnet_type
32 | self.pnet_tune = pnet_tune
33 | self.pnet_rand = pnet_rand
34 | self.spatial = spatial
35 | self.lpips = lpips
36 | self.version = version
37 | self.scaling_layer = ScalingLayer()
38 |
39 | if(self.pnet_type in ['vgg','vgg16']):
40 | net_type = pn.vgg16
41 | self.chns = [64,128,256,512,512]
42 | elif(self.pnet_type=='alex'):
43 | net_type = pn.alexnet
44 | self.chns = [64,192,384,256,256]
45 | elif(self.pnet_type=='squeeze'):
46 | net_type = pn.squeezenet
47 | self.chns = [64,128,256,384,384,512,512]
48 | self.L = len(self.chns)
49 |
50 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
51 |
52 | if(lpips):
53 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
54 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
55 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
56 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
57 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
58 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
59 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
60 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
61 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
62 | self.lins+=[self.lin5,self.lin6]
63 |
64 | def forward(self, in0, in1, retPerLayer=False):
65 | # v0.0 - original release had a bug, where input was not scaled
66 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
67 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
68 | feats0, feats1, diffs = {}, {}, {}
69 |
70 | for kk in range(self.L):
71 | feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk])
72 | diffs[kk] = (feats0[kk]-feats1[kk])**2
73 |
74 | if(self.lpips):
75 | if(self.spatial):
76 | res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)]
77 | else:
78 | res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)]
79 | else:
80 | if(self.spatial):
81 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)]
82 | else:
83 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
84 |
85 | val = res[0]
86 | for l in range(1,self.L):
87 | val += res[l]
88 |
89 | if(retPerLayer):
90 | return (val, res)
91 | else:
92 | return val
93 |
94 | class ScalingLayer(nn.Module):
95 | def __init__(self):
96 | super(ScalingLayer, self).__init__()
97 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
98 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
99 |
100 | def forward(self, inp):
101 | return (inp - self.shift) / self.scale
102 |
103 |
104 | class NetLinLayer(nn.Module):
105 | ''' A single linear layer which does a 1x1 conv '''
106 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
107 | super(NetLinLayer, self).__init__()
108 |
109 | layers = [nn.Dropout(),] if(use_dropout) else []
110 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
111 | self.model = nn.Sequential(*layers)
112 |
113 |
114 | class Dist2LogitLayer(nn.Module):
115 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
116 | def __init__(self, chn_mid=32, use_sigmoid=True):
117 | super(Dist2LogitLayer, self).__init__()
118 |
119 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
120 | layers += [nn.LeakyReLU(0.2,True),]
121 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
122 | layers += [nn.LeakyReLU(0.2,True),]
123 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
124 | if(use_sigmoid):
125 | layers += [nn.Sigmoid(),]
126 | self.model = nn.Sequential(*layers)
127 |
128 | def forward(self,d0,d1,eps=0.1):
129 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
130 |
131 | class BCERankingLoss(nn.Module):
132 | def __init__(self, chn_mid=32):
133 | super(BCERankingLoss, self).__init__()
134 | self.net = Dist2LogitLayer(chn_mid=chn_mid)
135 | # self.parameters = list(self.net.parameters())
136 | self.loss = torch.nn.BCELoss()
137 |
138 | def forward(self, d0, d1, judge):
139 | per = (judge+1.)/2.
140 | self.logit = self.net.forward(d0,d1)
141 | return self.loss(self.logit, per)
142 |
143 | # L2, DSSIM metrics
144 | class FakeNet(nn.Module):
145 | def __init__(self, use_gpu=True, colorspace='Lab'):
146 | super(FakeNet, self).__init__()
147 | self.use_gpu = use_gpu
148 | self.colorspace=colorspace
149 |
150 | class L2(FakeNet):
151 |
152 | def forward(self, in0, in1, retPerLayer=None):
153 | assert(in0.size()[0]==1) # currently only supports batchSize 1
154 |
155 | if(self.colorspace=='RGB'):
156 | (N,C,X,Y) = in0.size()
157 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
158 | return value
159 | elif(self.colorspace=='Lab'):
160 | value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
161 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
162 | ret_var = Variable( torch.Tensor((value,) ) )
163 | if(self.use_gpu):
164 | ret_var = ret_var.cuda()
165 | return ret_var
166 |
167 | class DSSIM(FakeNet):
168 |
169 | def forward(self, in0, in1, retPerLayer=None):
170 | assert(in0.size()[0]==1) # currently only supports batchSize 1
171 |
172 | if(self.colorspace=='RGB'):
173 | value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float')
174 | elif(self.colorspace=='Lab'):
175 | value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)),
176 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
177 | ret_var = Variable( torch.Tensor((value,) ) )
178 | if(self.use_gpu):
179 | ret_var = ret_var.cuda()
180 | return ret_var
181 |
182 | def print_network(net):
183 | num_params = 0
184 | for param in net.parameters():
185 | num_params += param.numel()
186 | print('Network',net)
187 | print('Total number of parameters: %d' % num_params)
188 |
--------------------------------------------------------------------------------
/StyleGAN2/lpips/pretrained_networks.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import torch
3 | from torchvision import models as tv
4 | from IPython import embed
5 |
6 | class squeezenet(torch.nn.Module):
7 | def __init__(self, requires_grad=False, pretrained=True):
8 | super(squeezenet, self).__init__()
9 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
10 | self.slice1 = torch.nn.Sequential()
11 | self.slice2 = torch.nn.Sequential()
12 | self.slice3 = torch.nn.Sequential()
13 | self.slice4 = torch.nn.Sequential()
14 | self.slice5 = torch.nn.Sequential()
15 | self.slice6 = torch.nn.Sequential()
16 | self.slice7 = torch.nn.Sequential()
17 | self.N_slices = 7
18 | for x in range(2):
19 | self.slice1.add_module(str(x), pretrained_features[x])
20 | for x in range(2,5):
21 | self.slice2.add_module(str(x), pretrained_features[x])
22 | for x in range(5, 8):
23 | self.slice3.add_module(str(x), pretrained_features[x])
24 | for x in range(8, 10):
25 | self.slice4.add_module(str(x), pretrained_features[x])
26 | for x in range(10, 11):
27 | self.slice5.add_module(str(x), pretrained_features[x])
28 | for x in range(11, 12):
29 | self.slice6.add_module(str(x), pretrained_features[x])
30 | for x in range(12, 13):
31 | self.slice7.add_module(str(x), pretrained_features[x])
32 | if not requires_grad:
33 | for param in self.parameters():
34 | param.requires_grad = False
35 |
36 | def forward(self, X):
37 | h = self.slice1(X)
38 | h_relu1 = h
39 | h = self.slice2(h)
40 | h_relu2 = h
41 | h = self.slice3(h)
42 | h_relu3 = h
43 | h = self.slice4(h)
44 | h_relu4 = h
45 | h = self.slice5(h)
46 | h_relu5 = h
47 | h = self.slice6(h)
48 | h_relu6 = h
49 | h = self.slice7(h)
50 | h_relu7 = h
51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
52 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
53 |
54 | return out
55 |
56 |
57 | class alexnet(torch.nn.Module):
58 | def __init__(self, requires_grad=False, pretrained=True):
59 | super(alexnet, self).__init__()
60 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
61 | self.slice1 = torch.nn.Sequential()
62 | self.slice2 = torch.nn.Sequential()
63 | self.slice3 = torch.nn.Sequential()
64 | self.slice4 = torch.nn.Sequential()
65 | self.slice5 = torch.nn.Sequential()
66 | self.N_slices = 5
67 | for x in range(2):
68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x])
69 | for x in range(2, 5):
70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x])
71 | for x in range(5, 8):
72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x])
73 | for x in range(8, 10):
74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x])
75 | for x in range(10, 12):
76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x])
77 | if not requires_grad:
78 | for param in self.parameters():
79 | param.requires_grad = False
80 |
81 | def forward(self, X):
82 | h = self.slice1(X)
83 | h_relu1 = h
84 | h = self.slice2(h)
85 | h_relu2 = h
86 | h = self.slice3(h)
87 | h_relu3 = h
88 | h = self.slice4(h)
89 | h_relu4 = h
90 | h = self.slice5(h)
91 | h_relu5 = h
92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
94 |
95 | return out
96 |
97 | class vgg16(torch.nn.Module):
98 | def __init__(self, requires_grad=False, pretrained=True):
99 | super(vgg16, self).__init__()
100 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
101 | self.slice1 = torch.nn.Sequential()
102 | self.slice2 = torch.nn.Sequential()
103 | self.slice3 = torch.nn.Sequential()
104 | self.slice4 = torch.nn.Sequential()
105 | self.slice5 = torch.nn.Sequential()
106 | self.N_slices = 5
107 | for x in range(4):
108 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
109 | for x in range(4, 9):
110 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
111 | for x in range(9, 16):
112 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
113 | for x in range(16, 23):
114 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
115 | for x in range(23, 30):
116 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
117 | if not requires_grad:
118 | for param in self.parameters():
119 | param.requires_grad = False
120 |
121 | def forward(self, X):
122 | h = self.slice1(X)
123 | h_relu1_2 = h
124 | h = self.slice2(h)
125 | h_relu2_2 = h
126 | h = self.slice3(h)
127 | h_relu3_3 = h
128 | h = self.slice4(h)
129 | h_relu4_3 = h
130 | h = self.slice5(h)
131 | h_relu5_3 = h
132 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
133 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
134 |
135 | return out
136 |
137 |
138 |
139 | class resnet(torch.nn.Module):
140 | def __init__(self, requires_grad=False, pretrained=True, num=18):
141 | super(resnet, self).__init__()
142 | if(num==18):
143 | self.net = tv.resnet18(pretrained=pretrained)
144 | elif(num==34):
145 | self.net = tv.resnet34(pretrained=pretrained)
146 | elif(num==50):
147 | self.net = tv.resnet50(pretrained=pretrained)
148 | elif(num==101):
149 | self.net = tv.resnet101(pretrained=pretrained)
150 | elif(num==152):
151 | self.net = tv.resnet152(pretrained=pretrained)
152 | self.N_slices = 5
153 |
154 | self.conv1 = self.net.conv1
155 | self.bn1 = self.net.bn1
156 | self.relu = self.net.relu
157 | self.maxpool = self.net.maxpool
158 | self.layer1 = self.net.layer1
159 | self.layer2 = self.net.layer2
160 | self.layer3 = self.net.layer3
161 | self.layer4 = self.net.layer4
162 |
163 | def forward(self, X):
164 | h = self.conv1(X)
165 | h = self.bn1(h)
166 | h = self.relu(h)
167 | h_relu1 = h
168 | h = self.maxpool(h)
169 | h = self.layer1(h)
170 | h_conv2 = h
171 | h = self.layer2(h)
172 | h_conv3 = h
173 | h = self.layer3(h)
174 | h_conv4 = h
175 | h = self.layer4(h)
176 | h_conv5 = h
177 |
178 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
180 |
181 | return out
182 |
--------------------------------------------------------------------------------
/StyleGAN2/lpips/weights/v0.0/alex.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntellicentAI-Lab/Re-GAN/785135c820777149e8a4da941ee369b7d994f951/StyleGAN2/lpips/weights/v0.0/alex.pth
--------------------------------------------------------------------------------
/StyleGAN2/lpips/weights/v0.0/squeeze.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntellicentAI-Lab/Re-GAN/785135c820777149e8a4da941ee369b7d994f951/StyleGAN2/lpips/weights/v0.0/squeeze.pth
--------------------------------------------------------------------------------
/StyleGAN2/lpips/weights/v0.0/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntellicentAI-Lab/Re-GAN/785135c820777149e8a4da941ee369b7d994f951/StyleGAN2/lpips/weights/v0.0/vgg.pth
--------------------------------------------------------------------------------
/StyleGAN2/lpips/weights/v0.1/alex.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntellicentAI-Lab/Re-GAN/785135c820777149e8a4da941ee369b7d994f951/StyleGAN2/lpips/weights/v0.1/alex.pth
--------------------------------------------------------------------------------
/StyleGAN2/lpips/weights/v0.1/squeeze.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntellicentAI-Lab/Re-GAN/785135c820777149e8a4da941ee369b7d994f951/StyleGAN2/lpips/weights/v0.1/squeeze.pth
--------------------------------------------------------------------------------
/StyleGAN2/lpips/weights/v0.1/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntellicentAI-Lab/Re-GAN/785135c820777149e8a4da941ee369b7d994f951/StyleGAN2/lpips/weights/v0.1/vgg.pth
--------------------------------------------------------------------------------
/StyleGAN2/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/StyleGAN2/op/conv2d_gradfix.py:
--------------------------------------------------------------------------------
1 | # print('[[')
2 | import contextlib
3 | import warnings
4 | import torch
5 | from torch import autograd
6 | from torch.nn import functional as F
7 | enabled = True
8 | weight_gradients_disabled = False
9 |
10 |
11 | @contextlib.contextmanager
12 | def no_weight_gradients():
13 | global weight_gradients_disabled
14 |
15 | old = weight_gradients_disabled
16 | weight_gradients_disabled = True
17 | yield
18 | weight_gradients_disabled = old
19 |
20 |
21 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
22 | if could_use_op(input):
23 | return conv2d_gradfix(
24 | transpose=False,
25 | weight_shape=weight.shape,
26 | stride=stride,
27 | padding=padding,
28 | output_padding=0,
29 | dilation=dilation,
30 | groups=groups,
31 | ).apply(input, weight, bias)
32 |
33 | return F.conv2d(
34 | input=input,
35 | weight=weight,
36 | bias=bias,
37 | stride=stride,
38 | padding=padding,
39 | dilation=dilation,
40 | groups=groups,
41 | )
42 |
43 |
44 | def conv_transpose2d(
45 | input,
46 | weight,
47 | bias=None,
48 | stride=1,
49 | padding=0,
50 | output_padding=0,
51 | groups=1,
52 | dilation=1,
53 | ):
54 | if could_use_op(input):
55 | return conv2d_gradfix(
56 | transpose=True,
57 | weight_shape=weight.shape,
58 | stride=stride,
59 | padding=padding,
60 | output_padding=output_padding,
61 | groups=groups,
62 | dilation=dilation,
63 | ).apply(input, weight, bias)
64 |
65 | return F.conv_transpose2d(
66 | input=input,
67 | weight=weight,
68 | bias=bias,
69 | stride=stride,
70 | padding=padding,
71 | output_padding=output_padding,
72 | dilation=dilation,
73 | groups=groups,
74 | )
75 |
76 |
77 | def could_use_op(input):
78 | if (not enabled) or (not torch.backends.cudnn.enabled):
79 | return False
80 |
81 | if input.device.type != "cuda":
82 | return False
83 |
84 | if any(torch.__version__.startswith(x) for x in ["1.7.", "1.8."]):
85 | return True
86 |
87 | warnings.warn(
88 | f"conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d()."
89 | )
90 |
91 | return False
92 |
93 |
94 | def ensure_tuple(xs, ndim):
95 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim
96 |
97 | return xs
98 |
99 |
100 | conv2d_gradfix_cache = dict()
101 |
102 |
103 | def conv2d_gradfix(
104 | transpose, weight_shape, stride, padding, output_padding, dilation, groups
105 | ):
106 | ndim = 2
107 | weight_shape = tuple(weight_shape)
108 | stride = ensure_tuple(stride, ndim)
109 | padding = ensure_tuple(padding, ndim)
110 | output_padding = ensure_tuple(output_padding, ndim)
111 | dilation = ensure_tuple(dilation, ndim)
112 |
113 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups)
114 | if key in conv2d_gradfix_cache:
115 | return conv2d_gradfix_cache[key]
116 |
117 | common_kwargs = dict(
118 | stride=stride, padding=padding, dilation=dilation, groups=groups
119 | )
120 |
121 | def calc_output_padding(input_shape, output_shape):
122 | if transpose:
123 | return [0, 0]
124 |
125 | return [
126 | input_shape[i + 2]
127 | - (output_shape[i + 2] - 1) * stride[i]
128 | - (1 - 2 * padding[i])
129 | - dilation[i] * (weight_shape[i + 2] - 1)
130 | for i in range(ndim)
131 | ]
132 |
133 | class Conv2d(autograd.Function):
134 | @staticmethod
135 | def forward(ctx, input, weight, bias):
136 | if not transpose:
137 | out = F.conv2d(input=input, weight=weight, bias=bias, **common_kwargs)
138 |
139 | else:
140 | out = F.conv_transpose2d(
141 | input=input,
142 | weight=weight,
143 | bias=bias,
144 | output_padding=output_padding,
145 | **common_kwargs,
146 | )
147 |
148 | ctx.save_for_backward(input, weight)
149 |
150 | return out
151 |
152 | @staticmethod
153 | def backward(ctx, grad_output):
154 | input, weight = ctx.saved_tensors
155 | grad_input, grad_weight, grad_bias = None, None, None
156 |
157 | if ctx.needs_input_grad[0]:
158 | p = calc_output_padding(
159 | input_shape=input.shape, output_shape=grad_output.shape
160 | )
161 | grad_input = conv2d_gradfix(
162 | transpose=(not transpose),
163 | weight_shape=weight_shape,
164 | output_padding=p,
165 | **common_kwargs,
166 | ).apply(grad_output, weight, None)
167 |
168 | if ctx.needs_input_grad[1] and not weight_gradients_disabled:
169 | grad_weight = Conv2dGradWeight.apply(grad_output, input)
170 |
171 | if ctx.needs_input_grad[2]:
172 | grad_bias = grad_output.sum((0, 2, 3))
173 |
174 | return grad_input, grad_weight, grad_bias
175 |
176 | class Conv2dGradWeight(autograd.Function):
177 | @staticmethod
178 | def forward(ctx, grad_output, input):
179 | op = torch._C._jit_get_operation(
180 | "aten::cudnn_convolution_backward_weight"
181 | if not transpose
182 | else "aten::cudnn_convolution_transpose_backward_weight"
183 | )
184 | flags = [
185 | torch.backends.cudnn.benchmark,
186 | torch.backends.cudnn.deterministic,
187 | torch.backends.cudnn.allow_tf32,
188 | ]
189 | grad_weight = op(
190 | weight_shape,
191 | grad_output,
192 | input,
193 | padding,
194 | stride,
195 | dilation,
196 | groups,
197 | *flags,
198 | )
199 | ctx.save_for_backward(grad_output, input)
200 |
201 | return grad_weight
202 |
203 | @staticmethod
204 | def backward(ctx, grad_grad_weight):
205 | grad_output, input = ctx.saved_tensors
206 | grad_grad_output, grad_grad_input = None, None
207 |
208 | if ctx.needs_input_grad[0]:
209 | grad_grad_output = Conv2d.apply(input, grad_grad_weight, None)
210 |
211 | if ctx.needs_input_grad[1]:
212 | p = calc_output_padding(
213 | input_shape=input.shape, output_shape=grad_output.shape
214 | )
215 | grad_grad_input = conv2d_gradfix(
216 | transpose=(not transpose),
217 | weight_shape=weight_shape,
218 | output_padding=p,
219 | **common_kwargs,
220 | ).apply(grad_output, grad_grad_weight, None)
221 |
222 | return grad_grad_output, grad_grad_input
223 |
224 | conv2d_gradfix_cache[key] = Conv2d
225 |
226 | return Conv2d
227 |
--------------------------------------------------------------------------------
/StyleGAN2/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from torch.autograd import Function
6 | # from . import fused
7 | from torch.utils.cpp_extension import load
8 |
9 |
10 |
11 | module_path = os.path.dirname(__file__)
12 |
13 | fused = load(
14 | "fused",
15 | sources=[
16 | os.path.join(module_path, "fused_bias_act.cpp"),
17 | os.path.join(module_path, "fused_bias_act_kernel.cu"),
18 | ],
19 | )
20 | # print('....')
21 |
22 |
23 | class FusedLeakyReLUFunctionBackward(Function):
24 | @staticmethod
25 | def forward(ctx, grad_output, out, bias, negative_slope, scale):
26 | ctx.save_for_backward(out)
27 | ctx.negative_slope = negative_slope
28 | ctx.scale = scale
29 |
30 | empty = grad_output.new_empty(0)
31 |
32 | grad_input = fused.fused_bias_act(
33 | grad_output.contiguous(), empty, out, 3, 1, negative_slope, scale
34 | )
35 |
36 | dim = [0]
37 |
38 | if grad_input.ndim > 2:
39 | dim += list(range(2, grad_input.ndim))
40 |
41 | if bias:
42 | grad_bias = grad_input.sum(dim).detach()
43 |
44 | else:
45 | grad_bias = empty
46 |
47 | return grad_input, grad_bias
48 |
49 | @staticmethod
50 | def backward(ctx, gradgrad_input, gradgrad_bias):
51 | out, = ctx.saved_tensors
52 | gradgrad_out = fused.fused_bias_act(
53 | gradgrad_input.contiguous(),
54 | gradgrad_bias,
55 | out,
56 | 3,
57 | 1,
58 | ctx.negative_slope,
59 | ctx.scale,
60 | )
61 |
62 | return gradgrad_out, None, None, None, None
63 |
64 |
65 | class FusedLeakyReLUFunction(Function):
66 | @staticmethod
67 | def forward(ctx, input, bias, negative_slope, scale):
68 | empty = input.new_empty(0)
69 |
70 | ctx.bias = bias is not None
71 |
72 | if bias is None:
73 | bias = empty
74 |
75 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
76 | ctx.save_for_backward(out)
77 | ctx.negative_slope = negative_slope
78 | ctx.scale = scale
79 |
80 | return out
81 |
82 | @staticmethod
83 | def backward(ctx, grad_output):
84 | out, = ctx.saved_tensors
85 |
86 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
87 | grad_output, out, ctx.bias, ctx.negative_slope, ctx.scale
88 | )
89 |
90 | if not ctx.bias:
91 | grad_bias = None
92 |
93 | return grad_input, grad_bias, None, None
94 |
95 |
96 | class FusedLeakyReLU(nn.Module):
97 | def __init__(self, channel, bias=True, negative_slope=0.2, scale=2 ** 0.5):
98 | super().__init__()
99 |
100 | if bias:
101 | self.bias = nn.Parameter(torch.zeros(channel))
102 |
103 | else:
104 | self.bias = None
105 |
106 | self.negative_slope = negative_slope
107 | self.scale = scale
108 |
109 | def forward(self, input):
110 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
111 |
112 |
113 | def fused_leaky_relu(input, bias=None, negative_slope=0.2, scale=2 ** 0.5):
114 | if input.device.type == "cpu":
115 | if bias is not None:
116 | rest_dim = [1] * (input.ndim - bias.ndim - 1)
117 | return (
118 | F.leaky_relu(
119 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2
120 | )
121 | * scale
122 | )
123 |
124 | else:
125 | return F.leaky_relu(input, negative_slope=0.2) * scale
126 |
127 | else:
128 | return FusedLeakyReLUFunction.apply(
129 | input.contiguous(), bias, negative_slope, scale
130 | )
131 |
--------------------------------------------------------------------------------
/StyleGAN2/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 |
2 | #include
3 | #include
4 |
5 | torch::Tensor fused_bias_act_op(const torch::Tensor &input,
6 | const torch::Tensor &bias,
7 | const torch::Tensor &refer, int act, int grad,
8 | float alpha, float scale);
9 |
10 | #define CHECK_CUDA(x) \
11 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
12 | #define CHECK_CONTIGUOUS(x) \
13 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
14 | #define CHECK_INPUT(x) \
15 | CHECK_CUDA(x); \
16 | CHECK_CONTIGUOUS(x)
17 |
18 | torch::Tensor fused_bias_act(const torch::Tensor &input,
19 | const torch::Tensor &bias,
20 | const torch::Tensor &refer, int act, int grad,
21 | float alpha, float scale) {
22 | CHECK_INPUT(input);
23 | CHECK_INPUT(bias);
24 |
25 | at::DeviceGuard guard(input.device());
26 |
27 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
28 | }
29 |
30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
32 | }
--------------------------------------------------------------------------------
/StyleGAN2/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 |
15 | #include
16 | #include
17 |
18 | template
19 | static __global__ void
20 | fused_bias_act_kernel(scalar_t *out, const scalar_t *p_x, const scalar_t *p_b,
21 | const scalar_t *p_ref, int act, int grad, scalar_t alpha,
22 | scalar_t scale, int loop_x, int size_x, int step_b,
23 | int size_b, int use_bias, int use_ref) {
24 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
25 |
26 | scalar_t zero = 0.0;
27 |
28 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x;
29 | loop_idx++, xi += blockDim.x) {
30 | scalar_t x = p_x[xi];
31 |
32 | if (use_bias) {
33 | x += p_b[(xi / step_b) % size_b];
34 | }
35 |
36 | scalar_t ref = use_ref ? p_ref[xi] : zero;
37 |
38 | scalar_t y;
39 |
40 | switch (act * 10 + grad) {
41 | default:
42 | case 10:
43 | y = x;
44 | break;
45 | case 11:
46 | y = x;
47 | break;
48 | case 12:
49 | y = 0.0;
50 | break;
51 |
52 | case 30:
53 | y = (x > 0.0) ? x : x * alpha;
54 | break;
55 | case 31:
56 | y = (ref > 0.0) ? x : x * alpha;
57 | break;
58 | case 32:
59 | y = 0.0;
60 | break;
61 | }
62 |
63 | out[xi] = y * scale;
64 | }
65 | }
66 |
67 | torch::Tensor fused_bias_act_op(const torch::Tensor &input,
68 | const torch::Tensor &bias,
69 | const torch::Tensor &refer, int act, int grad,
70 | float alpha, float scale) {
71 | int curDevice = -1;
72 | cudaGetDevice(&curDevice);
73 | cudaStream_t stream = at::cuda::getCurrentCUDAStream();
74 |
75 | auto x = input.contiguous();
76 | auto b = bias.contiguous();
77 | auto ref = refer.contiguous();
78 |
79 | int use_bias = b.numel() ? 1 : 0;
80 | int use_ref = ref.numel() ? 1 : 0;
81 |
82 | int size_x = x.numel();
83 | int size_b = b.numel();
84 | int step_b = 1;
85 |
86 | for (int i = 1 + 1; i < x.dim(); i++) {
87 | step_b *= x.size(i);
88 | }
89 |
90 | int loop_x = 4;
91 | int block_size = 4 * 32;
92 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
93 |
94 | auto y = torch::empty_like(x);
95 |
96 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(
97 | x.scalar_type(), "fused_bias_act_kernel", [&] {
98 | fused_bias_act_kernel<<>>(
99 | y.data_ptr(), x.data_ptr(),
100 | b.data_ptr(), ref.data_ptr(), act, grad, alpha,
101 | scale, loop_x, size_x, step_b, size_b, use_bias, use_ref);
102 | });
103 |
104 | return y;
105 | }
--------------------------------------------------------------------------------
/StyleGAN2/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor &input,
5 | const torch::Tensor &kernel, int up_x, int up_y,
6 | int down_x, int down_y, int pad_x0, int pad_x1,
7 | int pad_y0, int pad_y1);
8 |
9 | #define CHECK_CUDA(x) \
10 | TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
11 | #define CHECK_CONTIGUOUS(x) \
12 | TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
13 | #define CHECK_INPUT(x) \
14 | CHECK_CUDA(x); \
15 | CHECK_CONTIGUOUS(x)
16 |
17 | torch::Tensor upfirdn2d(const torch::Tensor &input, const torch::Tensor &kernel,
18 | int up_x, int up_y, int down_x, int down_y, int pad_x0,
19 | int pad_x1, int pad_y0, int pad_y1) {
20 | CHECK_INPUT(input);
21 | CHECK_INPUT(kernel);
22 |
23 | at::DeviceGuard guard(input.device());
24 |
25 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
26 | pad_y0, pad_y1);
27 | }
28 |
29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
30 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
31 | }
--------------------------------------------------------------------------------
/StyleGAN2/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | from collections import abc
2 | import os
3 |
4 | import torch
5 | from torch.nn import functional as F
6 | from torch.autograd import Function
7 | # from . import upfirdn2d as upfirdn2d_op
8 | from torch.utils.cpp_extension import load
9 |
10 |
11 | module_path = os.path.dirname(__file__)
12 | upfirdn2d_op = load(
13 | "upfirdn2d",
14 | sources=[
15 | os.path.join(module_path, "upfirdn2d.cpp"),
16 | os.path.join(module_path, "upfirdn2d_kernel.cu"),
17 | ],
18 | )
19 |
20 |
21 | class UpFirDn2dBackward(Function):
22 | @staticmethod
23 | def forward(
24 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
25 | ):
26 |
27 | up_x, up_y = up
28 | down_x, down_y = down
29 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
30 |
31 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
32 |
33 | grad_input = upfirdn2d_op.upfirdn2d(
34 | grad_output,
35 | grad_kernel,
36 | down_x,
37 | down_y,
38 | up_x,
39 | up_y,
40 | g_pad_x0,
41 | g_pad_x1,
42 | g_pad_y0,
43 | g_pad_y1,
44 | )
45 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
46 |
47 | ctx.save_for_backward(kernel)
48 |
49 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
50 |
51 | ctx.up_x = up_x
52 | ctx.up_y = up_y
53 | ctx.down_x = down_x
54 | ctx.down_y = down_y
55 | ctx.pad_x0 = pad_x0
56 | ctx.pad_x1 = pad_x1
57 | ctx.pad_y0 = pad_y0
58 | ctx.pad_y1 = pad_y1
59 | ctx.in_size = in_size
60 | ctx.out_size = out_size
61 |
62 | return grad_input
63 |
64 | @staticmethod
65 | def backward(ctx, gradgrad_input):
66 | kernel, = ctx.saved_tensors
67 |
68 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
69 |
70 | gradgrad_out = upfirdn2d_op.upfirdn2d(
71 | gradgrad_input,
72 | kernel,
73 | ctx.up_x,
74 | ctx.up_y,
75 | ctx.down_x,
76 | ctx.down_y,
77 | ctx.pad_x0,
78 | ctx.pad_x1,
79 | ctx.pad_y0,
80 | ctx.pad_y1,
81 | )
82 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
83 | gradgrad_out = gradgrad_out.view(
84 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
85 | )
86 |
87 | return gradgrad_out, None, None, None, None, None, None, None, None
88 |
89 |
90 | class UpFirDn2d(Function):
91 | @staticmethod
92 | def forward(ctx, input, kernel, up, down, pad):
93 | up_x, up_y = up
94 | down_x, down_y = down
95 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
96 |
97 | kernel_h, kernel_w = kernel.shape
98 | batch, channel, in_h, in_w = input.shape
99 | ctx.in_size = input.shape
100 |
101 | input = input.reshape(-1, in_h, in_w, 1)
102 |
103 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
104 |
105 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
106 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
107 | ctx.out_size = (out_h, out_w)
108 |
109 | ctx.up = (up_x, up_y)
110 | ctx.down = (down_x, down_y)
111 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
112 |
113 | g_pad_x0 = kernel_w - pad_x0 - 1
114 | g_pad_y0 = kernel_h - pad_y0 - 1
115 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
116 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
117 |
118 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
119 |
120 | out = upfirdn2d_op.upfirdn2d(
121 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
122 | )
123 | # out = out.view(major, out_h, out_w, minor)
124 | out = out.view(-1, channel, out_h, out_w)
125 |
126 | return out
127 |
128 | @staticmethod
129 | def backward(ctx, grad_output):
130 | kernel, grad_kernel = ctx.saved_tensors
131 |
132 | grad_input = None
133 |
134 | if ctx.needs_input_grad[0]:
135 | grad_input = UpFirDn2dBackward.apply(
136 | grad_output,
137 | kernel,
138 | grad_kernel,
139 | ctx.up,
140 | ctx.down,
141 | ctx.pad,
142 | ctx.g_pad,
143 | ctx.in_size,
144 | ctx.out_size,
145 | )
146 |
147 | return grad_input, None, None, None, None
148 |
149 |
150 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
151 | if not isinstance(up, abc.Iterable):
152 | up = (up, up)
153 |
154 | if not isinstance(down, abc.Iterable):
155 | down = (down, down)
156 |
157 | if len(pad) == 2:
158 | pad = (pad[0], pad[1], pad[0], pad[1])
159 |
160 | if input.device.type == "cpu":
161 | out = upfirdn2d_native(input, kernel, *up, *down, *pad)
162 |
163 | else:
164 | out = UpFirDn2d.apply(input, kernel, up, down, pad)
165 |
166 | return out
167 |
168 |
169 | def upfirdn2d_native(
170 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
171 | ):
172 | _, channel, in_h, in_w = input.shape
173 | input = input.reshape(-1, in_h, in_w, 1)
174 |
175 | _, in_h, in_w, minor = input.shape
176 | kernel_h, kernel_w = kernel.shape
177 |
178 | out = input.view(-1, in_h, 1, in_w, 1, minor)
179 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
180 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
181 |
182 | out = F.pad(
183 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
184 | )
185 | out = out[
186 | :,
187 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
188 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
189 | :,
190 | ]
191 |
192 | out = out.permute(0, 3, 1, 2)
193 | out = out.reshape(
194 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
195 | )
196 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
197 | out = F.conv2d(out, w)
198 | out = out.reshape(
199 | -1,
200 | minor,
201 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
202 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
203 | )
204 | out = out.permute(0, 2, 3, 1)
205 | out = out[:, ::down_y, ::down_x, :]
206 |
207 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y
208 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x
209 |
210 | return out.view(-1, channel, out_h, out_w)
211 |
--------------------------------------------------------------------------------
/StyleGAN2/ppl.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from torch.nn import functional as F
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | import lpips
9 | from model import Generator
10 |
11 |
12 | def normalize(x):
13 | return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True))
14 |
15 |
16 | def slerp(a, b, t):
17 | a = normalize(a)
18 | b = normalize(b)
19 | d = (a * b).sum(-1, keepdim=True)
20 | p = t * torch.acos(d)
21 | c = normalize(b - d * a)
22 | d = a * torch.cos(p) + c * torch.sin(p)
23 |
24 | return normalize(d)
25 |
26 |
27 | def lerp(a, b, t):
28 | return a + (b - a) * t
29 |
30 |
31 | if __name__ == "__main__":
32 | device = "cuda"
33 |
34 | parser = argparse.ArgumentParser(description="Perceptual Path Length calculator")
35 |
36 | parser.add_argument(
37 | "--space", choices=["z", "w"], help="space that PPL calculated with"
38 | )
39 | parser.add_argument(
40 | "--batch", type=int, default=64, help="batch size for the models"
41 | )
42 | parser.add_argument(
43 | "--n_sample",
44 | type=int,
45 | default=5000,
46 | help="number of the samples for calculating PPL",
47 | )
48 | parser.add_argument(
49 | "--size", type=int, default=256, help="output image sizes of the generator"
50 | )
51 | parser.add_argument(
52 | "--eps", type=float, default=1e-4, help="epsilon for numerical stability"
53 | )
54 | parser.add_argument(
55 | "--crop", action="store_true", help="apply center crop to the images"
56 | )
57 | parser.add_argument(
58 | "--sampling",
59 | default="end",
60 | choices=["end", "full"],
61 | help="set endpoint sampling method",
62 | )
63 | parser.add_argument(
64 | "ckpt", metavar="CHECKPOINT", help="path to the model checkpoints"
65 | )
66 |
67 | args = parser.parse_args()
68 |
69 | latent_dim = 512
70 |
71 | ckpt = torch.load(args.ckpt)
72 |
73 | g = Generator(args.size, latent_dim, 8).to(device)
74 | g.load_state_dict(ckpt["g_ema"])
75 | g.eval()
76 |
77 | percept = lpips.PerceptualLoss(
78 | model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
79 | )
80 |
81 | distances = []
82 |
83 | n_batch = args.n_sample // args.batch
84 | resid = args.n_sample - (n_batch * args.batch)
85 | batch_sizes = [args.batch] * n_batch + [resid]
86 |
87 | with torch.no_grad():
88 | for batch in tqdm(batch_sizes):
89 | noise = g.make_noise()
90 |
91 | inputs = torch.randn([batch * 2, latent_dim], device=device)
92 | if args.sampling == "full":
93 | lerp_t = torch.rand(batch, device=device)
94 | else:
95 | lerp_t = torch.zeros(batch, device=device)
96 |
97 | if args.space == "w":
98 | latent = g.get_latent(inputs)
99 | latent_t0, latent_t1 = latent[::2], latent[1::2]
100 | latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None])
101 | latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps)
102 | latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape)
103 |
104 | image, _ = g([latent_e], input_is_latent=True, noise=noise)
105 |
106 | if args.crop:
107 | c = image.shape[2] // 8
108 | image = image[:, :, c * 3 : c * 7, c * 2 : c * 6]
109 |
110 | factor = image.shape[2] // 256
111 |
112 | if factor > 1:
113 | image = F.interpolate(
114 | image, size=(256, 256), mode="bilinear", align_corners=False
115 | )
116 |
117 | dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / (
118 | args.eps ** 2
119 | )
120 | distances.append(dist.to("cpu").numpy())
121 |
122 | distances = np.concatenate(distances, 0)
123 |
124 | lo = np.percentile(distances, 1, interpolation="lower")
125 | hi = np.percentile(distances, 99, interpolation="higher")
126 | filtered_dist = np.extract(
127 | np.logical_and(lo <= distances, distances <= hi), distances
128 | )
129 |
130 | print("ppl:", filtered_dist.mean())
131 |
--------------------------------------------------------------------------------
/StyleGAN2/prepare_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from io import BytesIO
3 | import multiprocessing
4 | from functools import partial
5 |
6 | from PIL import Image
7 | import lmdb
8 | from tqdm import tqdm
9 | from torchvision import datasets
10 | from torchvision.transforms import functional as trans_fn
11 |
12 |
13 | def resize_and_convert(img, size, resample, quality=100):
14 | img = trans_fn.resize(img, size, resample)
15 | img = trans_fn.center_crop(img, size)
16 | buffer = BytesIO()
17 | img.save(buffer, format="jpeg", quality=quality)
18 | val = buffer.getvalue()
19 |
20 | return val
21 |
22 |
23 | def resize_multiple(
24 | img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100
25 | ):
26 | imgs = []
27 |
28 | for size in sizes:
29 | imgs.append(resize_and_convert(img, size, resample, quality))
30 |
31 | return imgs
32 |
33 |
34 | def resize_worker(img_file, sizes, resample):
35 | i, file = img_file
36 | img = Image.open(file)
37 | img = img.convert("RGB")
38 | out = resize_multiple(img, sizes=sizes, resample=resample)
39 |
40 | return i, out
41 |
42 |
43 | def prepare(
44 | env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS
45 | ):
46 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
47 |
48 | files = sorted(dataset.imgs, key=lambda x: x[0])
49 | files = [(i, file) for i, (file, label) in enumerate(files)]
50 | total = 0
51 |
52 | with multiprocessing.Pool(n_worker) as pool:
53 | for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
54 | for size, img in zip(sizes, imgs):
55 | key = f"{size}-{str(i).zfill(5)}".encode("utf-8")
56 |
57 | with env.begin(write=True) as txn:
58 | txn.put(key, img)
59 |
60 | total += 1
61 |
62 | with env.begin(write=True) as txn:
63 | txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
64 |
65 |
66 | if __name__ == "__main__":
67 | parser = argparse.ArgumentParser(description="Preprocess images for model training")
68 | parser.add_argument("--out", type=str, help="filename of the result lmdb dataset")
69 | parser.add_argument(
70 | "--size",
71 | type=str,
72 | default="128,256,512,1024",
73 | help="resolutions of images for the dataset",
74 | )
75 | parser.add_argument(
76 | "--n_worker",
77 | type=int,
78 | default=8,
79 | help="number of workers for preparing dataset",
80 | )
81 | parser.add_argument(
82 | "--resample",
83 | type=str,
84 | default="lanczos",
85 | help="resampling methods for resizing images",
86 | )
87 | parser.add_argument("path", type=str, help="path to the image dataset")
88 |
89 | args = parser.parse_args()
90 |
91 | resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
92 | resample = resample_map[args.resample]
93 |
94 | sizes = [int(s.strip()) for s in args.size.split(",")]
95 |
96 | print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
97 |
98 | imgset = datasets.ImageFolder(args.path)
99 |
100 | with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env:
101 | prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample)
102 |
--------------------------------------------------------------------------------
/StyleGAN2/projector.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import os
4 |
5 | import torch
6 | from torch import optim
7 | from torch.nn import functional as F
8 | from torchvision import transforms
9 | from PIL import Image
10 | from tqdm import tqdm
11 |
12 | import lpips
13 | from model import Generator
14 |
15 |
16 | def noise_regularize(noises):
17 | loss = 0
18 |
19 | for noise in noises:
20 | size = noise.shape[2]
21 |
22 | while True:
23 | loss = (
24 | loss
25 | + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
26 | + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
27 | )
28 |
29 | if size <= 8:
30 | break
31 |
32 | noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
33 | noise = noise.mean([3, 5])
34 | size //= 2
35 |
36 | return loss
37 |
38 |
39 | def noise_normalize_(noises):
40 | for noise in noises:
41 | mean = noise.mean()
42 | std = noise.std()
43 |
44 | noise.data.add_(-mean).div_(std)
45 |
46 |
47 | def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
48 | lr_ramp = min(1, (1 - t) / rampdown)
49 | lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
50 | lr_ramp = lr_ramp * min(1, t / rampup)
51 |
52 | return initial_lr * lr_ramp
53 |
54 |
55 | def latent_noise(latent, strength):
56 | noise = torch.randn_like(latent) * strength
57 |
58 | return latent + noise
59 |
60 |
61 | def make_image(tensor):
62 | return (
63 | tensor.detach()
64 | .clamp_(min=-1, max=1)
65 | .add(1)
66 | .div_(2)
67 | .mul(255)
68 | .type(torch.uint8)
69 | .permute(0, 2, 3, 1)
70 | .to("cpu")
71 | .numpy()
72 | )
73 |
74 |
75 | if __name__ == "__main__":
76 | device = "cuda"
77 |
78 | parser = argparse.ArgumentParser(
79 | description="Image projector to the generator latent spaces"
80 | )
81 | parser.add_argument(
82 | "--ckpt", type=str, required=True, help="path to the model checkpoint"
83 | )
84 | parser.add_argument(
85 | "--size", type=int, default=256, help="output image sizes of the generator"
86 | )
87 | parser.add_argument(
88 | "--lr_rampup",
89 | type=float,
90 | default=0.05,
91 | help="duration of the learning rate warmup",
92 | )
93 | parser.add_argument(
94 | "--lr_rampdown",
95 | type=float,
96 | default=0.25,
97 | help="duration of the learning rate decay",
98 | )
99 | parser.add_argument("--lr", type=float, default=0.1, help="learning rate")
100 | parser.add_argument(
101 | "--noise", type=float, default=0.05, help="strength of the noise level"
102 | )
103 | parser.add_argument(
104 | "--noise_ramp",
105 | type=float,
106 | default=0.75,
107 | help="duration of the noise level decay",
108 | )
109 | parser.add_argument("--step", type=int, default=1000, help="optimize iterations")
110 | parser.add_argument(
111 | "--noise_regularize",
112 | type=float,
113 | default=1e5,
114 | help="weight of the noise regularization",
115 | )
116 | parser.add_argument("--mse", type=float, default=0, help="weight of the mse loss")
117 | parser.add_argument(
118 | "--w_plus",
119 | action="store_true",
120 | help="allow to use distinct latent codes to each layers",
121 | )
122 | parser.add_argument(
123 | "files", metavar="FILES", nargs="+", help="path to image files to be projected"
124 | )
125 |
126 | args = parser.parse_args()
127 |
128 | n_mean_latent = 10000
129 |
130 | resize = min(args.size, 256)
131 |
132 | transform = transforms.Compose(
133 | [
134 | transforms.Resize(resize),
135 | transforms.CenterCrop(resize),
136 | transforms.ToTensor(),
137 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
138 | ]
139 | )
140 |
141 | imgs = []
142 |
143 | for imgfile in args.files:
144 | img = transform(Image.open(imgfile).convert("RGB"))
145 | imgs.append(img)
146 |
147 | imgs = torch.stack(imgs, 0).to(device)
148 |
149 | g_ema = Generator(args.size, 512, 8)
150 | g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False)
151 | g_ema.eval()
152 | g_ema = g_ema.to(device)
153 |
154 | with torch.no_grad():
155 | noise_sample = torch.randn(n_mean_latent, 512, device=device)
156 | latent_out = g_ema.style(noise_sample)
157 |
158 | latent_mean = latent_out.mean(0)
159 | latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5
160 |
161 | percept = lpips.PerceptualLoss(
162 | model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
163 | )
164 |
165 | noises_single = g_ema.make_noise()
166 | noises = []
167 | for noise in noises_single:
168 | noises.append(noise.repeat(imgs.shape[0], 1, 1, 1).normal_())
169 |
170 | latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(imgs.shape[0], 1)
171 |
172 | if args.w_plus:
173 | latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
174 |
175 | latent_in.requires_grad = True
176 |
177 | for noise in noises:
178 | noise.requires_grad = True
179 |
180 | optimizer = optim.Adam([latent_in] + noises, lr=args.lr)
181 |
182 | pbar = tqdm(range(args.step))
183 | latent_path = []
184 |
185 | for i in pbar:
186 | t = i / args.step
187 | lr = get_lr(t, args.lr)
188 | optimizer.param_groups[0]["lr"] = lr
189 | noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2
190 | latent_n = latent_noise(latent_in, noise_strength.item())
191 |
192 | img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises)
193 |
194 | batch, channel, height, width = img_gen.shape
195 |
196 | if height > 256:
197 | factor = height // 256
198 |
199 | img_gen = img_gen.reshape(
200 | batch, channel, height // factor, factor, width // factor, factor
201 | )
202 | img_gen = img_gen.mean([3, 5])
203 |
204 | p_loss = percept(img_gen, imgs).sum()
205 | n_loss = noise_regularize(noises)
206 | mse_loss = F.mse_loss(img_gen, imgs)
207 |
208 | loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss
209 |
210 | optimizer.zero_grad()
211 | loss.backward()
212 | optimizer.step()
213 |
214 | noise_normalize_(noises)
215 |
216 | if (i + 1) % 100 == 0:
217 | latent_path.append(latent_in.detach().clone())
218 |
219 | pbar.set_description(
220 | (
221 | f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
222 | f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
223 | )
224 | )
225 |
226 | img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises)
227 |
228 | filename = os.path.splitext(os.path.basename(args.files[0]))[0] + ".pt"
229 |
230 | img_ar = make_image(img_gen)
231 |
232 | result_file = {}
233 | for i, input_name in enumerate(args.files):
234 | noise_single = []
235 | for noise in noises:
236 | noise_single.append(noise[i : i + 1])
237 |
238 | result_file[input_name] = {
239 | "img": img_gen[i],
240 | "latent": latent_in[i],
241 | "noise": noise_single,
242 | }
243 |
244 | img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png"
245 | pil_img = Image.fromarray(img_ar[i])
246 | pil_img.save(img_name)
247 |
248 | torch.save(result_file, filename)
249 |
--------------------------------------------------------------------------------
/StyleGAN2/regan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.optim as optim
4 | import torch.nn as nn
5 |
6 |
7 | class Regan_training(nn.Module):
8 |
9 | def __init__(self, model, sparsity, train_on_sparse=False):
10 | super(Regan_training, self).__init__()
11 |
12 | self.model = model
13 | self.sparsity = sparsity
14 | self.train_on_sparse = train_on_sparse
15 | self.layers = []
16 | self.masks = []
17 | layers = list(self.model.named_parameters())
18 |
19 | for i in range(0, len(layers)):
20 | w = layers[i]
21 | if "to_rgb" not in w[0]:
22 | self.layers.append(w[1])
23 | self.reset_masks()
24 |
25 | def get_latent(self, input):
26 | return self.model.style(input)
27 |
28 | def mean_latent(self, n_latent):
29 | latent_in = torch.randn(
30 | n_latent, self.model.style_dim, device=self.model.input.input.device
31 | )
32 | latent = self.model.style(latent_in).mean(0, keepdim=True)
33 |
34 | return latent
35 |
36 | def reset_masks(self):
37 | for w in self.layers:
38 | mask_w = torch.ones_like(w, dtype=bool)
39 | self.masks.append(mask_w)
40 |
41 | return self.masks
42 |
43 | def update_masks(self):
44 |
45 | for i, w in enumerate(self.layers):
46 | q_w = torch.quantile(torch.abs(w), q=self.sparsity)
47 | mask_w = torch.where(torch.abs(w) < q_w, True, False)
48 |
49 | self.masks[i] = mask_w
50 |
51 | def turn_training_mode(self, mode):
52 | if mode == 'dense':
53 | self.train_on_sparse = False
54 | else:
55 | self.train_on_sparse = True
56 | self.update_masks()
57 |
58 | def apply_masks(self):
59 | for w, mask_w in zip(self.layers, self.masks):
60 | w.data[mask_w] = 0
61 | w.grad.data[mask_w] = 0
62 |
63 | def forward(self, x, return_latents=False,
64 | inject_index=None,
65 | truncation=1,
66 | truncation_latent=None,
67 | input_is_latent=False,
68 | noise=None,
69 | randomize_noise=True,):
70 | return self.model(x, return_latents,
71 | inject_index,
72 | truncation,
73 | truncation_latent,
74 | input_is_latent,
75 | noise,
76 | randomize_noise)
--------------------------------------------------------------------------------
/figures/Table5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntellicentAI-Lab/Re-GAN/785135c820777149e8a4da941ee369b7d994f951/figures/Table5.png
--------------------------------------------------------------------------------
/figures/main_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IntellicentAI-Lab/Re-GAN/785135c820777149e8a4da941ee369b7d994f951/figures/main_figure.png
--------------------------------------------------------------------------------