├── .style.yapf ├── img ├── model.jpeg ├── SRGAN_Result.png ├── SRGAN_Result2.png └── SRGAN_Result3.png ├── .gitignore ├── requirements.txt ├── config.py ├── README.md ├── train.py ├── vgg.py └── srgan.py /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | column_limit = 160 3 | -------------------------------------------------------------------------------- /img/model.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/SRGAN/HEAD/img/model.jpeg -------------------------------------------------------------------------------- /img/SRGAN_Result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/SRGAN/HEAD/img/SRGAN_Result.png -------------------------------------------------------------------------------- /img/SRGAN_Result2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/SRGAN/HEAD/img/SRGAN_Result2.png -------------------------------------------------------------------------------- /img/SRGAN_Result3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensorlayer/SRGAN/HEAD/img/SRGAN_Result3.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ._* 2 | *.pyc 3 | .DS_Store 4 | *.npz 5 | sample/ 6 | samples/ 7 | checkpoint/ 8 | __pycache__/ 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/tensorlayer/tensorlayerx.git 2 | numpy>=1.16.1 3 | easydict==1.9 4 | opencv-python>=4.5.1.48 -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | import json 3 | 4 | config = edict() 5 | config.TRAIN = edict() 6 | config.TRAIN.batch_size = 16 # [16] use 8 if your GPU memory is small 7 | config.TRAIN.lr_init = 1e-4 8 | config.TRAIN.beta1 = 0.9 9 | 10 | ## initialize G 11 | config.TRAIN.n_epoch_init = 100 12 | # config.TRAIN.lr_decay_init = 0.1 13 | # config.TRAIN.decay_every_init = int(config.TRAIN.n_epoch_init / 2) 14 | 15 | ## adversarial learning (SRGAN) 16 | config.TRAIN.n_epoch = 2000 17 | config.TRAIN.lr_decay = 0.1 18 | config.TRAIN.decay_every = int(config.TRAIN.n_epoch / 2) 19 | 20 | ## train set location 21 | config.TRAIN.hr_img_path = 'DIV2K/DIV2K_train_HR/' 22 | config.TRAIN.lr_img_path = 'DIV2K/DIV2K_train_LR_bicubic/X4/' 23 | 24 | config.VALID = edict() 25 | ## test set location 26 | config.VALID.hr_img_path = 'DIV2K/DIV2K_valid_HR/' 27 | config.VALID.lr_img_path = 'DIV2K/DIV2K_valid_LR_bicubic/X4/' 28 | 29 | def log_config(filename, cfg): 30 | with open(filename, 'w') as f: 31 | f.write("================================================\n") 32 | f.write(json.dumps(cfg, indent=4)) 33 | f.write("\n================================================\n") 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Super Resolution Examples 2 | 3 | - Implementation of ["Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"](https://arxiv.org/abs/1609.04802) 4 | 5 | - For earlier version, please check [srgan release](https://github.com/tensorlayer/srgan/releases) and [tensorlayer](https://github.com/tensorlayer/TensorLayer). 6 | 7 | - For more computer vision applications, check [TLXCV](https://github.com/tensorlayer/TLXCV) 8 | 9 | 10 | ### SRGAN Architecture 11 | 12 | 13 | 14 |
15 | 16 |
17 |
18 | 19 |
20 | 21 |
22 |
23 | 24 | ### Prepare Data and Pre-trained VGG 25 | 26 | - 1. You need to download the pretrained VGG19 model weights in [here](https://drive.google.com/file/d/1CLw6Cn3yNI1N15HyX99_Zy9QnDcgP3q7/view?usp=sharing). 27 | - 2. You need to have the high resolution images for training. 28 | - In this experiment, I used images from [DIV2K - bicubic downscaling x4 competition](http://www.vision.ee.ethz.ch/ntire17/), so the hyper-paremeters in `config.py` (like number of epochs) are seleted basic on that dataset, if you change a larger dataset you can reduce the number of epochs. 29 | - If you dont want to use DIV2K dataset, you can also use [Yahoo MirFlickr25k](http://press.liacs.nl/mirflickr/mirdownload.html), just simply download it using `train_hr_imgs = tl.files.load_flickr25k_dataset(tag=None)` in `main.py`. 30 | - If you want to use your own images, you can set the path to your image folder via `config.TRAIN.hr_img_path` in `config.py`. 31 | 32 | 33 | 34 | ### Run 35 | 36 | 🔥🔥🔥🔥🔥🔥 You need install [TensorLayerX](https://github.com/tensorlayer/TensorLayerX#installation) at first! 37 | 38 | 🔥🔥🔥🔥🔥🔥 Please install TensorLayerX via source 39 | 40 | ```bash 41 | pip install git+https://github.com/tensorlayer/tensorlayerx.git 42 | ``` 43 | 44 | #### Train 45 | - Set your image folder in `config.py`, if you download [DIV2K - bicubic downscaling x4 competition](http://www.vision.ee.ethz.ch/ntire17/) dataset, you don't need to change it. 46 | - Other links for DIV2K, in case you can't find it : [test\_LR\_bicubic_X4](https://data.vision.ee.ethz.ch/cvl/DIV2K/validation_release/DIV2K_test_LR_bicubic_X4.zip), [train_HR](https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip), [train\_LR\_bicubic_X4](https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_LR_bicubic_X4.zip), [valid_HR](https://data.vision.ee.ethz.ch/cvl/DIV2K/validation_release/DIV2K_valid_HR.zip), [valid\_LR\_bicubic_X4](https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X4.zip). 47 | 48 | ```python 49 | config.TRAIN.img_path = "your_image_folder/" 50 | ``` 51 | Your directory structure should look like this: 52 | 53 | ``` 54 | srgan/ 55 | └── config.py 56 | └── srgan.py 57 | └── train.py 58 | └── vgg.py 59 | └── model 60 | └── vgg19.npy 61 | └── DIV2K 62 | └── DIV2K_train_HR 63 | ├── DIV2K_train_LR_bicubic 64 | ├── DIV2K_valid_HR 65 | └── DIV2K_valid_LR_bicubic 66 | 67 | ``` 68 | 69 | - Start training. 70 | 71 | ```bash 72 | python train.py 73 | ``` 74 | 75 | 🔥Modify a line of code in **train.py**, easily switch to any framework! 76 | 77 | ```python 78 | import os 79 | os.environ['TL_BACKEND'] = 'tensorflow' 80 | # os.environ['TL_BACKEND'] = 'mindspore' 81 | # os.environ['TL_BACKEND'] = 'paddle' 82 | # os.environ['TL_BACKEND'] = 'pytorch' 83 | ``` 84 | 🚧 We will support PyTorch as Backend soon. 85 | 86 | 87 | #### Evaluation. 88 | 89 | 🔥 We have trained SRGAN on DIV2K dataset. 90 | 🔥 Download model weights as follows. 91 | 92 | | | SRGAN_g | SRGAN_d | 93 | |------------- |---------|---------| 94 | | TensorFlow | [Baidu](https://pan.baidu.com/s/118uUg3oce_3NZQCIWHVjmA?pwd=p9li), [Googledrive](https://drive.google.com/file/d/1GlU9At-5XEDilgnt326fyClvZB_fsaFZ/view?usp=sharing) |[Baidu](https://pan.baidu.com/s/1DOpGzDJY5PyusKzaKqbLOg?pwd=g2iy), [Googledrive](https://drive.google.com/file/d/1RpOtVcVK-yxnVhNH4KSjnXHDvuU_pq3j/view?usp=sharing) | 95 | | PaddlePaddle | [Baidu](https://pan.baidu.com/s/1ngBpleV5vQZQqNE_8djDIg?pwd=s8wc), [Googledrive](https://drive.google.com/file/d/1GRNt_ZsgorB19qvwN5gE6W9a_bIPLkg1/view?usp=sharing) | [Baidu](https://pan.baidu.com/s/1nSefLNRanFImf1DskSVpCg?pwd=befc), [Googledrive](https://drive.google.com/file/d/1Jf6W1ZPdgtmUSfrQ5mMZDB_hOCVU-zFo/view?usp=sharing) | 96 | | MindSpore | 🚧Coming soon! | 🚧Coming soon! | 97 | | PyTorch | 🚧Coming soon! | 🚧Coming soon! | 98 | 99 | 100 | Download weights file and put weights under the folder srgan/models/. 101 | 102 | Your directory structure should look like this: 103 | 104 | ``` 105 | srgan/ 106 | └── config.py 107 | └── srgan.py 108 | └── train.py 109 | └── vgg.py 110 | └── model 111 | └── vgg19.npy 112 | └── DIV2K 113 | ├── DIV2K_train_HR 114 | ├── DIV2K_train_LR_bicubic 115 | ├── DIV2K_valid_HR 116 | └── DIV2K_valid_LR_bicubic 117 | └── models 118 | ├── g.npz # You should rename the weigths file. 119 | └── d.npz # If you set os.environ['TL_BACKEND'] = 'tensorflow',you should rename srgan-g-tensorflow.npz to g.npz . 120 | 121 | ``` 122 | 123 | - Start evaluation. 124 | ```bash 125 | python train.py --mode=eval 126 | ``` 127 | 128 | Results will be saved under the folder srgan/samples/. 129 | 130 | ### Results 131 | 132 | 133 |
134 | 135 |
136 |
137 | 138 | 139 | ### Reference 140 | * [1] [Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network](https://arxiv.org/abs/1609.04802) 141 | * [2] [Is the deconvolution layer the same as a convolutional layer ?](https://arxiv.org/abs/1609.07009) 142 | 143 | 144 | 145 | ### Citation 146 | If you find this project useful, we would be grateful if you cite the TensorLayer paper: 147 | 148 | ``` 149 | @article{tensorlayer2017, 150 | author = {Dong, Hao and Supratak, Akara and Mai, Luo and Liu, Fangde and Oehmichen, Axel and Yu, Simiao and Guo, Yike}, 151 | journal = {ACM Multimedia}, 152 | title = {{TensorLayer: A Versatile Library for Efficient Deep Learning Development}}, 153 | url = {http://tensorlayer.org}, 154 | year = {2017} 155 | } 156 | 157 | @inproceedings{tensorlayer2021, 158 | title={TensorLayer 3.0: A Deep Learning Library Compatible With Multiple Backends}, 159 | author={Lai, Cheng and Han, Jiarong and Dong, Hao}, 160 | booktitle={2021 IEEE International Conference on Multimedia \& Expo Workshops (ICMEW)}, 161 | pages={1--3}, 162 | year={2021}, 163 | organization={IEEE} 164 | } 165 | ``` 166 | 167 | ### Other Projects 168 | 169 | - [Style Transfer](https://github.com/tensorlayer/adaptive-style-transfer) 170 | - [Pose Estimation](https://github.com/tensorlayer/openpose) 171 | 172 | ### Discussion 173 | 174 | - [TensorLayer Slack](https://join.slack.com/t/tensorlayer/shared_invite/enQtMjUyMjczMzU2Njg4LWI0MWU0MDFkOWY2YjQ4YjVhMzI5M2VlZmE4YTNhNGY1NjZhMzUwMmQ2MTc0YWRjMjQzMjdjMTg2MWQ2ZWJhYzc) 175 | - [TensorLayer WeChat](https://github.com/tensorlayer/tensorlayer-chinese/blob/master/docs/wechat_group.md) 176 | 177 | ### License 178 | 179 | - For academic and non-commercial use only. 180 | - For commercial use, please contact tensorlayer@gmail.com. 181 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TL_BACKEND'] = 'tensorflow' # Just modify this line, easily switch to any framework! PyTorch will coming soon! 3 | # os.environ['TL_BACKEND'] = 'mindspore' 4 | # os.environ['TL_BACKEND'] = 'paddle' 5 | # os.environ['TL_BACKEND'] = 'torch' 6 | import time 7 | import numpy as np 8 | import tensorlayerx as tlx 9 | from tensorlayerx.dataflow import Dataset, DataLoader 10 | from srgan import SRGAN_g, SRGAN_d 11 | from config import config 12 | from tensorlayerx.vision.transforms import Compose, RandomCrop, Normalize, RandomFlipHorizontal, Resize, HWC2CHW 13 | import vgg 14 | from tensorlayerx.model import TrainOneStep 15 | from tensorlayerx.nn import Module 16 | import cv2 17 | tlx.set_device('GPU') 18 | 19 | ###====================== HYPER-PARAMETERS ===========================### 20 | batch_size = 8 21 | n_epoch_init = config.TRAIN.n_epoch_init 22 | n_epoch = config.TRAIN.n_epoch 23 | # create folders to save result images and trained models 24 | save_dir = "samples" 25 | tlx.files.exists_or_mkdir(save_dir) 26 | checkpoint_dir = "models" 27 | tlx.files.exists_or_mkdir(checkpoint_dir) 28 | 29 | hr_transform = Compose([ 30 | RandomCrop(size=(384, 384)), 31 | RandomFlipHorizontal(), 32 | ]) 33 | nor = Compose([Normalize(mean=(127.5), std=(127.5), data_format='HWC'), 34 | HWC2CHW()]) 35 | lr_transform = Resize(size=(96, 96)) 36 | 37 | train_hr_imgs = tlx.vision.load_images(path=config.TRAIN.hr_img_path, n_threads = 32) 38 | 39 | class TrainData(Dataset): 40 | 41 | def __init__(self, hr_trans=hr_transform, lr_trans=lr_transform): 42 | self.train_hr_imgs = train_hr_imgs 43 | self.hr_trans = hr_trans 44 | self.lr_trans = lr_trans 45 | 46 | def __getitem__(self, index): 47 | img = self.train_hr_imgs[index] 48 | hr_patch = self.hr_trans(img) 49 | lr_patch = self.lr_trans(hr_patch) 50 | return nor(lr_patch), nor(hr_patch) 51 | 52 | def __len__(self): 53 | return len(self.train_hr_imgs) 54 | 55 | 56 | class WithLoss_init(Module): 57 | def __init__(self, G_net, loss_fn): 58 | super(WithLoss_init, self).__init__() 59 | self.net = G_net 60 | self.loss_fn = loss_fn 61 | 62 | def forward(self, lr, hr): 63 | out = self.net(lr) 64 | loss = self.loss_fn(out, hr) 65 | return loss 66 | 67 | 68 | class WithLoss_D(Module): 69 | def __init__(self, D_net, G_net, loss_fn): 70 | super(WithLoss_D, self).__init__() 71 | self.D_net = D_net 72 | self.G_net = G_net 73 | self.loss_fn = loss_fn 74 | 75 | def forward(self, lr, hr): 76 | fake_patchs = self.G_net(lr) 77 | logits_fake = self.D_net(fake_patchs) 78 | logits_real = self.D_net(hr) 79 | d_loss1 = self.loss_fn(logits_real, tlx.ones_like(logits_real)) 80 | d_loss1 = tlx.ops.reduce_mean(d_loss1) 81 | d_loss2 = self.loss_fn(logits_fake, tlx.zeros_like(logits_fake)) 82 | d_loss2 = tlx.ops.reduce_mean(d_loss2) 83 | d_loss = d_loss1 + d_loss2 84 | return d_loss 85 | 86 | 87 | class WithLoss_G(Module): 88 | def __init__(self, D_net, G_net, vgg, loss_fn1, loss_fn2): 89 | super(WithLoss_G, self).__init__() 90 | self.D_net = D_net 91 | self.G_net = G_net 92 | self.vgg = vgg 93 | self.loss_fn1 = loss_fn1 94 | self.loss_fn2 = loss_fn2 95 | 96 | def forward(self, lr, hr): 97 | fake_patchs = self.G_net(lr) 98 | logits_fake = self.D_net(fake_patchs) 99 | feature_fake = self.vgg((fake_patchs + 1) / 2.) 100 | feature_real = self.vgg((hr + 1) / 2.) 101 | g_gan_loss = 1e-3 * self.loss_fn1(logits_fake, tlx.ones_like(logits_fake)) 102 | g_gan_loss = tlx.ops.reduce_mean(g_gan_loss) 103 | mse_loss = self.loss_fn2(fake_patchs, hr) 104 | vgg_loss = 2e-6 * self.loss_fn2(feature_fake, feature_real) 105 | g_loss = mse_loss + vgg_loss + g_gan_loss 106 | return g_loss 107 | 108 | 109 | G = SRGAN_g() 110 | D = SRGAN_d() 111 | VGG = vgg.VGG19(pretrained=True, end_with='pool4', mode='dynamic') 112 | # automatic init layers weights shape with input tensor. 113 | # Calculating and filling 'in_channels' of each layer is a very troublesome thing. 114 | # So, just use 'init_build' with input shape. 'in_channels' of each layer will be automaticlly set. 115 | G.init_build(tlx.nn.Input(shape=(8, 3, 96, 96))) 116 | D.init_build(tlx.nn.Input(shape=(8, 3, 384, 384))) 117 | 118 | 119 | def train(): 120 | G.set_train() 121 | D.set_train() 122 | VGG.set_eval() 123 | train_ds = TrainData() 124 | train_ds_img_nums = len(train_ds) 125 | train_ds = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True) 126 | 127 | lr_v = tlx.optimizers.lr.StepDecay(learning_rate=0.05, step_size=1000, gamma=0.1, last_epoch=-1, verbose=True) 128 | g_optimizer_init = tlx.optimizers.Momentum(lr_v, 0.9) 129 | g_optimizer = tlx.optimizers.Momentum(lr_v, 0.9) 130 | d_optimizer = tlx.optimizers.Momentum(lr_v, 0.9) 131 | g_weights = G.trainable_weights 132 | d_weights = D.trainable_weights 133 | net_with_loss_init = WithLoss_init(G, loss_fn=tlx.losses.mean_squared_error) 134 | net_with_loss_D = WithLoss_D(D_net=D, G_net=G, loss_fn=tlx.losses.sigmoid_cross_entropy) 135 | net_with_loss_G = WithLoss_G(D_net=D, G_net=G, vgg=VGG, loss_fn1=tlx.losses.sigmoid_cross_entropy, 136 | loss_fn2=tlx.losses.mean_squared_error) 137 | 138 | trainforinit = TrainOneStep(net_with_loss_init, optimizer=g_optimizer_init, train_weights=g_weights) 139 | trainforG = TrainOneStep(net_with_loss_G, optimizer=g_optimizer, train_weights=g_weights) 140 | trainforD = TrainOneStep(net_with_loss_D, optimizer=d_optimizer, train_weights=d_weights) 141 | 142 | # initialize learning (G) 143 | n_step_epoch = round(train_ds_img_nums // batch_size) 144 | for epoch in range(n_epoch_init): 145 | for step, (lr_patch, hr_patch) in enumerate(train_ds): 146 | step_time = time.time() 147 | loss = trainforinit(lr_patch, hr_patch) 148 | print("Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, mse: {:.3f} ".format( 149 | epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, float(loss))) 150 | 151 | # adversarial learning (G, D) 152 | n_step_epoch = round(train_ds_img_nums // batch_size) 153 | for epoch in range(n_epoch): 154 | for step, (lr_patch, hr_patch) in enumerate(train_ds): 155 | step_time = time.time() 156 | loss_g = trainforG(lr_patch, hr_patch) 157 | loss_d = trainforD(lr_patch, hr_patch) 158 | print( 159 | "Epoch: [{}/{}] step: [{}/{}] time: {:.3f}s, g_loss:{:.3f}, d_loss: {:.3f}".format( 160 | epoch, n_epoch, step, n_step_epoch, time.time() - step_time, float(loss_g), float(loss_d))) 161 | # dynamic learning rate update 162 | lr_v.step() 163 | 164 | if (epoch != 0) and (epoch % 10 == 0): 165 | G.save_weights(os.path.join(checkpoint_dir, 'g.npz'), format='npz_dict') 166 | D.save_weights(os.path.join(checkpoint_dir, 'd.npz'), format='npz_dict') 167 | 168 | def evaluate(): 169 | ###====================== PRE-LOAD DATA ===========================### 170 | valid_hr_imgs = tlx.vision.load_images(path=config.VALID.hr_img_path ) 171 | ###========================LOAD WEIGHTS ============================### 172 | G.load_weights(os.path.join(checkpoint_dir, 'g.npz'), format='npz_dict') 173 | G.set_eval() 174 | imid = 0 # 0: 企鹅 81: 蝴蝶 53: 鸟 64: 古堡 175 | valid_hr_img = valid_hr_imgs[imid] 176 | valid_lr_img = np.asarray(valid_hr_img) 177 | hr_size1 = [valid_lr_img.shape[0], valid_lr_img.shape[1]] 178 | valid_lr_img = cv2.resize(valid_lr_img, dsize=(hr_size1[1] // 4, hr_size1[0] // 4)) 179 | valid_lr_img_tensor = (valid_lr_img / 127.5) - 1 # rescale to [-1, 1] 180 | 181 | 182 | valid_lr_img_tensor = np.asarray(valid_lr_img_tensor, dtype=np.float32) 183 | valid_lr_img_tensor = np.transpose(valid_lr_img_tensor,axes=[2, 0, 1]) 184 | valid_lr_img_tensor = valid_lr_img_tensor[np.newaxis, :, :, :] 185 | valid_lr_img_tensor= tlx.ops.convert_to_tensor(valid_lr_img_tensor) 186 | size = [valid_lr_img.shape[0], valid_lr_img.shape[1]] 187 | 188 | out = tlx.ops.convert_to_numpy(G(valid_lr_img_tensor)) 189 | out = np.asarray((out + 1) * 127.5, dtype=np.uint8) 190 | out = np.transpose(out[0], axes=[1, 2, 0]) 191 | print("LR size: %s / generated HR size: %s" % (size, out.shape)) # LR size: (339, 510, 3) / gen HR size: (1, 1356, 2040, 3) 192 | print("[*] save images") 193 | tlx.vision.save_image(out, file_name='valid_gen.png', path=save_dir) 194 | tlx.vision.save_image(valid_lr_img, file_name='valid_lr.png', path=save_dir) 195 | tlx.vision.save_image(valid_hr_img, file_name='valid_hr.png', path=save_dir) 196 | out_bicu = cv2.resize(valid_lr_img, dsize = [size[1] * 4, size[0] * 4], interpolation = cv2.INTER_CUBIC) 197 | tlx.vision.save_image(out_bicu, file_name='valid_hr_cubic.png', path=save_dir) 198 | 199 | 200 | if __name__ == '__main__': 201 | import argparse 202 | 203 | parser = argparse.ArgumentParser() 204 | 205 | parser.add_argument('--mode', type=str, default='train', help='train, eval') 206 | 207 | args = parser.parse_args() 208 | 209 | tlx.global_flag['mode'] = args.mode 210 | 211 | if tlx.global_flag['mode'] == 'train': 212 | train() 213 | elif tlx.global_flag['mode'] == 'eval': 214 | evaluate() 215 | else: 216 | raise Exception("Unknow --mode") 217 | -------------------------------------------------------------------------------- /vgg.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | """ 4 | VGG for ImageNet. 5 | 6 | Introduction 7 | ---------------- 8 | VGG is a convolutional neural network model proposed by K. Simonyan and A. Zisserman 9 | from the University of Oxford in the paper "Very Deep Convolutional Networks for 10 | Large-Scale Image Recognition" . The model achieves 92.7% top-5 test accuracy in ImageNet, 11 | which is a dataset of over 14 million images belonging to 1000 classes. 12 | 13 | Download Pre-trained Model 14 | ---------------------------- 15 | - Model weights in this example - vgg16_weights.npz : http://www.cs.toronto.edu/~frossard/post/vgg16/ 16 | - Model weights in this example - vgg19.npy : https://media.githubusercontent.com/media/tensorlayer/pretrained-models/master/models/ 17 | - Caffe VGG 16 model : https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-readme-md 18 | - Tool to convert the Caffe models to TensorFlow's : https://github.com/ethereon/caffe-tensorflow 19 | 20 | Note 21 | ------ 22 | - For simplified CNN layer see "Convolutional layer (Simplified)" 23 | in read the docs website. 24 | - When feeding other images to the model be sure to properly resize or crop them 25 | beforehand. Distorted images might end up being misclassified. One way of safely 26 | feeding images of multiple sizes is by doing center cropping. 27 | 28 | """ 29 | 30 | import os 31 | 32 | import numpy as np 33 | 34 | import tensorlayerx as tlx 35 | from tensorlayerx import logging 36 | from tensorlayerx.files import assign_weights, maybe_download_and_extract 37 | from tensorlayerx.nn import (BatchNorm, Conv2d, Linear, Flatten, Input, Sequential, MaxPool2d) 38 | from tensorlayerx.nn import Module 39 | 40 | __all__ = [ 41 | 'VGG', 42 | 'vgg16', 43 | 'vgg19', 44 | 'VGG16', 45 | 'VGG19', 46 | # 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 47 | # 'vgg19_bn', 'vgg19', 48 | ] 49 | 50 | layer_names = [ 51 | ['conv1_1', 'conv1_2'], 'pool1', ['conv2_1', 'conv2_2'], 'pool2', 52 | ['conv3_1', 'conv3_2', 'conv3_3', 'conv3_4'], 'pool3', ['conv4_1', 'conv4_2', 'conv4_3', 'conv4_4'], 'pool4', 53 | ['conv5_1', 'conv5_2', 'conv5_3', 'conv5_4'], 'pool5', 'flatten', 'fc1_relu', 'fc2_relu', 'outputs' 54 | ] 55 | 56 | cfg = { 57 | 'A': [[64], 'M', [128], 'M', [256, 256], 'M', [512, 512], 'M', [512, 512], 'M', 'F', 'fc1', 'fc2', 'O'], 58 | 'B': [[64, 64], 'M', [128, 128], 'M', [256, 256], 'M', [512, 512], 'M', [512, 512], 'M', 'F', 'fc1', 'fc2', 'O'], 59 | 'D': 60 | [ 61 | [64, 64], 'M', [128, 128], 'M', [256, 256, 256], 'M', [512, 512, 512], 'M', [512, 512, 512], 'M', 'F', 62 | 'fc1', 'fc2', 'O' 63 | ], 64 | 'E': 65 | [ 66 | [64, 64], 'M', [128, 128], 'M', [256, 256, 256, 256], 'M', [512, 512, 512, 512], 'M', [512, 512, 512, 512], 67 | 'M', 'F', 'fc1', 'fc2', 'O' 68 | ], 69 | } 70 | 71 | mapped_cfg = { 72 | 'vgg11': 'A', 73 | 'vgg11_bn': 'A', 74 | 'vgg13': 'B', 75 | 'vgg13_bn': 'B', 76 | 'vgg16': 'D', 77 | 'vgg16_bn': 'D', 78 | 'vgg19': 'E', 79 | 'vgg19_bn': 'E' 80 | } 81 | 82 | model_urls = { 83 | 'vgg16': 'https://git.openi.org.cn/attachments/760835b9-db71-4a00-8edd-d5ece4b6b522?type=0', 84 | 'vgg19': 'https://git.openi.org.cn/attachments/503c8a6c-705f-4fb6-ba18-03d72b6a949a?type=0' 85 | } 86 | 87 | model_saved_name = {'vgg16': 'vgg16_weights.npz', 'vgg19': 'vgg19.npy'} 88 | 89 | 90 | class VGG(Module): 91 | 92 | def __init__(self, layer_type, batch_norm=False, end_with='outputs', name=None): 93 | super(VGG, self).__init__(name=name) 94 | self.end_with = end_with 95 | 96 | config = cfg[mapped_cfg[layer_type]] 97 | self.make_layer = make_layers(config, batch_norm, end_with) 98 | 99 | def forward(self, inputs): 100 | """ 101 | inputs : tensor 102 | Shape [None, 224, 224, 3], value range [0, 1]. 103 | """ 104 | 105 | # inputs = inputs * 255 - np.array([123.68, 116.779, 103.939], dtype=np.float32).reshape([1, 1, 1, 3]) 106 | inputs = inputs * 255. - tlx.convert_to_tensor(np.array([123.68, 116.779, 103.939], dtype=np.float32).reshape(-1,1,1)) 107 | out = self.make_layer(inputs) 108 | return out 109 | 110 | 111 | def make_layers(config, batch_norm=False, end_with='outputs'): 112 | layer_list = [] 113 | is_end = False 114 | for layer_group_idx, layer_group in enumerate(config): 115 | if isinstance(layer_group, list): 116 | for idx, layer in enumerate(layer_group): 117 | layer_name = layer_names[layer_group_idx][idx] 118 | n_filter = layer 119 | if idx == 0: 120 | if layer_group_idx > 0: 121 | in_channels = config[layer_group_idx - 2][-1] 122 | else: 123 | in_channels = 3 124 | else: 125 | in_channels = layer_group[idx - 1] 126 | layer_list.append( 127 | Conv2d( 128 | out_channels=n_filter, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME', 129 | in_channels=in_channels, name=layer_name, data_format='channels_first' 130 | ) 131 | ) 132 | if batch_norm: 133 | layer_list.append(BatchNorm(num_features=n_filter, data_format='channels_first')) 134 | if layer_name == end_with: 135 | is_end = True 136 | break 137 | else: 138 | layer_name = layer_names[layer_group_idx] 139 | if layer_group == 'M': 140 | layer_list.append(MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME', name=layer_name, data_format='channels_first')) 141 | elif layer_group == 'O': 142 | layer_list.append(Linear(out_features=1000, in_features=4096, name=layer_name)) 143 | elif layer_group == 'F': 144 | layer_list.append(Flatten(name='flatten')) 145 | elif layer_group == 'fc1': 146 | layer_list.append(Linear(out_features=4096, act=tlx.ReLU, in_features=512 * 7 * 7, name=layer_name)) 147 | elif layer_group == 'fc2': 148 | layer_list.append(Linear(out_features=4096, act=tlx.ReLU, in_features=4096, name=layer_name)) 149 | if layer_name == end_with: 150 | is_end = True 151 | if is_end: 152 | break 153 | return Sequential(layer_list) 154 | 155 | def restore_model(model, layer_type): 156 | logging.info("Restore pre-trained weights") 157 | # download weights 158 | maybe_download_and_extract(model_saved_name[layer_type], 'model', model_urls[layer_type]) 159 | weights = [] 160 | if layer_type == 'vgg16': 161 | npz = np.load(os.path.join('model', model_saved_name[layer_type]), allow_pickle=True) 162 | # get weight list 163 | for val in sorted(npz.items()): 164 | logging.info(" Loading weights %s in %s" % (str(val[1].shape), val[0])) 165 | weights.append(val[1]) 166 | if len(model.all_weights) == len(weights): 167 | break 168 | elif layer_type == 'vgg19': 169 | npz = np.load(os.path.join('model', model_saved_name[layer_type]), allow_pickle=True, encoding='latin1').item() 170 | # get weight list 171 | for val in sorted(npz.items()): 172 | logging.info(" Loading %s in %s" % (str(val[1][0].shape), val[0])) 173 | logging.info(" Loading %s in %s" % (str(val[1][1].shape), val[0])) 174 | weights.extend(val[1]) 175 | if len(model.all_weights) == len(weights): 176 | break 177 | # assign weight values 178 | if tlx.BACKEND != 'tensorflow': 179 | for i in range(len(weights)): 180 | if len(weights[i].shape) == 4: 181 | weights[i] = np.transpose(weights[i], axes=[3, 2, 0, 1]) 182 | assign_weights(weights, model) 183 | del weights 184 | 185 | def vgg16(pretrained=False, end_with='outputs', mode='dynamic', name=None): 186 | """Pre-trained VGG16 model. 187 | 188 | Parameters 189 | ------------ 190 | pretrained : boolean 191 | Whether to load pretrained weights. Default False. 192 | end_with : str 193 | The end point of the model. Default ``fc3_relu`` i.e. the whole model. 194 | mode : str. 195 | Model building mode, 'dynamic' or 'static'. Default 'dynamic'. 196 | name : None or str 197 | A unique layer name. 198 | 199 | Examples 200 | --------- 201 | Classify ImageNet classes with VGG16, see `tutorial_models_vgg.py `__ 202 | With TensorLayer 203 | TODO Modify the usage example according to the model storage location 204 | 205 | >>> # get the whole model, without pre-trained VGG parameters 206 | >>> vgg = vgg16() 207 | >>> # get the whole model, restore pre-trained VGG parameters 208 | >>> vgg = vgg16(pretrained=True) 209 | >>> # use for inferencing 210 | >>> output = vgg(img) 211 | >>> probs = tlx.ops.softmax(output)[0].numpy() 212 | 213 | """ 214 | 215 | if mode == 'dynamic': 216 | model = VGG(layer_type='vgg16', batch_norm=False, end_with=end_with, name=name) 217 | elif mode == 'static': 218 | raise NotImplementedError 219 | else: 220 | raise Exception("No such mode %s" % mode) 221 | if pretrained: 222 | restore_model(model, layer_type='vgg16') 223 | return model 224 | 225 | 226 | def vgg19(pretrained=False, end_with='outputs', mode='dynamic', name=None): 227 | """Pre-trained VGG19 model. 228 | 229 | Parameters 230 | ------------ 231 | pretrained : boolean 232 | Whether to load pretrained weights. Default False. 233 | end_with : str 234 | The end point of the model. Default ``fc3_relu`` i.e. the whole model. 235 | mode : str. 236 | Model building mode, 'dynamic' or 'static'. Default 'dynamic'. 237 | name : None or str 238 | A unique layer name. 239 | 240 | Examples 241 | --------- 242 | Classify ImageNet classes with VGG19, see `tutorial_models_vgg.py `__ 243 | With TensorLayerx 244 | 245 | >>> # get the whole model, without pre-trained VGG parameters 246 | >>> vgg = vgg19() 247 | >>> # get the whole model, restore pre-trained VGG parameters 248 | >>> vgg = vgg19(pretrained=True) 249 | >>> # use for inferencing 250 | >>> output = vgg(img) 251 | >>> probs = tlx.ops.softmax(output)[0].numpy() 252 | 253 | """ 254 | if mode == 'dynamic': 255 | model = VGG(layer_type='vgg19', batch_norm=False, end_with=end_with, name=name) 256 | elif mode == 'static': 257 | raise NotImplementedError 258 | else: 259 | raise Exception("No such mode %s" % mode) 260 | if pretrained: 261 | restore_model(model, layer_type='vgg19') 262 | return model 263 | 264 | 265 | VGG16 = vgg16 266 | VGG19 = vgg19 267 | 268 | -------------------------------------------------------------------------------- /srgan.py: -------------------------------------------------------------------------------- 1 | from tensorlayerx.nn import Module 2 | import tensorlayerx as tlx 3 | from tensorlayerx.nn import Conv2d, BatchNorm2d, Elementwise, SubpixelConv2d, UpSampling2d, Flatten, Sequential 4 | from tensorlayerx.nn import Linear, MaxPool2d 5 | 6 | W_init = tlx.initializers.TruncatedNormal(stddev=0.02) 7 | G_init = tlx.initializers.TruncatedNormal(mean=1.0, stddev=0.02) 8 | 9 | 10 | class ResidualBlock(Module): 11 | 12 | def __init__(self): 13 | super(ResidualBlock, self).__init__() 14 | self.conv1 = Conv2d( 15 | out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init, 16 | data_format='channels_first', b_init=None 17 | ) 18 | self.bn1 = BatchNorm2d(num_features=64, act=tlx.ReLU, gamma_init=G_init, data_format='channels_first') 19 | self.conv2 = Conv2d( 20 | out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init, 21 | data_format='channels_first', b_init=None 22 | ) 23 | self.bn2 = BatchNorm2d(num_features=64, act=None, gamma_init=G_init, data_format='channels_first') 24 | 25 | def forward(self, x): 26 | z = self.conv1(x) 27 | z = self.bn1(z) 28 | z = self.conv2(z) 29 | z = self.bn2(z) 30 | x = x + z 31 | return x 32 | 33 | 34 | class SRGAN_g(Module): 35 | """ Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 36 | feature maps (n) and stride (s) feature maps (n) and stride (s) 37 | """ 38 | 39 | def __init__(self): 40 | super(SRGAN_g, self).__init__() 41 | self.conv1 = Conv2d( 42 | out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME', W_init=W_init, 43 | data_format='channels_first' 44 | ) 45 | self.residual_block = self.make_layer() 46 | self.conv2 = Conv2d( 47 | out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init, 48 | data_format='channels_first', b_init=None 49 | ) 50 | self.bn1 = BatchNorm2d(num_features=64, act=None, gamma_init=G_init, data_format='channels_first') 51 | self.conv3 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init, data_format='channels_first') 52 | self.subpiexlconv1 = SubpixelConv2d(data_format='channels_first', scale=2, act=tlx.ReLU) 53 | self.conv4 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init, data_format='channels_first') 54 | self.subpiexlconv2 = SubpixelConv2d(data_format='channels_first', scale=2, act=tlx.ReLU) 55 | self.conv5 = Conv2d(3, kernel_size=(1, 1), stride=(1, 1), act=tlx.Tanh, padding='SAME', W_init=W_init, data_format='channels_first') 56 | 57 | def make_layer(self): 58 | layer_list = [] 59 | for i in range(16): 60 | layer_list.append(ResidualBlock()) 61 | return Sequential(layer_list) 62 | 63 | def forward(self, x): 64 | x = self.conv1(x) 65 | temp = x 66 | x = self.residual_block(x) 67 | x = self.conv2(x) 68 | x = self.bn1(x) 69 | x = x + temp 70 | x = self.conv3(x) 71 | x = self.subpiexlconv1(x) 72 | x = self.conv4(x) 73 | x = self.subpiexlconv2(x) 74 | x = self.conv5(x) 75 | 76 | return x 77 | 78 | 79 | class SRGAN_g2(Module): 80 | """ Generator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 81 | feature maps (n) and stride (s) feature maps (n) and stride (s) 82 | 83 | 96x96 --> 384x384 84 | 85 | Use Resize Conv 86 | """ 87 | 88 | def __init__(self): 89 | super(SRGAN_g2, self).__init__() 90 | self.conv1 = Conv2d( 91 | out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init, 92 | data_format='channels_first' 93 | ) 94 | self.residual_block = self.make_layer() 95 | self.conv2 = Conv2d( 96 | out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init, 97 | data_format='channels_first', b_init=None 98 | ) 99 | self.bn1 = BatchNorm2d(act=None, gamma_init=G_init, data_format='channels_first') 100 | self.upsample1 = UpSampling2d(data_format='channels_first', scale=(2, 2), method='bilinear') 101 | self.conv3 = Conv2d( 102 | out_channels=64, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init, 103 | data_format='channels_first', b_init=None 104 | ) 105 | self.bn2 = BatchNorm2d(act=tlx.ReLU, gamma_init=G_init, data_format='channels_first') 106 | self.upsample2 = UpSampling2d(data_format='channels_first', scale=(4, 4), method='bilinear') 107 | self.conv4 = Conv2d( 108 | out_channels=32, kernel_size=(3, 3), stride=(1, 1), padding='SAME', W_init=W_init, 109 | data_format='channels_first', b_init=None 110 | ) 111 | self.bn3 = BatchNorm2d(act=tlx.ReLU, gamma_init=G_init, data_format='channels_first') 112 | self.conv5 = Conv2d( 113 | out_channels=3, kernel_size=(1, 1), stride=(1, 1), act=tlx.Tanh, padding='SAME', W_init=W_init 114 | ) 115 | 116 | def make_layer(self): 117 | layer_list = [] 118 | for i in range(16): 119 | layer_list.append(ResidualBlock()) 120 | return Sequential(layer_list) 121 | 122 | def forward(self, x): 123 | x = self.conv1(x) 124 | temp = x 125 | x = self.residual_block(x) 126 | x = self.conv2(x) 127 | x = self.bn1(x) 128 | x = x + temp 129 | x = self.upsample1(x) 130 | x = self.conv3(x) 131 | x = self.bn2(x) 132 | x = self.upsample2(x) 133 | x = self.conv4(x) 134 | x = self.bn3(x) 135 | x = self.conv5(x) 136 | return x 137 | 138 | 139 | class SRGAN_d2(Module): 140 | """ Discriminator in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network 141 | feature maps (n) and stride (s) feature maps (n) and stride (s) 142 | """ 143 | 144 | def __init__(self, ): 145 | super(SRGAN_d2, self).__init__() 146 | self.conv1 = Conv2d( 147 | out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME', 148 | W_init=W_init, data_format='channels_first' 149 | ) 150 | self.conv2 = Conv2d( 151 | out_channels=64, kernel_size=(3, 3), stride=(2, 2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME', 152 | W_init=W_init, data_format='channels_first', b_init=None 153 | ) 154 | self.bn1 = BatchNorm2d(gamma_init=G_init, data_format='channels_first') 155 | self.conv3 = Conv2d( 156 | out_channels=128, kernel_size=(3, 3), stride=(1, 1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME', 157 | W_init=W_init, data_format='channels_first', b_init=None 158 | ) 159 | self.bn2 = BatchNorm2d(gamma_init=G_init, data_format='channels_first') 160 | self.conv4 = Conv2d( 161 | out_channels=128, kernel_size=(3, 3), stride=(2, 2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME', 162 | W_init=W_init, data_format='channels_first', b_init=None 163 | ) 164 | self.bn3 = BatchNorm2d(gamma_init=G_init, data_format='channels_first') 165 | self.conv5 = Conv2d( 166 | out_channels=256, kernel_size=(3, 3), stride=(1, 1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME', 167 | W_init=W_init, data_format='channels_first', b_init=None 168 | ) 169 | self.bn4 = BatchNorm2d(gamma_init=G_init, data_format='channels_first') 170 | self.conv6 = Conv2d( 171 | out_channels=256, kernel_size=(3, 3), stride=(2, 2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME', 172 | W_init=W_init, data_format='channels_first', b_init=None 173 | ) 174 | self.bn5 = BatchNorm2d(gamma_init=G_init, data_format='channels_first') 175 | self.conv7 = Conv2d( 176 | out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME', 177 | W_init=W_init, data_format='channels_first', b_init=None 178 | ) 179 | self.bn6 = BatchNorm2d(gamma_init=G_init, data_format='channels_first') 180 | self.conv8 = Conv2d( 181 | out_channels=512, kernel_size=(3, 3), stride=(2, 2), act=tlx.LeakyReLU(negative_slope=0.2), padding='SAME', 182 | W_init=W_init, data_format='channels_first', b_init=None 183 | ) 184 | self.bn7 = BatchNorm2d(gamma_init=G_init, data_format='channels_first') 185 | self.flat = Flatten() 186 | self.dense1 = Linear(out_features=1024, act=tlx.LeakyReLU(negative_slope=0.2)) 187 | self.dense2 = Linear(out_features=1) 188 | 189 | def forward(self, x): 190 | x = self.conv1(x) 191 | x = self.conv2(x) 192 | x = self.bn1(x) 193 | x = self.conv3(x) 194 | x = self.bn2(x) 195 | x = self.conv4(x) 196 | x = self.bn3(x) 197 | x = self.conv5(x) 198 | x = self.bn4(x) 199 | x = self.conv6(x) 200 | x = self.bn5(x) 201 | x = self.conv7(x) 202 | x = self.bn6(x) 203 | x = self.conv8(x) 204 | x = self.bn7(x) 205 | x = self.flat(x) 206 | x = self.dense1(x) 207 | x = self.dense2(x) 208 | logits = x 209 | n = tlx.sigmoid(x) 210 | return n, logits 211 | 212 | 213 | class SRGAN_d(Module): 214 | 215 | def __init__(self, dim=64): 216 | super(SRGAN_d, self).__init__() 217 | self.conv1 = Conv2d( 218 | out_channels=dim, kernel_size=(4, 4), stride=(2, 2), act=tlx.LeakyReLU, padding='SAME', W_init=W_init, 219 | data_format='channels_first' 220 | ) 221 | self.conv2 = Conv2d( 222 | out_channels=dim * 2, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init, 223 | data_format='channels_first', b_init=None 224 | ) 225 | self.bn1 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first') 226 | self.conv3 = Conv2d( 227 | out_channels=dim * 4, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init, 228 | data_format='channels_first', b_init=None 229 | ) 230 | self.bn2 = BatchNorm2d(num_features=dim * 4, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first') 231 | self.conv4 = Conv2d( 232 | out_channels=dim * 8, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init, 233 | data_format='channels_first', b_init=None 234 | ) 235 | self.bn3 = BatchNorm2d(num_features=dim * 8, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first') 236 | self.conv5 = Conv2d( 237 | out_channels=dim * 16, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init, 238 | data_format='channels_first', b_init=None 239 | ) 240 | self.bn4 = BatchNorm2d(num_features=dim * 16, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first') 241 | self.conv6 = Conv2d( 242 | out_channels=dim * 32, kernel_size=(4, 4), stride=(2, 2), act=None, padding='SAME', W_init=W_init, 243 | data_format='channels_first', b_init=None 244 | ) 245 | self.bn5 = BatchNorm2d(num_features=dim * 32, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first') 246 | self.conv7 = Conv2d( 247 | out_channels=dim * 16, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME', W_init=W_init, 248 | data_format='channels_first', b_init=None 249 | ) 250 | self.bn6 = BatchNorm2d(num_features=dim * 16, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first') 251 | self.conv8 = Conv2d( 252 | out_channels=dim * 8, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME', W_init=W_init, 253 | data_format='channels_first', b_init=None 254 | ) 255 | self.bn7 = BatchNorm2d(num_features=dim * 8, act=None, gamma_init=G_init, data_format='channels_first') 256 | self.conv9 = Conv2d( 257 | out_channels=dim * 2, kernel_size=(1, 1), stride=(1, 1), act=None, padding='SAME', W_init=W_init, 258 | data_format='channels_first', b_init=None 259 | ) 260 | self.bn8 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first') 261 | self.conv10 = Conv2d( 262 | out_channels=dim * 2, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init, 263 | data_format='channels_first', b_init=None 264 | ) 265 | self.bn9 = BatchNorm2d(num_features=dim * 2, act=tlx.LeakyReLU, gamma_init=G_init, data_format='channels_first') 266 | self.conv11 = Conv2d( 267 | out_channels=dim * 8, kernel_size=(3, 3), stride=(1, 1), act=None, padding='SAME', W_init=W_init, 268 | data_format='channels_first', b_init=None 269 | ) 270 | self.bn10 = BatchNorm2d(num_features=dim * 8, gamma_init=G_init, data_format='channels_first') 271 | self.add = Elementwise(combine_fn=tlx.add, act=tlx.LeakyReLU) 272 | self.flat = Flatten() 273 | self.dense = Linear(out_features=1, W_init=W_init) 274 | 275 | def forward(self, x): 276 | 277 | x = self.conv1(x) 278 | x = self.conv2(x) 279 | x = self.bn1(x) 280 | x = self.conv3(x) 281 | x = self.bn2(x) 282 | x = self.conv4(x) 283 | x = self.bn3(x) 284 | x = self.conv5(x) 285 | x = self.bn4(x) 286 | x = self.conv6(x) 287 | x = self.bn5(x) 288 | x = self.conv7(x) 289 | x = self.bn6(x) 290 | x = self.conv8(x) 291 | x = self.bn7(x) 292 | temp = x 293 | x = self.conv9(x) 294 | x = self.bn8(x) 295 | x = self.conv10(x) 296 | x = self.bn9(x) 297 | x = self.conv11(x) 298 | x = self.bn10(x) 299 | x = self.add([temp, x]) 300 | x = self.flat(x) 301 | x = self.dense(x) 302 | 303 | return x 304 | 305 | 306 | class Vgg19_simple_api(Module): 307 | 308 | def __init__(self): 309 | super(Vgg19_simple_api, self).__init__() 310 | """ conv1 """ 311 | self.conv1 = Conv2d(out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME') 312 | self.conv2 = Conv2d(out_channels=64, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME') 313 | self.maxpool1 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME') 314 | """ conv2 """ 315 | self.conv3 = Conv2d(out_channels=128, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME') 316 | self.conv4 = Conv2d(out_channels=128, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME') 317 | self.maxpool2 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME') 318 | """ conv3 """ 319 | self.conv5 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME') 320 | self.conv6 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME') 321 | self.conv7 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME') 322 | self.conv8 = Conv2d(out_channels=256, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME') 323 | self.maxpool3 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME') 324 | """ conv4 """ 325 | self.conv9 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME') 326 | self.conv10 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME') 327 | self.conv11 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME') 328 | self.conv12 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME') 329 | self.maxpool4 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME') # (batch_size, 14, 14, 512) 330 | """ conv5 """ 331 | self.conv13 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME') 332 | self.conv14 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME') 333 | self.conv15 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME') 334 | self.conv16 = Conv2d(out_channels=512, kernel_size=(3, 3), stride=(1, 1), act=tlx.ReLU, padding='SAME') 335 | self.maxpool5 = MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding='SAME') # (batch_size, 7, 7, 512) 336 | """ fc 6~8 """ 337 | self.flat = Flatten() 338 | self.dense1 = Linear(out_features=4096, act=tlx.ReLU) 339 | self.dense2 = Linear(out_features=4096, act=tlx.ReLU) 340 | self.dense3 = Linear(out_features=1000, act=tlx.identity) 341 | 342 | def forward(self, x): 343 | x = self.conv1(x) 344 | x = self.conv2(x) 345 | x = self.maxpool1(x) 346 | x = self.conv3(x) 347 | x = self.conv4(x) 348 | x = self.maxpool2(x) 349 | x = self.conv5(x) 350 | x = self.conv6(x) 351 | x = self.conv7(x) 352 | x = self.conv8(x) 353 | x = self.maxpool3(x) 354 | x = self.conv9(x) 355 | x = self.conv10(x) 356 | x = self.conv11(x) 357 | x = self.conv12(x) 358 | x = self.maxpool4(x) 359 | conv = x 360 | x = self.conv13(x) 361 | x = self.conv14(x) 362 | x = self.conv15(x) 363 | x = self.conv16(x) 364 | x = self.maxpool5(x) 365 | x = self.flat(x) 366 | x = self.dense1(x) 367 | x = self.dense2(x) 368 | x = self.dense3(x) 369 | 370 | return x, conv 371 | --------------------------------------------------------------------------------