├── .style.yapf
├── img
├── model.jpeg
├── SRGAN_Result.png
├── SRGAN_Result2.png
└── SRGAN_Result3.png
├── .gitignore
├── requirements.txt
├── config.py
├── README.md
├── train.py
├── vgg.py
└── srgan.py
/.style.yapf:
--------------------------------------------------------------------------------
1 | [style]
2 | column_limit = 160
3 |
--------------------------------------------------------------------------------
/img/model.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorlayer/SRGAN/HEAD/img/model.jpeg
--------------------------------------------------------------------------------
/img/SRGAN_Result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorlayer/SRGAN/HEAD/img/SRGAN_Result.png
--------------------------------------------------------------------------------
/img/SRGAN_Result2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorlayer/SRGAN/HEAD/img/SRGAN_Result2.png
--------------------------------------------------------------------------------
/img/SRGAN_Result3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensorlayer/SRGAN/HEAD/img/SRGAN_Result3.png
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ._*
2 | *.pyc
3 | .DS_Store
4 | *.npz
5 | sample/
6 | samples/
7 | checkpoint/
8 | __pycache__/
9 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/tensorlayer/tensorlayerx.git
2 | numpy>=1.16.1
3 | easydict==1.9
4 | opencv-python>=4.5.1.48
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from easydict import EasyDict as edict
2 | import json
3 |
4 | config = edict()
5 | config.TRAIN = edict()
6 | config.TRAIN.batch_size = 16 # [16] use 8 if your GPU memory is small
7 | config.TRAIN.lr_init = 1e-4
8 | config.TRAIN.beta1 = 0.9
9 |
10 | ## initialize G
11 | config.TRAIN.n_epoch_init = 100
12 | # config.TRAIN.lr_decay_init = 0.1
13 | # config.TRAIN.decay_every_init = int(config.TRAIN.n_epoch_init / 2)
14 |
15 | ## adversarial learning (SRGAN)
16 | config.TRAIN.n_epoch = 2000
17 | config.TRAIN.lr_decay = 0.1
18 | config.TRAIN.decay_every = int(config.TRAIN.n_epoch / 2)
19 |
20 | ## train set location
21 | config.TRAIN.hr_img_path = 'DIV2K/DIV2K_train_HR/'
22 | config.TRAIN.lr_img_path = 'DIV2K/DIV2K_train_LR_bicubic/X4/'
23 |
24 | config.VALID = edict()
25 | ## test set location
26 | config.VALID.hr_img_path = 'DIV2K/DIV2K_valid_HR/'
27 | config.VALID.lr_img_path = 'DIV2K/DIV2K_valid_LR_bicubic/X4/'
28 |
29 | def log_config(filename, cfg):
30 | with open(filename, 'w') as f:
31 | f.write("================================================\n")
32 | f.write(json.dumps(cfg, indent=4))
33 | f.write("\n================================================\n")
34 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## Super Resolution Examples
2 |
3 | - Implementation of ["Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"](https://arxiv.org/abs/1609.04802)
4 |
5 | - For earlier version, please check [srgan release](https://github.com/tensorlayer/srgan/releases) and [tensorlayer](https://github.com/tensorlayer/TensorLayer).
6 |
7 | - For more computer vision applications, check [TLXCV](https://github.com/tensorlayer/TLXCV)
8 |
9 |
10 | ### SRGAN Architecture
11 |
12 |
13 |
14 |
15 |

16 |
17 |
18 |
19 |
20 |

21 |
22 |
23 |
24 | ### Prepare Data and Pre-trained VGG
25 |
26 | - 1. You need to download the pretrained VGG19 model weights in [here](https://drive.google.com/file/d/1CLw6Cn3yNI1N15HyX99_Zy9QnDcgP3q7/view?usp=sharing).
27 | - 2. You need to have the high resolution images for training.
28 | - In this experiment, I used images from [DIV2K - bicubic downscaling x4 competition](http://www.vision.ee.ethz.ch/ntire17/), so the hyper-paremeters in `config.py` (like number of epochs) are seleted basic on that dataset, if you change a larger dataset you can reduce the number of epochs.
29 | - If you dont want to use DIV2K dataset, you can also use [Yahoo MirFlickr25k](http://press.liacs.nl/mirflickr/mirdownload.html), just simply download it using `train_hr_imgs = tl.files.load_flickr25k_dataset(tag=None)` in `main.py`.
30 | - If you want to use your own images, you can set the path to your image folder via `config.TRAIN.hr_img_path` in `config.py`.
31 |
32 |
33 |
34 | ### Run
35 |
36 | 🔥🔥🔥🔥🔥🔥 You need install [TensorLayerX](https://github.com/tensorlayer/TensorLayerX#installation) at first!
37 |
38 | 🔥🔥🔥🔥🔥🔥 Please install TensorLayerX via source
39 |
40 | ```bash
41 | pip install git+https://github.com/tensorlayer/tensorlayerx.git
42 | ```
43 |
44 | #### Train
45 | - Set your image folder in `config.py`, if you download [DIV2K - bicubic downscaling x4 competition](http://www.vision.ee.ethz.ch/ntire17/) dataset, you don't need to change it.
46 | - Other links for DIV2K, in case you can't find it : [test\_LR\_bicubic_X4](https://data.vision.ee.ethz.ch/cvl/DIV2K/validation_release/DIV2K_test_LR_bicubic_X4.zip), [train_HR](https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip), [train\_LR\_bicubic_X4](https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X4.zip), [valid_HR](https://data.vision.ee.ethz.ch/cvl/DIV2K/validation_release/DIV2K_valid_HR.zip), [valid\_LR\_bicubic_X4](https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X4.zip).
47 |
48 | ```python
49 | config.TRAIN.img_path = "your_image_folder/"
50 | ```
51 | Your directory structure should look like this:
52 |
53 | ```
54 | srgan/
55 | └── config.py
56 | └── srgan.py
57 | └── train.py
58 | └── vgg.py
59 | └── model
60 | └── vgg19.npy
61 | └── DIV2K
62 | └── DIV2K_train_HR
63 | ├── DIV2K_train_LR_bicubic
64 | ├── DIV2K_valid_HR
65 | └── DIV2K_valid_LR_bicubic
66 |
67 | ```
68 |
69 | - Start training.
70 |
71 | ```bash
72 | python train.py
73 | ```
74 |
75 | 🔥Modify a line of code in **train.py**, easily switch to any framework!
76 |
77 | ```python
78 | import os
79 | os.environ['TL_BACKEND'] = 'tensorflow'
80 | # os.environ['TL_BACKEND'] = 'mindspore'
81 | # os.environ['TL_BACKEND'] = 'paddle'
82 | # os.environ['TL_BACKEND'] = 'pytorch'
83 | ```
84 | 🚧 We will support PyTorch as Backend soon.
85 |
86 |
87 | #### Evaluation.
88 |
89 | 🔥 We have trained SRGAN on DIV2K dataset.
90 | 🔥 Download model weights as follows.
91 |
92 | | | SRGAN_g | SRGAN_d |
93 | |------------- |---------|---------|
94 | | TensorFlow | [Baidu](https://pan.baidu.com/s/118uUg3oce_3NZQCIWHVjmA?pwd=p9li), [Googledrive](https://drive.google.com/file/d/1GlU9At-5XEDilgnt326fyClvZB_fsaFZ/view?usp=sharing) |[Baidu](https://pan.baidu.com/s/1DOpGzDJY5PyusKzaKqbLOg?pwd=g2iy), [Googledrive](https://drive.google.com/file/d/1RpOtVcVK-yxnVhNH4KSjnXHDvuU_pq3j/view?usp=sharing) |
95 | | PaddlePaddle | [Baidu](https://pan.baidu.com/s/1ngBpleV5vQZQqNE_8djDIg?pwd=s8wc), [Googledrive](https://drive.google.com/file/d/1GRNt_ZsgorB19qvwN5gE6W9a_bIPLkg1/view?usp=sharing) | [Baidu](https://pan.baidu.com/s/1nSefLNRanFImf1DskSVpCg?pwd=befc), [Googledrive](https://drive.google.com/file/d/1Jf6W1ZPdgtmUSfrQ5mMZDB_hOCVU-zFo/view?usp=sharing) |
96 | | MindSpore | 🚧Coming soon! | 🚧Coming soon! |
97 | | PyTorch | 🚧Coming soon! | 🚧Coming soon! |
98 |
99 |
100 | Download weights file and put weights under the folder srgan/models/.
101 |
102 | Your directory structure should look like this:
103 |
104 | ```
105 | srgan/
106 | └── config.py
107 | └── srgan.py
108 | └── train.py
109 | └── vgg.py
110 | └── model
111 | └── vgg19.npy
112 | └── DIV2K
113 | ├── DIV2K_train_HR
114 | ├── DIV2K_train_LR_bicubic
115 | ├── DIV2K_valid_HR
116 | └── DIV2K_valid_LR_bicubic
117 | └── models
118 | ├── g.npz # You should rename the weigths file.
119 | └── d.npz # If you set os.environ['TL_BACKEND'] = 'tensorflow',you should rename srgan-g-tensorflow.npz to g.npz .
120 |
121 | ```
122 |
123 | - Start evaluation.
124 | ```bash
125 | python train.py --mode=eval
126 | ```
127 |
128 | Results will be saved under the folder srgan/samples/.
129 |
130 | ### Results
131 |
132 |
133 |
134 |

135 |
136 |
137 |
138 |
139 | ### Reference
140 | * [1] [Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802)
141 | * [2] [Is the deconvolution layer the same as a convolutional layer ?](https://arxiv.org/abs/1609.07009)
142 |
143 |
144 |
145 | ### Citation
146 | If you find this project useful, we would be grateful if you cite the TensorLayer paper:
147 |
148 | ```
149 | @article{tensorlayer2017,
150 | author = {Dong, Hao and Supratak, Akara and Mai, Luo and Liu, Fangde and Oehmichen, Axel and Yu, Simiao and Guo, Yike},
151 | journal = {ACM Multimedia},
152 | title = {{TensorLayer: A Versatile Library for Efficient Deep Learning Development}},
153 | url = {http://tensorlayer.org},
154 | year = {2017}
155 | }
156 |
157 | @inproceedings{tensorlayer2021,
158 | title={TensorLayer 3.0: A Deep Learning Library Compatible With Multiple Backends},
159 | author={Lai, Cheng and Han, Jiarong and Dong, Hao},
160 | booktitle={2021 IEEE International Conference on Multimedia \& Expo Workshops (ICMEW)},
161 | pages={1--3},
162 | year={2021},
163 | organization={IEEE}
164 | }
165 | ```
166 |
167 | ### Other Projects
168 |
169 | - [Style Transfer](https://github.com/tensorlayer/adaptive-style-transfer)
170 | - [Pose Estimation](https://github.com/tensorlayer/openpose)
171 |
172 | ### Discussion
173 |
174 | - [TensorLayer Slack](https://join.slack.com/t/tensorlayer/shared_invite/enQtMjUyMjczMzU2Njg4LWI0MWU0MDFkOWY2YjQ4YjVhMzI5M2VlZmE4YTNhNGY1NjZhMzUwMmQ2MTc0YWRjMjQzMjdjMTg2MWQ2ZWJhYzc)
175 | - [TensorLayer WeChat](https://github.com/tensorlayer/tensorlayer-chinese/blob/master/docs/wechat_group.md)
176 |
177 | ### License
178 |
179 | - For academic and non-commercial use only.
180 | - For commercial use, please contact tensorlayer@gmail.com.
181 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ['TL_BACKEND'] = 'tensorflow' # Just modify this line, easily switch to any framework! PyTorch will coming soon!
3 | # os.environ['TL_BACKEND'] = 'mindspore'
4 | # os.environ['TL_BACKEND'] = 'paddle'
5 | # os.environ['TL_BACKEND'] = 'torch'
6 | import time
7 | import numpy as np
8 | import tensorlayerx as tlx
9 | from tensorlayerx.dataflow import Dataset, DataLoader
10 | from srgan import SRGAN_g, SRGAN_d
11 | from config import config
12 | from tensorlayerx.vision.transforms import Compose, RandomCrop, Normalize, RandomFlipHorizontal, Resize, HWC2CHW
13 | import vgg
14 | from tensorlayerx.model import TrainOneStep
15 | from tensorlayerx.nn import Module
16 | import cv2
17 | tlx.set_device('GPU')
18 |
19 | ###====================== HYPER-PARAMETERS ===========================###
20 | batch_size = 8
21 | n_epoch_init = config.TRAIN.n_epoch_init
22 | n_epoch = config.TRAIN.n_epoch
23 | # create folders to save result images and trained models
24 | save_dir = "samples"
25 | tlx.files.exists_or_mkdir(save_dir)
26 | checkpoint_dir = "models"
27 | tlx.files.exists_or_mkdir(checkpoint_dir)
28 |
29 | hr_transform = Compose([
30 | RandomCrop(size=(384, 384)),
31 | RandomFlipHorizontal(),
32 | ])
33 | nor = Compose([Normalize(mean=(127.5), std=(127.5), data_format='HWC'),
34 | HWC2CHW()])
35 | lr_transform = Resize(size=(96, 96))
36 |
37 | train_hr_imgs = tlx.vision.load_images(path=config.TRAIN.hr_img_path, n_threads = 32)
38 |
39 | class TrainData(Dataset):
40 |
41 | def __init__(self, hr_trans=hr_transform, lr_trans=lr_transform):
42 | self.train_hr_imgs = train_hr_imgs
43 | self.hr_trans = hr_trans
44 | self.lr_trans = lr_trans
45 |
46 | def __getitem__(self, index):
47 | img = self.train_hr_imgs[index]
48 | hr_patch = self.hr_trans(img)
49 | lr_patch = self.lr_trans(hr_patch)
50 | return nor(lr_patch), nor(hr_patch)
51 |
52 | def __len__(self):
53 | return len(self.train_hr_imgs)
54 |
55 |
56 | class WithLoss_init(Module):
57 | def __init__(self, G_net, loss_fn):
58 | super(WithLoss_init, self).__init__()
59 | self.net = G_net
60 | self.loss_fn = loss_fn
61 |
62 | def forward(self, lr, hr):
63 | out = self.net(lr)
64 | loss = self.loss_fn(out, hr)
65 | return loss
66 |
67 |
68 | class WithLoss_D(Module):
69 | def __init__(self, D_net, G_net, loss_fn):
70 | super(WithLoss_D, self).__init__()
71 | self.D_net = D_net
72 | self.G_net = G_net
73 | self.loss_fn = loss_fn
74 |
75 | def forward(self, lr, hr):
76 | fake_patchs = self.G_net(lr)
77 | logits_fake = self.D_net(fake_patchs)
78 | logits_real = self.D_net(hr)
79 | d_loss1 = self.loss_fn(logits_real, tlx.ones_like(logits_real))
80 | d_loss1 = tlx.ops.reduce_mean(d_loss1)
81 | d_loss2 = self.loss_fn(logits_fake, tlx.zeros_like(logits_fake))
82 | d_loss2 = tlx.ops.reduce_mean(d_loss2)
83 | d_loss = d_loss1 + d_loss2
84 | return d_loss
85 |
86 |
87 | class WithLoss_G(Module):
88 | def __init__(self, D_net, G_net, vgg, loss_fn1, loss_fn2):
89 | super(WithLoss_G, self).__init__()
90 | self.D_net = D_net
91 | self.G_net = G_net
92 | self.vgg = vgg
93 | self.loss_fn1 = loss_fn1
94 | self.loss_fn2 = loss_fn2
95 |
96 | def forward(self, lr, hr):
97 | fake_patchs = self.G_net(lr)
98 | logits_fake = self.D_net(fake_patchs)
99 | feature_fake = self.vgg((fake_patchs + 1) / 2.)
100 | feature_real = self.vgg((hr + 1) / 2.)
101 | g_gan_loss = 1e-3 * self.loss_fn1(logits_fake, tlx.ones_like(logits_fake))
102 | g_gan_loss = tlx.ops.reduce_mean(g_gan_loss)
103 | mse_loss = self.loss_fn2(fake_patchs, hr)
104 | vgg_loss = 2e-6 * self.loss_fn2(feature_fake, feature_real)
105 | g_loss = mse_loss + vgg_loss + g_gan_loss
106 | return g_loss
107 |
108 |
109 | G = SRGAN_g()
110 | D = SRGAN_d()
111 | VGG = vgg.VGG19(pretrained=True, end_with='pool4', mode='dynamic')
112 | # automatic init layers weights shape with input tensor.
113 | # Calculating and filling 'in_channels' of each layer is a very troublesome thing.
114 | # So, just use 'init_build' with input shape. 'in_channels' of each layer will be automaticlly set.
115 | G.init_build(tlx.nn.Input(shape=(8, 3, 96, 96)))
116 | D.init_build(tlx.nn.Input(shape=(8, 3, 384, 384)))
117 |
118 |
119 | def train():
120 | G.set_train()
121 | D.set_train()
122 | VGG.set_eval()
123 | train_ds = TrainData()
124 | train_ds_img_nums = len(train_ds)
125 | train_ds = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)
126 |
127 | lr_v = tlx.optimizers.lr.StepDecay(learning_rate=0.05, step_size=1000, gamma=0.1, last_epoch=-1, verbose=True)
128 | g_optimizer_init = tlx.optimizers.Momentum(lr_v, 0.9)
129 | g_optimizer = tlx.optimizers.Momentum(lr_v, 0.9)
130 | d_optimizer = tlx.optimizers.Momentum(lr_v, 0.9)
131 | g_weights = G.trainable_weights
132 | d_weights = D.trainable_weights
133 | net_with_loss_init = WithLoss_init(G, loss_fn=tlx.losses.mean_squared_error)
134 | net_with_loss_D = WithLoss_D(D_net=D, G_net=G, loss_fn=tlx.losses.sigmoid_cross_entropy)
135 | net_with_loss_G = WithLoss_G(D_net=D, G_net=G, vgg=VGG, loss_fn1=tlx.losses.sigmoid_cross_entropy,
136 | loss_fn2=tlx.losses.mean_squared_error)
137 |
138 | trainforinit = TrainOneStep(net_with_loss_init, optimizer=g_optimizer_init, train_weights=g_weights)
139 | trainforG = TrainOneStep(net_with_loss_G, optimizer=g_optimizer, train_weights=g_weights)
140 | trainforD = TrainOneStep(net_with_loss_D, optimizer=d_optimizer, train_weights=d_weights)
141 |
142 | # initialize learning (G)
143 | n_step_epoch = round(train_ds_img_nums // batch_size)
144 | for epoch in range(n_epoch_init):
145 | for step, (lr_patch, hr_patch) in enumerate(train_ds):
146 | step_time = time.time()
147 | loss = trainforinit(lr_patch, hr_patch)
148 | print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse: {:.3f} ".format(
149 | epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, float(loss)))
150 |
151 | # adversarial learning (G, D)
152 | n_step_epoch = round(train_ds_img_nums // batch_size)
153 | for epoch in range(n_epoch):
154 | for step, (lr_patch, hr_patch) in enumerate(train_ds):
155 | step_time = time.time()
156 | loss_g = trainforG(lr_patch, hr_patch)
157 | loss_d = trainforD(lr_patch, hr_patch)
158 | print(
159 | "Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_loss:{:.3f}, d_loss: {:.3f}".format(
160 | epoch, n_epoch, step, n_step_epoch, time.time() - step_time, float(loss_g), float(loss_d)))
161 | # dynamic learning rate update
162 | lr_v.step()
163 |
164 | if (epoch != 0) and (epoch % 10 == 0):
165 | G.save_weights(os.path.join(checkpoint_dir, 'g.npz'), format='npz_dict')
166 | D.save_weights(os.path.join(checkpoint_dir, 'd.npz'), format='npz_dict')
167 |
168 | def evaluate():
169 | ###====================== PRE-LOAD DATA ===========================###
170 | valid_hr_imgs = tlx.vision.load_images(path=config.VALID.hr_img_path )
171 | ###========================LOAD WEIGHTS ============================###
172 | G.load_weights(os.path.join(checkpoint_dir, 'g.npz'), format='npz_dict')
173 | G.set_eval()
174 | imid = 0 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡
175 | valid_hr_img = valid_hr_imgs[imid]
176 | valid_lr_img = np.asarray(valid_hr_img)
177 | hr_size1 = [valid_lr_img.shape[0], valid_lr_img.shape[1]]
178 | valid_lr_img = cv2.resize(valid_lr_img, dsize=(hr_size1[1] // 4, hr_size1[0] // 4))
179 | valid_lr_img_tensor = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1]
180 |
181 |
182 | valid_lr_img_tensor = np.asarray(valid_lr_img_tensor, dtype=np.float32)
183 | valid_lr_img_tensor = np.transpose(valid_lr_img_tensor,axes=[2, 0, 1])
184 | valid_lr_img_tensor = valid_lr_img_tensor[np.newaxis, :, :, :]
185 | valid_lr_img_tensor= tlx.ops.convert_to_tensor(valid_lr_img_tensor)
186 | size = [valid_lr_img.shape[0], valid_lr_img.shape[1]]
187 |
188 | out = tlx.ops.convert_to_numpy(G(valid_lr_img_tensor))
189 | out = np.asarray((out + 1) * 127.5, dtype=np.uint8)
190 | out = np.transpose(out[0], axes=[1, 2, 0])
191 | print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3)
192 | print("[*] save images")
193 | tlx.vision.save_image(out, file_name='valid_gen.png', path=save_dir)
194 | tlx.vision.save_image(valid_lr_img, file_name='valid_lr.png', path=save_dir)
195 | tlx.vision.save_image(valid_hr_img, file_name='valid_hr.png', path=save_dir)
196 | out_bicu = cv2.resize(valid_lr_img, dsize = [size[1] * 4, size[0] * 4], interpolation = cv2.INTER_CUBIC)
197 | tlx.vision.save_image(out_bicu, file_name='valid_hr_cubic.png', path=save_dir)
198 |
199 |
200 | if __name__ == '__main__':
201 | import argparse
202 |
203 | parser = argparse.ArgumentParser()
204 |
205 | parser.add_argument('--mode', type=str, default='train', help='train, eval')
206 |
207 | args = parser.parse_args()
208 |
209 | tlx.global_flag['mode'] = args.mode
210 |
211 | if tlx.global_flag['mode'] == 'train':
212 | train()
213 | elif tlx.global_flag['mode'] == 'eval':
214 | evaluate()
215 | else:
216 | raise Exception("Unknow --mode")
217 |
--------------------------------------------------------------------------------
/vgg.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | """
4 | VGG for ImageNet.
5 |
6 | Introduction
7 | ----------------
8 | VGG is a convolutional neural network model proposed by K. Simonyan and A. Zisserman
9 | from the University of Oxford in the paper "Very Deep Convolutional Networks for
10 | Large-Scale Image Recognition" . The model achieves 92.7% top-5 test accuracy in ImageNet,
11 | which is a dataset of over 14 million images belonging to 1000 classes.
12 |
13 | Download Pre-trained Model
14 | ----------------------------
15 | - Model weights in this example - vgg16_weights.npz : http://www.cs.toronto.edu/~frossard/post/vgg16/
16 | - Model weights in this example - vgg19.npy : https://media.githubusercontent.com/media/tensorlayer/pretrained-models/master/models/
17 | - Caffe VGG 16 model : https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-readme-md
18 | - Tool to convert the Caffe models to TensorFlow's : https://github.com/ethereon/caffe-tensorflow
19 |
20 | Note
21 | ------
22 | - For simplified CNN layer see "Convolutional layer (Simplified)"
23 | in read the docs website.
24 | - When feeding other images to the model be sure to properly resize or crop them
25 | beforehand. Distorted images might end up being misclassified. One way of safely
26 | feeding images of multiple sizes is by doing center cropping.
27 |
28 | """
29 |
30 | import os
31 |
32 | import numpy as np
33 |
34 | import tensorlayerx as tlx
35 | from tensorlayerx import logging
36 | from tensorlayerx.files import assign_weights, maybe_download_and_extract
37 | from tensorlayerx.nn import (BatchNorm, Conv2d, Linear, Flatten, Input, Sequential, MaxPool2d)
38 | from tensorlayerx.nn import Module
39 |
40 | __all__ = [
41 | 'VGG',
42 | 'vgg16',
43 | 'vgg19',
44 | 'VGG16',
45 | 'VGG19',
46 | # 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
47 | # 'vgg19_bn', 'vgg19',
48 | ]
49 |
50 | layer_names = [
51 | ['conv1_1', 'conv1_2'], 'pool1', ['conv2_1', 'conv2_2'], 'pool2',
52 | ['conv3_1', 'conv3_2', 'conv3_3', 'conv3_4'], 'pool3', ['conv4_1', 'conv4_2', 'conv4_3', 'conv4_4'], 'pool4',
53 | ['conv5_1', 'conv5_2', 'conv5_3', 'conv5_4'], 'pool5', 'flatten', 'fc1_relu', 'fc2_relu', 'outputs'
54 | ]
55 |
56 | cfg = {
57 | 'A': [[64], 'M', [128], 'M', [256, 256], 'M', [512, 512], 'M', [512, 512], 'M', 'F', 'fc1', 'fc2', 'O'],
58 | 'B': [[64, 64], 'M', [128, 128], 'M', [256, 256], 'M', [512, 512], 'M', [512, 512], 'M', 'F', 'fc1', 'fc2', 'O'],
59 | 'D':
60 | [
61 | [64, 64], 'M', [128, 128], 'M', [256, 256, 256], 'M', [512, 512, 512], 'M', [512, 512, 512], 'M', 'F',
62 | 'fc1', 'fc2', 'O'
63 | ],
64 | 'E':
65 | [
66 | [64, 64], 'M', [128, 128], 'M', [256, 256, 256, 256], 'M', [512, 512, 512, 512], 'M', [512, 512, 512, 512],
67 | 'M', 'F', 'fc1', 'fc2', 'O'
68 | ],
69 | }
70 |
71 | mapped_cfg = {
72 | 'vgg11': 'A',
73 | 'vgg11_bn': 'A',
74 | 'vgg13': 'B',
75 | 'vgg13_bn': 'B',
76 | 'vgg16': 'D',
77 | 'vgg16_bn': 'D',
78 | 'vgg19': 'E',
79 | 'vgg19_bn': 'E'
80 | }
81 |
82 | model_urls = {
83 | 'vgg16': 'https://git.openi.org.cn/attachments/760835b9-db71-4a00-8edd-d5ece4b6b522?type=0',
84 | 'vgg19': 'https://git.openi.org.cn/attachments/503c8a6c-705f-4fb6-ba18-03d72b6a949a?type=0'
85 | }
86 |
87 | model_saved_name = {'vgg16': 'vgg16_weights.npz', 'vgg19': 'vgg19.npy'}
88 |
89 |
90 | class VGG(Module):
91 |
92 | def __init__(self, layer_type, batch_norm=False, end_with='outputs', name=None):
93 | super(VGG, self).__init__(name=name)
94 | self.end_with = end_with
95 |
96 | config = cfg[mapped_cfg[layer_type]]
97 | self.make_layer = make_layers(config, batch_norm, end_with)
98 |
99 | def forward(self, inputs):
100 | """
101 | inputs : tensor
102 | Shape [None, 224, 224, 3], value range [0, 1].
103 | """
104 |
105 | # inputs = inputs * 255 - np.array([123.68, 116.779, 103.939], dtype=np.float32).reshape([1, 1, 1, 3])
106 | inputs = inputs * 255. - tlx.convert_to_tensor(np.array([123.68, 116.779, 103.939], dtype=np.float32).reshape(-1,1,1))
107 | out = self.make_layer(inputs)
108 | return out
109 |
110 |
111 | def make_layers(config, batch_norm=False, end_with='outputs'):
112 | layer_list = []
113 | is_end = False
114 | for layer_group_idx, layer_group in enumerate(config):
115 | if isinstance(layer_group, list):
116 | for idx, layer in enumerate(layer_group):
117 | layer_name = layer_names[layer_group_idx][idx]
118 | n_filter = layer
119 | if idx == 0:
120 | if layer_group_idx > 0:
121 | in_channels = config[layer_group_idx - 2][-1]
122 | else:
123 | in_channels = 3
124 | else:
125 | in_channels = layer_group[idx - 1]
126 | layer_list.append(
127 | Conv2d(
128 | out_channels=n_filter, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME',
129 | in_channels=in_channels, name=layer_name, data_format='channels_first'
130 | )
131 | )
132 | if batch_norm:
133 | layer_list.append(BatchNorm(num_features=n_filter, data_format='channels_first'))
134 | if layer_name == end_with:
135 | is_end = True
136 | break
137 | else:
138 | layer_name = layer_names[layer_group_idx]
139 | if layer_group == 'M':
140 | layer_list.append(MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME', name=layer_name, data_format='channels_first'))
141 | elif layer_group == 'O':
142 | layer_list.append(Linear(out_features=1000, in_features=4096, name=layer_name))
143 | elif layer_group == 'F':
144 | layer_list.append(Flatten(name='flatten'))
145 | elif layer_group == 'fc1':
146 | layer_list.append(Linear(out_features=4096, act=tlx.ReLU, in_features=512 * 7 * 7, name=layer_name))
147 | elif layer_group == 'fc2':
148 | layer_list.append(Linear(out_features=4096, act=tlx.ReLU, in_features=4096, name=layer_name))
149 | if layer_name == end_with:
150 | is_end = True
151 | if is_end:
152 | break
153 | return Sequential(layer_list)
154 |
155 | def restore_model(model, layer_type):
156 | logging.info("Restore pre-trained weights")
157 | # download weights
158 | maybe_download_and_extract(model_saved_name[layer_type], 'model', model_urls[layer_type])
159 | weights = []
160 | if layer_type == 'vgg16':
161 | npz = np.load(os.path.join('model', model_saved_name[layer_type]), allow_pickle=True)
162 | # get weight list
163 | for val in sorted(npz.items()):
164 | logging.info(" Loading weights %s in %s" % (str(val[1].shape), val[0]))
165 | weights.append(val[1])
166 | if len(model.all_weights) == len(weights):
167 | break
168 | elif layer_type == 'vgg19':
169 | npz = np.load(os.path.join('model', model_saved_name[layer_type]), allow_pickle=True, encoding='latin1').item()
170 | # get weight list
171 | for val in sorted(npz.items()):
172 | logging.info(" Loading %s in %s" % (str(val[1][0].shape), val[0]))
173 | logging.info(" Loading %s in %s" % (str(val[1][1].shape), val[0]))
174 | weights.extend(val[1])
175 | if len(model.all_weights) == len(weights):
176 | break
177 | # assign weight values
178 | if tlx.BACKEND != 'tensorflow':
179 | for i in range(len(weights)):
180 | if len(weights[i].shape) == 4:
181 | weights[i] = np.transpose(weights[i], axes=[3, 2, 0, 1])
182 | assign_weights(weights, model)
183 | del weights
184 |
185 | def vgg16(pretrained=False, end_with='outputs', mode='dynamic', name=None):
186 | """Pre-trained VGG16 model.
187 |
188 | Parameters
189 | ------------
190 | pretrained : boolean
191 | Whether to load pretrained weights. Default False.
192 | end_with : str
193 | The end point of the model. Default ``fc3_relu`` i.e. the whole model.
194 | mode : str.
195 | Model building mode, 'dynamic' or 'static'. Default 'dynamic'.
196 | name : None or str
197 | A unique layer name.
198 |
199 | Examples
200 | ---------
201 | Classify ImageNet classes with VGG16, see `tutorial_models_vgg.py `__
202 | With TensorLayer
203 | TODO Modify the usage example according to the model storage location
204 |
205 | >>> # get the whole model, without pre-trained VGG parameters
206 | >>> vgg = vgg16()
207 | >>> # get the whole model, restore pre-trained VGG parameters
208 | >>> vgg = vgg16(pretrained=True)
209 | >>> # use for inferencing
210 | >>> output = vgg(img)
211 | >>> probs = tlx.ops.softmax(output)[0].numpy()
212 |
213 | """
214 |
215 | if mode == 'dynamic':
216 | model = VGG(layer_type='vgg16', batch_norm=False, end_with=end_with, name=name)
217 | elif mode == 'static':
218 | raise NotImplementedError
219 | else:
220 | raise Exception("No such mode %s" % mode)
221 | if pretrained:
222 | restore_model(model, layer_type='vgg16')
223 | return model
224 |
225 |
226 | def vgg19(pretrained=False, end_with='outputs', mode='dynamic', name=None):
227 | """Pre-trained VGG19 model.
228 |
229 | Parameters
230 | ------------
231 | pretrained : boolean
232 | Whether to load pretrained weights. Default False.
233 | end_with : str
234 | The end point of the model. Default ``fc3_relu`` i.e. the whole model.
235 | mode : str.
236 | Model building mode, 'dynamic' or 'static'. Default 'dynamic'.
237 | name : None or str
238 | A unique layer name.
239 |
240 | Examples
241 | ---------
242 | Classify ImageNet classes with VGG19, see `tutorial_models_vgg.py `__
243 | With TensorLayerx
244 |
245 | >>> # get the whole model, without pre-trained VGG parameters
246 | >>> vgg = vgg19()
247 | >>> # get the whole model, restore pre-trained VGG parameters
248 | >>> vgg = vgg19(pretrained=True)
249 | >>> # use for inferencing
250 | >>> output = vgg(img)
251 | >>> probs = tlx.ops.softmax(output)[0].numpy()
252 |
253 | """
254 | if mode == 'dynamic':
255 | model = VGG(layer_type='vgg19', batch_norm=False, end_with=end_with, name=name)
256 | elif mode == 'static':
257 | raise NotImplementedError
258 | else:
259 | raise Exception("No such mode %s" % mode)
260 | if pretrained:
261 | restore_model(model, layer_type='vgg19')
262 | return model
263 |
264 |
265 | VGG16 = vgg16
266 | VGG19 = vgg19
267 |
268 |
--------------------------------------------------------------------------------
/srgan.py:
--------------------------------------------------------------------------------
1 | from tensorlayerx.nn import Module
2 | import tensorlayerx as tlx
3 | from tensorlayerx.nn import Conv2d, BatchNorm2d, Elementwise, SubpixelConv2d, UpSampling2d, Flatten, Sequential
4 | from tensorlayerx.nn import Linear, MaxPool2d
5 |
6 | W_init = tlx.initializers.TruncatedNormal(stddev=0.02)
7 | G_init = tlx.initializers.TruncatedNormal(mean=1.0, stddev=0.02)
8 |
9 |
10 | class ResidualBlock(Module):
11 |
12 | def __init__(self):
13 | super(ResidualBlock, self).__init__()
14 | self.conv1 = Conv2d(
15 | out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
16 | data_format='channels_first', b_init=None
17 | )
18 | self.bn1 = BatchNorm2d(num_features=64, act=tlx.ReLU, gamma_init=G_init, data_format='channels_first')
19 | self.conv2 = Conv2d(
20 | out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
21 | data_format='channels_first', b_init=None
22 | )
23 | self.bn2 = BatchNorm2d(num_features=64, act=None, gamma_init=G_init, data_format='channels_first')
24 |
25 | def forward(self, x):
26 | z = self.conv1(x)
27 | z = self.bn1(z)
28 | z = self.conv2(z)
29 | z = self.bn2(z)
30 | x = x + z
31 | return x
32 |
33 |
34 | class SRGAN_g(Module):
35 | """ Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
36 | feature maps (n) and stride (s) feature maps (n) and stride (s)
37 | """
38 |
39 | def __init__(self):
40 | super(SRGAN_g, self).__init__()
41 | self.conv1 = Conv2d(
42 | out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME', W_init=W_init,
43 | data_format='channels_first'
44 | )
45 | self.residual_block = self.make_layer()
46 | self.conv2 = Conv2d(
47 | out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init,
48 | data_format='channels_first', b_init=None
49 | )
50 | self.bn1 = BatchNorm2d(num_features=64, act=None, gamma_init=G_init, data_format='channels_first')
51 | self.conv3 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init, data_format='channels_first')
52 | self.subpiexlconv1 = SubpixelConv2d(data_format='channels_first', scale=2, act=tlx.ReLU)
53 | self.conv4 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init, data_format='channels_first')
54 | self.subpiexlconv2 = SubpixelConv2d(data_format='channels_first', scale=2, act=tlx.ReLU)
55 | self.conv5 = Conv2d(3, kernel_size=(1, 1), stride=(1, 1), act=tlx.Tanh, padding='SAME', W_init=W_init, data_format='channels_first')
56 |
57 | def make_layer(self):
58 | layer_list = []
59 | for i in range(16):
60 | layer_list.append(ResidualBlock())
61 | return Sequential(layer_list)
62 |
63 | def forward(self, x):
64 | x = self.conv1(x)
65 | temp = x
66 | x = self.residual_block(x)
67 | x = self.conv2(x)
68 | x = self.bn1(x)
69 | x = x + temp
70 | x = self.conv3(x)
71 | x = self.subpiexlconv1(x)
72 | x = self.conv4(x)
73 | x = self.subpiexlconv2(x)
74 | x = self.conv5(x)
75 |
76 | return x
77 |
78 |
79 | class SRGAN_g2(Module):
80 | """ Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
81 | feature maps (n) and stride (s) feature maps (n) and stride (s)
82 |
83 | 96x96 --> 384x384
84 |
85 | Use Resize Conv
86 | """
87 |
88 | def __init__(self):
89 | super(SRGAN_g2, self).__init__()
90 | self.conv1 = Conv2d(
91 | out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
92 | data_format='channels_first'
93 | )
94 | self.residual_block = self.make_layer()
95 | self.conv2 = Conv2d(
96 | out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init,
97 | data_format='channels_first', b_init=None
98 | )
99 | self.bn1 = BatchNorm2d(act=None, gamma_init=G_init, data_format='channels_first')
100 | self.upsample1 = UpSampling2d(data_format='channels_first', scale=(2, 2), method='bilinear')
101 | self.conv3 = Conv2d(
102 | out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init,
103 | data_format='channels_first', b_init=None
104 | )
105 | self.bn2 = BatchNorm2d(act=tlx.ReLU, gamma_init=G_init, data_format='channels_first')
106 | self.upsample2 = UpSampling2d(data_format='channels_first', scale=(4, 4), method='bilinear')
107 | self.conv4 = Conv2d(
108 | out_channels=32, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init,
109 | data_format='channels_first', b_init=None
110 | )
111 | self.bn3 = BatchNorm2d(act=tlx.ReLU, gamma_init=G_init, data_format='channels_first')
112 | self.conv5 = Conv2d(
113 | out_channels=3, kernel_size=(1, 1), stride=(1, 1), act=tlx.Tanh, padding='SAME', W_init=W_init
114 | )
115 |
116 | def make_layer(self):
117 | layer_list = []
118 | for i in range(16):
119 | layer_list.append(ResidualBlock())
120 | return Sequential(layer_list)
121 |
122 | def forward(self, x):
123 | x = self.conv1(x)
124 | temp = x
125 | x = self.residual_block(x)
126 | x = self.conv2(x)
127 | x = self.bn1(x)
128 | x = x + temp
129 | x = self.upsample1(x)
130 | x = self.conv3(x)
131 | x = self.bn2(x)
132 | x = self.upsample2(x)
133 | x = self.conv4(x)
134 | x = self.bn3(x)
135 | x = self.conv5(x)
136 | return x
137 |
138 |
139 | class SRGAN_d2(Module):
140 | """ Discriminator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network
141 | feature maps (n) and stride (s) feature maps (n) and stride (s)
142 | """
143 |
144 | def __init__(self, ):
145 | super(SRGAN_d2, self).__init__()
146 | self.conv1 = Conv2d(
147 | out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
148 | W_init=W_init, data_format='channels_first'
149 | )
150 | self.conv2 = Conv2d(
151 | out_channels=64, kernel_size=(3, 3), stride=(2, 2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
152 | W_init=W_init, data_format='channels_first', b_init=None
153 | )
154 | self.bn1 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
155 | self.conv3 = Conv2d(
156 | out_channels=128, kernel_size=(3, 3), stride=(1, 1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
157 | W_init=W_init, data_format='channels_first', b_init=None
158 | )
159 | self.bn2 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
160 | self.conv4 = Conv2d(
161 | out_channels=128, kernel_size=(3, 3), stride=(2, 2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
162 | W_init=W_init, data_format='channels_first', b_init=None
163 | )
164 | self.bn3 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
165 | self.conv5 = Conv2d(
166 | out_channels=256, kernel_size=(3, 3), stride=(1, 1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
167 | W_init=W_init, data_format='channels_first', b_init=None
168 | )
169 | self.bn4 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
170 | self.conv6 = Conv2d(
171 | out_channels=256, kernel_size=(3, 3), stride=(2, 2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
172 | W_init=W_init, data_format='channels_first', b_init=None
173 | )
174 | self.bn5 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
175 | self.conv7 = Conv2d(
176 | out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
177 | W_init=W_init, data_format='channels_first', b_init=None
178 | )
179 | self.bn6 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
180 | self.conv8 = Conv2d(
181 | out_channels=512, kernel_size=(3, 3), stride=(2, 2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME',
182 | W_init=W_init, data_format='channels_first', b_init=None
183 | )
184 | self.bn7 = BatchNorm2d(gamma_init=G_init, data_format='channels_first')
185 | self.flat = Flatten()
186 | self.dense1 = Linear(out_features=1024, act=tlx.LeakyReLU(negative_slope=0.2))
187 | self.dense2 = Linear(out_features=1)
188 |
189 | def forward(self, x):
190 | x = self.conv1(x)
191 | x = self.conv2(x)
192 | x = self.bn1(x)
193 | x = self.conv3(x)
194 | x = self.bn2(x)
195 | x = self.conv4(x)
196 | x = self.bn3(x)
197 | x = self.conv5(x)
198 | x = self.bn4(x)
199 | x = self.conv6(x)
200 | x = self.bn5(x)
201 | x = self.conv7(x)
202 | x = self.bn6(x)
203 | x = self.conv8(x)
204 | x = self.bn7(x)
205 | x = self.flat(x)
206 | x = self.dense1(x)
207 | x = self.dense2(x)
208 | logits = x
209 | n = tlx.sigmoid(x)
210 | return n, logits
211 |
212 |
213 | class SRGAN_d(Module):
214 |
215 | def __init__(self, dim=64):
216 | super(SRGAN_d, self).__init__()
217 | self.conv1 = Conv2d(
218 | out_channels=dim, kernel_size=(4, 4), stride=(2, 2), act=tlx.LeakyReLU, padding='SAME', W_init=W_init,
219 | data_format='channels_first'
220 | )
221 | self.conv2 = Conv2d(
222 | out_channels=dim * 2, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
223 | data_format='channels_first', b_init=None
224 | )
225 | self.bn1 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
226 | self.conv3 = Conv2d(
227 | out_channels=dim * 4, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
228 | data_format='channels_first', b_init=None
229 | )
230 | self.bn2 = BatchNorm2d(num_features=dim * 4, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
231 | self.conv4 = Conv2d(
232 | out_channels=dim * 8, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
233 | data_format='channels_first', b_init=None
234 | )
235 | self.bn3 = BatchNorm2d(num_features=dim * 8, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
236 | self.conv5 = Conv2d(
237 | out_channels=dim * 16, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
238 | data_format='channels_first', b_init=None
239 | )
240 | self.bn4 = BatchNorm2d(num_features=dim * 16, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
241 | self.conv6 = Conv2d(
242 | out_channels=dim * 32, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init,
243 | data_format='channels_first', b_init=None
244 | )
245 | self.bn5 = BatchNorm2d(num_features=dim * 32, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
246 | self.conv7 = Conv2d(
247 | out_channels=dim * 16, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
248 | data_format='channels_first', b_init=None
249 | )
250 | self.bn6 = BatchNorm2d(num_features=dim * 16, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
251 | self.conv8 = Conv2d(
252 | out_channels=dim * 8, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
253 | data_format='channels_first', b_init=None
254 | )
255 | self.bn7 = BatchNorm2d(num_features=dim * 8, act=None, gamma_init=G_init, data_format='channels_first')
256 | self.conv9 = Conv2d(
257 | out_channels=dim * 2, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
258 | data_format='channels_first', b_init=None
259 | )
260 | self.bn8 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
261 | self.conv10 = Conv2d(
262 | out_channels=dim * 2, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
263 | data_format='channels_first', b_init=None
264 | )
265 | self.bn9 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first')
266 | self.conv11 = Conv2d(
267 | out_channels=dim * 8, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init,
268 | data_format='channels_first', b_init=None
269 | )
270 | self.bn10 = BatchNorm2d(num_features=dim * 8, gamma_init=G_init, data_format='channels_first')
271 | self.add = Elementwise(combine_fn=tlx.add, act=tlx.LeakyReLU)
272 | self.flat = Flatten()
273 | self.dense = Linear(out_features=1, W_init=W_init)
274 |
275 | def forward(self, x):
276 |
277 | x = self.conv1(x)
278 | x = self.conv2(x)
279 | x = self.bn1(x)
280 | x = self.conv3(x)
281 | x = self.bn2(x)
282 | x = self.conv4(x)
283 | x = self.bn3(x)
284 | x = self.conv5(x)
285 | x = self.bn4(x)
286 | x = self.conv6(x)
287 | x = self.bn5(x)
288 | x = self.conv7(x)
289 | x = self.bn6(x)
290 | x = self.conv8(x)
291 | x = self.bn7(x)
292 | temp = x
293 | x = self.conv9(x)
294 | x = self.bn8(x)
295 | x = self.conv10(x)
296 | x = self.bn9(x)
297 | x = self.conv11(x)
298 | x = self.bn10(x)
299 | x = self.add([temp, x])
300 | x = self.flat(x)
301 | x = self.dense(x)
302 |
303 | return x
304 |
305 |
306 | class Vgg19_simple_api(Module):
307 |
308 | def __init__(self):
309 | super(Vgg19_simple_api, self).__init__()
310 | """ conv1 """
311 | self.conv1 = Conv2d(out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
312 | self.conv2 = Conv2d(out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
313 | self.maxpool1 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME')
314 | """ conv2 """
315 | self.conv3 = Conv2d(out_channels=128, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
316 | self.conv4 = Conv2d(out_channels=128, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
317 | self.maxpool2 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME')
318 | """ conv3 """
319 | self.conv5 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
320 | self.conv6 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
321 | self.conv7 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
322 | self.conv8 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
323 | self.maxpool3 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME')
324 | """ conv4 """
325 | self.conv9 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
326 | self.conv10 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
327 | self.conv11 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
328 | self.conv12 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
329 | self.maxpool4 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME') # (batch_size, 14, 14, 512)
330 | """ conv5 """
331 | self.conv13 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
332 | self.conv14 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
333 | self.conv15 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
334 | self.conv16 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME')
335 | self.maxpool5 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME') # (batch_size, 7, 7, 512)
336 | """ fc 6~8 """
337 | self.flat = Flatten()
338 | self.dense1 = Linear(out_features=4096, act=tlx.ReLU)
339 | self.dense2 = Linear(out_features=4096, act=tlx.ReLU)
340 | self.dense3 = Linear(out_features=1000, act=tlx.identity)
341 |
342 | def forward(self, x):
343 | x = self.conv1(x)
344 | x = self.conv2(x)
345 | x = self.maxpool1(x)
346 | x = self.conv3(x)
347 | x = self.conv4(x)
348 | x = self.maxpool2(x)
349 | x = self.conv5(x)
350 | x = self.conv6(x)
351 | x = self.conv7(x)
352 | x = self.conv8(x)
353 | x = self.maxpool3(x)
354 | x = self.conv9(x)
355 | x = self.conv10(x)
356 | x = self.conv11(x)
357 | x = self.conv12(x)
358 | x = self.maxpool4(x)
359 | conv = x
360 | x = self.conv13(x)
361 | x = self.conv14(x)
362 | x = self.conv15(x)
363 | x = self.conv16(x)
364 | x = self.maxpool5(x)
365 | x = self.flat(x)
366 | x = self.dense1(x)
367 | x = self.dense2(x)
368 | x = self.dense3(x)
369 |
370 | return x, conv
371 |
--------------------------------------------------------------------------------