├── README.md ├── code ├── Situation1.ipynb ├── Situation2.ipynb ├── Situation2_test_speed.ipynb ├── Situation3.ipynb ├── Situation3.py ├── Situation3_test_speed.ipynb ├── models.py └── utils.py ├── imgs ├── Conv2d_MyConv2D.png ├── S1.jpg ├── S2.jpg ├── S2_speed.png ├── S3.jpg ├── S3_speed.png ├── VGG16.png ├── metanet.png ├── monitor.png ├── transform_net.png ├── transform_net2.png ├── transform_net3.png └── weights_diverge.png └── keras_version ├── 3_situation.py ├── README.md ├── a.jpg ├── c.jpg ├── demo.sh ├── demo3.py ├── picasso.jpg ├── res ├── all.jpg ├── c2.jpg ├── c2_1.jpg ├── c2_2.jpg ├── c2_3.jpg ├── c2_5.jpg ├── c4.jpg ├── c4_1.jpg ├── c4_2.jpg ├── c4_3.jpg ├── c4_5.jpg ├── style1.jpg ├── style2.jpg ├── style3.jpg └── style5.jpg ├── result.jpg ├── train_content └── gitkeep ├── train_style └── gitkeep ├── util.py └── util.pyc /README.md: -------------------------------------------------------------------------------- 1 | # StyleTransferTrilogy 2 | 3 | # 风格迁移三部曲 4 | 5 | 风格迁移是一个很有意思的任务,通过风格迁移可以使一张图片保持本身内容大致不变的情况下呈现出另外一张图片的风格。本文会介绍以下三种风格迁移方式以及对应的代码实现: 6 | 7 | * 固定风格固定内容的普通风格迁移([A Neural Algorithm of Artistic Style](https://arxiv.org/abs/1508.06576)) 8 | * 固定风格任意内容的快速风格迁移([Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155)) 9 | * 任意风格任意内容的极速风格迁移([Meta Networks for Neural Style Transfer](https://arxiv.org/abs/1709.04111)) 10 | 11 | 本文所使用的环境是 pytorch 0.4.0,如果你使用了其他的版本,稍作修改即可正确运行。 12 | 13 | # 固定风格固定内容的普通风格迁移 14 | 15 | 最早的风格迁移就是在固定风格、固定内容的情况下做的风格迁移,这是最慢的方法,也是最经典的方法。 16 | 17 | 最原始的风格迁移的思路很简单,把图片当做可以训练的变量,通过优化图片来降低与内容图片的内容差异以及降低与风格图片的风格差异,迭代训练多次以后,生成的图片就会与内容图片的内容一致,同时也会与风格图片的风格一致。 18 | 19 | ## VGG16 20 | 21 | VGG16 是一个很经典的模型,它通过堆叠 3x3 的卷积层和池化层,在 ImageNet 上获得了不错的成绩。我们使用在 ImageNet 上经过预训练的 VGG16 模型可以对图像提取出有用的特征,这些特征可以帮助我们去衡量两个图像的内容差异和风格差异。 22 | 23 | 在进行风格迁移任务时,我们只需要提取其中几个比较重要的层,所以我们对 pytorch 自带的预训练 VGG16 模型稍作了一些修改: 24 | 25 | ```py 26 | class VGG(nn.Module): 27 | 28 | def __init__(self, features): 29 | super(VGG, self).__init__() 30 | self.features = features 31 | self.layer_name_mapping = { 32 | '3': "relu1_2", 33 | '8': "relu2_2", 34 | '15': "relu3_3", 35 | '22': "relu4_3" 36 | } 37 | for p in self.parameters(): 38 | p.requires_grad = False 39 | 40 | def forward(self, x): 41 | outs = [] 42 | for name, module in self.features._modules.items(): 43 | x = module(x) 44 | if name in self.layer_name_mapping: 45 | outs.append(x) 46 | return outs 47 | 48 | vgg16 = models.vgg16(pretrained=True) 49 | vgg16 = VGG(vgg16.features[:23]).to(device).eval() 50 | ``` 51 | 52 | 经过修改的 VGG16 可以输出 relu1_2,relu2_2,relu3_3,relu4_3 这几个特定层的特征图。下面这两句代码就是它的用法: 53 | 54 | ```py 55 | features = vgg16(input_img) 56 | content_features = vgg16(content_img) 57 | ``` 58 | 59 | 举个例子,当我们使用 vgg16 对 `input_img` 计算特征时,它会返回四个矩阵给 features,假设 `input_img` 的尺寸是 `[1, 3, 512, 512]`(四个维度分别代表 batch, channels, height, width),那么它返回的四个矩阵的尺寸就是这样的: 60 | 61 | * relu1_2 `[1, 64, 512, 512]` 62 | * relu2_2 `[1, 128, 256, 256]` 63 | * relu3_3 `[1, 256, 128, 128]` 64 | * relu4_3 `[1, 512, 64, 64]` 65 | 66 | ## 内容 67 | 68 | 我们进行风格迁移的时候,必须保证生成的图像与内容图像的内容一致性,不然风格迁移就变成艺术创作了。那么如何衡量两张图片的内容差异呢?很简单,通过 VGG16 输出的特征图来衡量图片的内容差异。 69 | 70 | ![](imgs/VGG16.png) 71 | 72 | 提示:在本方法中没有 Image Transform Net,为了表述方便,我们使用了第二篇论文中的图。 73 | 74 | 这里使用的损失函数是: 75 | 76 | ![equation](https://latex.codecogs.com/svg.latex?$$\Large\ell^{\phi,j}_{feat}(\hat{y},y)=\frac{1}{C_jH_jW_j}||\phi_j(\hat{y})-\phi_j(y)||^2_2$$) 77 | 78 | 79 | 其中: 80 | 81 | * ![equation](https://latex.codecogs.com/svg.latex?\hat{y})是输入图像(也就是生成的图像) 82 | * ![equation](https://latex.codecogs.com/svg.latex?y)是内容图像 83 | * ![equation](https://latex.codecogs.com/svg.latex?\phi) 代表 VGG16 84 | * ![equation](https://latex.codecogs.com/svg.latex?\j) 在这里是 relu3_3 85 | * ![equation](https://latex.codecogs.com/svg.latex?\phi_j(x))指的是 x 图像输入到 VGG 以后的第 j 层的特征图 86 | * ![equation](https://latex.codecogs.com/svg.latex?C_j\times&space;H_j\times&space;W_j)是第 j 层输出的特征图的尺寸 87 | 88 | 根据生成图像和内容图像在 ![equation](https://latex.codecogs.com/svg.latex?$\text{relu3\_3}$) 输出的特征图的均方误差(MeanSquaredError)来优化生成的图像与内容图像之间的内容一致性。 89 | 90 | 91 | 那么写成代码就是这样的: 92 | 93 | ```py 94 | content_loss = F.mse_loss(features[2], content_features[2]) * content_weight 95 | ``` 96 | 97 | 因为我们这里使用的是经过在 ImageNet 预训练过的 VGG16 提取的特征图,所以它能提取出图像的高级特征,通过优化生成图像和内容图像特征图的 mse,可以迫使生成图像的内容与内容图像在 VGG16 的 relu3\_3 上输出相似的结果,因此生成图像和内容图像在内容上是一致的。 98 | 99 | ## 风格 100 | 101 | ### Gram 矩阵 102 | 103 | 那么如何衡量输入图像与风格图像之间的内容差异呢?这里就需要提出一个新的公式,Gram 矩阵: 104 | 105 | ![equation](https://latex.codecogs.com/svg.latex?$$\Large{G^\phi_j(x)_{c,c'}=\frac{1}{C_jH_jW_j}&space;\sum_{h=1}^{H_j}&space;\sum_{w=1}^{W_j}&space;\phi_j(x)_{h,w,c}\phi_j(x)_{h,w,c'}}$$) 106 | 107 | 其中: 108 | 109 | * ![equation](https://latex.codecogs.com/svg.latex?\hat{y})是输入图像(也就是生成的图像) 110 | * ![equation](https://latex.codecogs.com/svg.latex?y)是风格图像 111 | * ![equation](https://latex.codecogs.com/svg.latex?C_j\times&space;H_j\times&space;W_j)是第 j 层输出的特征图的尺寸。 112 | * ![equation](https://latex.codecogs.com/svg.latex?$G^\phi_j(x)$)指的是 x 图像的第 j 层特征图对应的 Gram 矩阵,比如 64 个卷积核对应的卷积层输出的特征图的 Gram 矩阵的尺寸是 ![equation](https://latex.codecogs.com/svg.latex?$(64,64)$)。 113 | * ![equation](https://latex.codecogs.com/svg.latex?$G^\phi_j(x)_{c,c'}$) 指的是 Gram 矩阵第 ![equation](https://latex.codecogs.com/svg.latex?$(c,c')$) 坐标对应的值。 114 | * ![equation](https://latex.codecogs.com/svg.latex?$\phi_j(x)$)指的是 x 图像输入到 VGG 以后的第 j 层的特征图,![equation](https://latex.codecogs.com/svg.latex?$\phi_j(x)_{h,w,c}$) 指的是特征图 ![equation](https://latex.codecogs.com/svg.latex?$(h,w,c)$)坐标对应的值。 115 | 116 | Gram 矩阵的计算方法其实很简单,Gram 矩阵的 ![equation](https://latex.codecogs.com/svg.latex?$(c,c')$) 坐标对应的值,就是特征图的第 ![equation](https://latex.codecogs.com/svg.latex?$c$) 张和第 ![equation](https://latex.codecogs.com/svg.latex?$c'$) 张图对应元素相乘,然后全部加起来并且除以 ![equation](https://latex.codecogs.com/svg.latex?C_j\times&space;H_j\times&space;W_j) 的结果。根据公式我们可以很容易推断出 Gram 矩阵是对称矩阵。 117 | 118 | 具体到代码,我们可以写出下面的函数: 119 | 120 | ```py 121 | def gram_matrix(y): 122 | (b, ch, h, w) = y.size() 123 | features = y.view(b, ch, w * h) 124 | features_t = features.transpose(1, 2) 125 | gram = features.bmm(features_t) / (ch * h * w) 126 | return gram 127 | ``` 128 | 129 | 参考链接: 130 | 131 | [https://github.com/pytorch/examples/blob/0.4/fast_neural_style/neural_style/utils.py#L21-L26](https://github.com/pytorch/examples/blob/0.4/fast_neural_style/neural_style/utils.py#L21-L26) 132 | 133 | 假设我们输入了一个 `[1, 3, 512, 512]` 的图像,下面就是各个矩阵的尺寸: 134 | 135 | * relu1_2 `[1, 64, 512, 512]`,gram `[1, 64, 64]` 136 | * relu2_2 `[1, 128, 256, 256]`,gram `[1, 128, 128]` 137 | * relu3_3 `[1, 256, 128, 128]`,gram `[1, 256, 256]` 138 | * relu4_3 `[1, 512, 64, 64]`,gram `[1, 512, 512]` 139 | 140 | ### 风格损失 141 | 142 | 根据生成图像和风格图像在 relu1_2、relu2_2、relu3_3、relu4_3 输出的特征图的 Gram 矩阵之间的均方误差(MeanSquaredError)来优化生成的图像与风格图像之间的风格差异: 143 | 144 | ![equation](https://latex.codecogs.com/svg.latex?$$\Large\ell^{\phi,j}_{style}(\hat{y},y)=||G^\phi_j(\hat{y})-G^\phi_j(y)||^2_F$$) 145 | 146 | 其中: 147 | 148 | * ![equation](https://latex.codecogs.com/svg.latex?\hat{y})是输入图像(也就是生成的图像) 149 | * ![equation](https://latex.codecogs.com/svg.latex?$y$)是风格图像 150 | * ![equation](https://latex.codecogs.com/svg.latex?$G^\phi_j(x)$)指的是 x 图像的第 j 层特征图对应的 Gram 矩阵 151 | 152 | 那么写成代码就是下面这样: 153 | 154 | ```py 155 | style_grams = [gram_matrix(x) for x in style_features] 156 | 157 | style_loss = 0 158 | grams = [gram_matrix(x) for x in features] 159 | for a, b in zip(grams, style_grams): 160 | style_loss += F.mse_loss(a, b) * style_weight 161 | ``` 162 | 163 | ## 训练 164 | 165 | 那么风格迁移的目标就很简单了,直接将两个 loss 按权值加起来,然后对图片优化 loss,即可优化出既有内容图像的内容,也有风格图像的风格的图片。代码如下: 166 | 167 | ```py 168 | input_img = content_img.clone() 169 | optimizer = optim.LBFGS([input_img.requires_grad_()]) 170 | style_weight = 1e6 171 | content_weight = 1 172 | 173 | run = [0] 174 | while run[0] <= 300: 175 | def f(): 176 | optimizer.zero_grad() 177 | features = vgg16(input_img) 178 | 179 | content_loss = F.mse_loss(features[2], content_features[2]) * content_weight 180 | style_loss = 0 181 | grams = [gram_matrix(x) for x in features] 182 | for a, b in zip(grams, style_grams): 183 | style_loss += F.mse_loss(a, b) * style_weight 184 | 185 | loss = style_loss + content_loss 186 | 187 | if run[0] % 50 == 0: 188 | print('Step {}: Style Loss: {:4f} Content Loss: {:4f}'.format( 189 | run[0], style_loss.item(), content_loss.item())) 190 | run[0] += 1 191 | 192 | loss.backward() 193 | return loss 194 | 195 | optimizer.step(f) 196 | ``` 197 | 198 | 此处使用了 LBFGS,所以 loss 需要包装在一个函数里,代码参考了: 199 | [https://pytorch.org/tutorials/advanced/neural\_style\_tutorial.html](https://pytorch.org/tutorials/advanced/neural_style_tutorial.html) 200 | 201 | ## 效果 202 | 203 | 最终效果如图所示: 204 | 205 | ![](imgs/S1.jpg) 206 | 207 | 可以看到生成的图像既有风格图像的风格,也有内容图像的内容,很完美。不过生成一幅256x256 的图像在 1080ti 上需要18.6s,这个时间挺长的,谈不上实时性,因此我们可以来看看第二篇论文中的方法。 208 | 209 | # 固定风格任意内容的快速风格迁移 210 | 211 | 有了上面的铺垫,理解固定风格任意内容的快速风格迁移就简单很多了。思路很简单,就是先搭建一个转换网络,然后通过优化转换网络的权值来实现快速风格迁移。由于这个转换网络可以接受任意图像,所以这是任意内容的风格迁移。 212 | 213 | ## 模型 214 | 215 | 模型结构很简单,分为三个部分: 216 | 217 | * 降维,三层卷积层,逐渐提升通道数为128,并且通过 stride 把特征图的宽高缩小为原来的八分之一 218 | * 5个 ResidualBlock 堆叠 219 | * 升维,三层卷积层,逐渐降低通道数为3,并且通过 nn.Upsample 把特征图的宽高还原为原来的大小 220 | 221 | 先降维再升维是为了减少计算量,中间的 5 个 Residual 结构可以学习如何在原图上添加少量内容,改变原图的风格。下面让我们来看看代码。 222 | 223 | ### ConvLayer 224 | 225 | ```py 226 | def ConvLayer(in_channels, out_channels, kernel_size=3, stride=1, 227 | upsample=None, instance_norm=True, relu=True): 228 | layers = [] 229 | if upsample: 230 | layers.append(nn.Upsample(mode='nearest', scale_factor=upsample)) 231 | layers.append(nn.ReflectionPad2d(kernel_size // 2)) 232 | layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride)) 233 | if instance_norm: 234 | layers.append(nn.InstanceNorm2d(out_channels)) 235 | if relu: 236 | layers.append(nn.ReLU()) 237 | return layers 238 | ``` 239 | 240 | 首先我们实现了一个函数,ConvLayer,它包含: 241 | 242 | * [nn.Upsample](https://pytorch.org/docs/stable/nn.html#upsample)(可选) 243 | * [nn.ReflectionPad2d](https://pytorch.org/docs/stable/nn.html#reflectionpad2d) 244 | * [nn.Conv2d](https://pytorch.org/docs/stable/nn.html#conv2d) 245 | * [nn.InstanceNorm2d](https://pytorch.org/docs/stable/nn.html#instancenorm2d)(可选) 246 | * [nn.ReLU](https://pytorch.org/docs/stable/nn.html#relu)(可选) 247 | 248 | 因为每个卷积层前后都可能会用到这些层,为了简化代码,我们将它写成一个函数,返回这些层用于搭建模型。 249 | 250 | ### ResidualBlock 251 | 252 | ```py 253 | class ResidualBlock(nn.Module): 254 | def __init__(self, channels): 255 | super(ResidualBlock, self).__init__() 256 | self.conv = nn.Sequential( 257 | *ConvLayer(channels, channels, kernel_size=3, stride=1), 258 | *ConvLayer(channels, channels, kernel_size=3, stride=1, relu=False) 259 | ) 260 | 261 | def forward(self, x): 262 | return self.conv(x) + x 263 | ``` 264 | 265 | 这里写的就不是函数,而是一个类,因为它内部包含许多层,而且并不是简单的自上而下的结构(Sequential),而是有了跨层的连接(`self.conv(x) + x`),所以我们需要继承 nn.Module,实现 forward 函数,才能实现跨层连接。 266 | 267 | ### TransformNet 268 | 269 | 最后这个模型就很简单了,照着论文里给出的表格搭建即可。我们这里为了实验方便,添加了 base 参数,当 `base=8` 时,卷积核的个数是按 `8, 16, 32` 递增的,当 `base=32` 时,卷积核个数是按 `32, 64, 128` 递增的。有了这个参数,我们可以按需增加模型规模,base 越大,图像质量越好。 270 | 271 | ![](imgs/transform_net.png) 272 | 273 | ```py 274 | class TransformNet(nn.Module): 275 | def __init__(self, base=32): 276 | super(TransformNet, self).__init__() 277 | self.downsampling = nn.Sequential( 278 | *ConvLayer(3, base, kernel_size=9), 279 | *ConvLayer(base, base*2, kernel_size=3, stride=2), 280 | *ConvLayer(base*2, base*4, kernel_size=3, stride=2), 281 | ) 282 | self.residuals = nn.Sequential(*[ResidualBlock(base*4) for i in range(5)]) 283 | self.upsampling = nn.Sequential( 284 | *ConvLayer(base*4, base*2, kernel_size=3, upsample=2), 285 | *ConvLayer(base*2, base, kernel_size=3, upsample=2), 286 | *ConvLayer(base, 3, kernel_size=9, instance_norm=False, relu=False), 287 | ) 288 | 289 | def forward(self, X): 290 | y = self.downsampling(X) 291 | y = self.residuals(y) 292 | y = self.upsampling(y) 293 | return y 294 | ``` 295 | 296 | ## 数据 297 | 298 | 训练的时候,我们使用了 COCO train 2014、val2014、test2014, 一共有 164k 图像,实际上原论文只用了训练集(80k)。图像宽高都是256。 299 | 300 | > We resize each of the 80k training images to 256 × 256 and train our networks with a batch size of 4 for 40,000 iterations, giving roughly two epochs over the training data. 301 | 302 | ```py 303 | batch_size = 4 304 | width = 256 305 | 306 | data_transform = transforms.Compose([ 307 | transforms.Resize(width), 308 | transforms.CenterCrop(width), 309 | transforms.ToTensor(), 310 | tensor_normalizer, 311 | ]) 312 | 313 | dataset = torchvision.datasets.ImageFolder('/home/ypw/COCO/', transform=data_transform) 314 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) 315 | ``` 316 | 317 | 返回: 318 | 319 | ``` 320 | Dataset ImageFolder 321 | Number of datapoints: 164062 322 | Root Location: /home/ypw/COCO/ 323 | Transforms (if any): Compose( 324 | Resize(size=256, interpolation=PIL.Image.BILINEAR) 325 | CenterCrop(size=(256, 256)) 326 | ToTensor() 327 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 328 | ) 329 | Target Transforms (if any): None 330 | ``` 331 | 332 | 其中的 `tensor_normalizer` 是为了使用 pytorch 自带的预训练模型,在官方文档中提到了要进行预处理:[https://pytorch.org/docs/master/torchvision/models.html](https://pytorch.org/docs/master/torchvision/models.html) 333 | 334 | ```py 335 | cnn_normalization_mean = [0.485, 0.456, 0.406] 336 | cnn_normalization_std = [0.229, 0.224, 0.225] 337 | tensor_normalizer = transforms.Normalize(mean=cnn_normalization_mean, std=cnn_normalization_std) 338 | ``` 339 | 340 | ## 训练 341 | 342 | ### 超参数 343 | 344 | 虽然[官方开源](https://github.com/jcjohnson/fast-neural-style/blob/master/doc/training.md)给出的 `style_weight` 是 5,但是我这里测试得并不理想,可能是不同的预训练权值、不同的预处理方式造成的差异,设置为 1e5 是比较理想的。 345 | 346 | > We use Adam [51] with a learning rate of 1 × 10−3. 347 | 348 | 优化器使用了论文中提到的 Adam 1e-3。 349 | 350 | > The output images are regularized with total variation regularization with a strength of between ![equation](https://latex.codecogs.com/svg.latex?$1\times10^{-6}$) and ![equation](https://latex.codecogs.com/svg.latex?$1\times10^{-4}$), chosen via cross-validation per style target. 351 | 352 | `tv_weight` 感觉没有太大变化,所以按论文中给出的参考设置了 1e-6。 353 | 354 | > train our networks with a batch size of 4 for 40,000 iterations 355 | 356 | `batch_size` 按论文设置为了4。 357 | 358 | 由于我这里使用的图片变多了,所以为了保持和官方的训练 step 一致(40k),训练代数(epoch)设置为了1。 359 | 360 | ### TotalVariation 361 | 362 | > Total Variation Regularization. To encourage spatial smoothness in the output image ![equation](https://latex.codecogs.com/svg.latex?\hat{y}), we follow prior work on feature inversion [6,20] and super- resolution [48,49] and make use of total variation regularizer ![equation](https://latex.codecogs.com/svg.latex?$\ell_{TV}(\hat{y})$). 363 | 364 | 论文中提到了一个 TV Loss,这是为了平滑图像。它的计算方法很简单: 365 | 366 | ![equation](https://latex.codecogs.com/svg.latex?$$\Large{V_\text{aniso}(y)=\sum_{i,j}|y_{i+1,j}-y_{i,j}|+|y_{i,j+1}-y_{i,j}|}$$) 367 | 368 | 将图像水平和垂直平移一个像素,与原图相减,然后计算绝对值的和,就是 TotalVariation。 369 | 370 | 参考链接:[https://en.wikipedia.org/wiki/Total_variation_denoising](https://en.wikipedia.org/wiki/Total_variation_denoising) 371 | 372 | ### 代码 373 | 374 | 由于代码太长,这里只贴一些关键代码: 375 | 376 | ```py 377 | for batch, (content_images, _) in pbar: 378 | optimizer.zero_grad() 379 | 380 | # 使用风格模型预测风格迁移图像 381 | content_images = content_images.to(device) 382 | transformed_images = transform_net(content_images) 383 | transformed_images = transformed_images.clamp(-3, 3) 384 | 385 | # 使用 vgg16 计算特征 386 | content_features = vgg16(content_images) 387 | transformed_features = vgg16(transformed_images) 388 | 389 | # content loss 390 | content_loss = content_weight * F.mse_loss(transformed_features[1], content_features[1]) 391 | 392 | # total variation loss 393 | y = transformed_images 394 | tv_loss = tv_weight * (torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + 395 | torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :]))) 396 | 397 | # style loss 398 | style_loss = 0. 399 | transformed_grams = [gram_matrix(x) for x in transformed_features] 400 | for transformed_gram, style_gram in zip(transformed_grams, style_grams): 401 | style_loss += style_weight * F.mse_loss(transformed_gram, 402 | style_gram.expand_as(transformed_gram)) 403 | 404 | # 加起来 405 | loss = style_loss + content_loss + tv_loss 406 | 407 | loss.backward() 408 | optimizer.step() 409 | ``` 410 | 411 | 通过对 loss 的优化,进而约束模型输出与内容图像的内容相似、与风格图像风格相似的图像,从而得到一个可以较快速度输出风格迁移图像的模型。 412 | 413 | ## 效果 414 | 415 | 最终效果如图所示: 416 | 417 | ![](imgs/S2.jpg) 418 | 419 | 可以看到对于任意内容图片,转换网络都能转换为固定风格的图像。根据下面这段代码进行的测速,1080ti 可以在4.82秒内完成 1000 张图像的风格迁移,相当于207fps,可以说是具有了实时性: 420 | 421 | ![](imgs/S2_speed.png) 422 | 423 | 但是整个模型的训练时间需要1小时54分钟,如果我们想做任意风格图像的风格迁移,这个时间几乎是不可接受的,因此让我们来看看第三篇论文的思路。 424 | 425 | # 任意风格任意内容的极速风格迁移 426 | 427 | 首先我们先对三种情况进行总结: 428 | 429 | ## 情况1 430 | 431 | ![equation](https://latex.codecogs.com/svg.latex?$$\large{\min_I\left(\lambda_c||\mathbf{CP}(I;w_f)-\mathbf{CP}(I_c;w_f)||^2_2+\lambda_s||\mathbf{SP}(I;w_f)-\mathbf{SP}(I_s;w_f)||^2_2\right)}$$) 432 | 433 | 其中: 434 | 435 | * ![equation](https://latex.codecogs.com/svg.latex?$\mathbf{CP}$) 是内容损失函数 436 | * ![equation](https://latex.codecogs.com/svg.latex?$\mathbf{SP}$) 是风格损失函数 437 | * ![equation](https://latex.codecogs.com/svg.latex?$\lambda_c$) 是内容权重 438 | * ![equation](https://latex.codecogs.com/svg.latex?$\lambda_s$) 是风格权重 439 | * ![equation](https://latex.codecogs.com/svg.latex?$w_f$) 是VGG16的固定权值 440 | * ![equation](https://latex.codecogs.com/svg.latex?$I_s$) 是风格图像 441 | * ![equation](https://latex.codecogs.com/svg.latex?$I_c$) 是内容图像 442 | * ![equation](https://latex.codecogs.com/svg.latex?$I$) 是输入图像 443 | 444 | 那么通过对输入图像 ![equation](https://latex.codecogs.com/svg.latex?$I$) 进行训练,我们能够得到固定风格、固定内容的风格迁移图像。 445 | 446 | ## 情况2 447 | 448 | ![equation](https://latex.codecogs.com/svg.latex?$$\large{\min_w\sum_{I_c}\left(\lambda_c||\mathbf{CP}(I_w;w_f)-\mathbf{CP}(I_c;w_f)||^2_2+\lambda_s||\mathbf{SP}(I_w;w_f)-\mathbf{SP}(I_s;w_f)||^2_2\right)}$$) 449 | 450 | 其中: 451 | 452 | * ![equation](https://latex.codecogs.com/svg.latex?$I_w$) 是生成图像,![equation](https://latex.codecogs.com/svg.latex?$I_w=\mathcal{N}(I_c;w)$,$\mathcal{N}$) 是图像转换网络 453 | 454 | 通过对权值的优化,我们可以得到一个快速风格迁移模型,它能够对任何内容图像进行风格转换,输出同一种风格的风格迁移图像。 455 | 456 | ## 情况3 457 | 458 | ![equation](https://latex.codecogs.com/svg.latex?$$\large{\min_\theta\sum_{I_c,I_s}\left(\lambda_c||\mathbf{CP}(I_{w_\theta};w_f)-\mathbf{CP}(I_c;w_f)||^2_2+\lambda_s||\mathbf{SP}(I_{w_\theta};w_f)-\mathbf{SP}(I_s;w_f)||^2_2\right)}$$) 459 | 460 | * ![equation](https://latex.codecogs.com/svg.latex?$\theta$) 是 ![equation](https://latex.codecogs.com/svg.latex?$Meta\mathcal{N}$) 的权值 461 | * ![equation](https://latex.codecogs.com/svg.latex?$w_\theta$) 是转换网络的权值,![equation](https://latex.codecogs.com/svg.latex?$w_\theta=Meta\mathcal{N}(I_s;\theta)$),所以我们可以说转换网络的权值是 MetaNet 通过风格图像生成的。 462 | * ![equation](https://latex.codecogs.com/svg.latex?$I_{w_\theta}$) 是转换网络生成的图像,![equation](https://latex.codecogs.com/svg.latex?$I_{w_\theta}=\mathcal{N}(I_c;w_\theta)$) 463 | 464 | 总的来说就是风格图像输入 ![equation](https://latex.codecogs.com/svg.latex?$Meta\mathcal{N}$) 得到转换网络 ![equation](https://latex.codecogs.com/svg.latex?$\mathcal{N}$),转换网络可以将任意内容图像进行转换。通过输入大量风格图像和内容图像 ![equation](https://latex.codecogs.com/svg.latex?$\sum_{I_c,I_s}$),可以训练出能够产出期望权值的 ![equation](https://latex.codecogs.com/svg.latex?$Meta\mathcal{N}$)。该模型可以输入任意风格图像,输出情况2中的迁移模型,进而实现任意风格任意内容的风格迁移。 465 | 466 | ## 转换网络(TransformNet) 467 | 468 | ![](imgs/transform_net2.png) 469 | 470 | 论文中的转换网络很有意思,粉色部分的权重是由 MetaNet 生成的,而灰色部分的权重则与 MetaNet 一起训练。由于这个模型的需求比较个性化,我们的代码需要一些技巧,下面让我们详细展开讨论。 471 | 472 | ### MyConv2D 473 | 474 | 转换网络的结构还是与之前的一样,但是为了调用方便,我们需要实现一个新的类,这个类和卷积层类似,但是权值和偏置都需要是常量。这是因为权值已经是 MetaNet 的输出,如果赋值为 TransformNet 的权值,那么这个计算图就断了,这不符合我们的预期,我们应该让 MetaNet 的输出继续参与计算图,直到计算出 loss,不然 MetaNet 的权值将不会更新。因此我们事先了一个新的类,MyConv2D。 475 | 476 | 为了体现两者的差异,我们使用 TensorBoard 进行了可视化: 477 | 478 | ![](imgs/Conv2d_MyConv2D.png) 479 | 480 | 从上图中可以看到,nn.Conv2d 内部有两个参数( Paramter),这是可以参与训练参数,也就是说在 `loss.backward()` 的时候会计算对应的梯度。而 MyConv2D 里面的权值和偏置都是常量(Constant),不会计算相应的梯度。 481 | 482 | 代码如下: 483 | 484 | ```py 485 | class MyConv2D(nn.Module): 486 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): 487 | super(MyConv2D, self).__init__() 488 | self.weight = torch.zeros((out_channels, in_channels, kernel_size, kernel_size)).to(device) 489 | self.bias = torch.zeros(out_channels).to(device) 490 | 491 | self.in_channels = in_channels 492 | self.out_channels = out_channels 493 | self.kernel_size = (kernel_size, kernel_size) 494 | self.stride = (stride, stride) 495 | 496 | def forward(self, x): 497 | return F.conv2d(x, self.weight, self.bias, self.stride) 498 | ``` 499 | 500 | ### ConvLayer 501 | 502 | 为了区分以下两种情况: 503 | 504 | * 权值是是可训练的参数 505 | * 权值由 MetaNet 给出 506 | 507 | 我们写出了下面的代码: 508 | 509 | ```py 510 | def ConvLayer(in_channels, out_channels, kernel_size=3, stride=1, 511 | upsample=None, instance_norm=True, relu=True, trainable=False): 512 | ...... 513 | if trainable: 514 | layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride)) 515 | else: 516 | layers.append(MyConv2D(in_channels, out_channels, kernel_size, stride)) 517 | ...... 518 | return layers 519 | ``` 520 | 521 | 很简单,当权值由 MetaNet 给出时,它是不参与训练的,我们设置 trainable=False,然后使用 MyConv2D 层。 522 | 523 | ### TransformNet 524 | 525 | 下面就直接贴代码了,模型结构按照上面论文中的图去搭就好。 526 | 527 | ```py 528 | class TransformNet(nn.Module): 529 | def __init__(self, base=8): 530 | super(TransformNet, self).__init__() 531 | self.base = base 532 | self.downsampling = nn.Sequential( 533 | *ConvLayer(3, base, kernel_size=9, trainable=True), 534 | *ConvLayer(base, base*2, kernel_size=3, stride=2), 535 | *ConvLayer(base*2, base*4, kernel_size=3, stride=2), 536 | ) 537 | self.residuals = nn.Sequential(*[ResidualBlock(base*4) for i in range(5)]) 538 | self.upsampling = nn.Sequential( 539 | *ConvLayer(base*4, base*2, kernel_size=3, upsample=2), 540 | *ConvLayer(base*2, base, kernel_size=3, upsample=2), 541 | *ConvLayer(base, 3, kernel_size=9, instance_norm=False, relu=False, trainable=True), 542 | ) 543 | 544 | def forward(self, X): 545 | y = self.downsampling(X) 546 | y = self.residuals(y) 547 | y = self.upsampling(y) 548 | return y 549 | .... 550 | ``` 551 | 552 | `TransformNet(32)` 每一层对应的权重数量如下: 553 | 554 | ``` 555 | defaultdict(int, 556 | {'downsampling.5': 18496, 557 | 'downsampling.9': 73856, 558 | 'residuals.0.conv.1': 147584, 559 | 'residuals.0.conv.5': 147584, 560 | 'residuals.1.conv.1': 147584, 561 | 'residuals.1.conv.5': 147584, 562 | 'residuals.2.conv.1': 147584, 563 | 'residuals.2.conv.5': 147584, 564 | 'residuals.3.conv.1': 147584, 565 | 'residuals.3.conv.5': 147584, 566 | 'residuals.4.conv.1': 147584, 567 | 'residuals.4.conv.5': 147584, 568 | 'upsampling.2': 73792, 569 | 'upsampling.7': 18464}) 570 | ``` 571 | 572 | 通过 TensorBoard,我们可以对模型结构进行可视化: 573 | 574 | ![](imgs/transform_net3.png) 575 | 576 | ## MetaNet 577 | 578 | 那么我们怎么样才能获得 TransformNet 的权值呢?当然是输入风格图像的特征。 579 | 580 | 那么我们知道风格图像经过 VGG16 输出的 relu1_2、relu2_2、relu3_3、relu4_3 尺寸是很大的,假设图像的尺寸是 `(256, 256)`,那么卷积层输出的尺寸分别是 `(64, 256, 256)、(128, 128, 128)、(256, 64, 64)、(512, 32, 32)`,即使取其 Gram 矩阵,`(64, 64)、(128, 128)、(256, 256)、(512, 512)` 也是非常大的。我们举个例子,假设使用 `512*512` 个特征来生成 147584 个权值(residual 层),那么这层全连接层的 w 就是 512x512x147584=38688260096 个,假设 w 的格式是 float32,那么光是一个 w 就有 144GB 这么大,这几乎是不可实现的。那么第三篇论文就提到了一个方法,只计算每一个卷积核输出的内容的均值和标准差。 581 | 582 | > We compute the mean and stand deviations of two feature maps of the style image and the transferred image as style features. 583 | 584 | 只计算均值和标准差,不计算 Gram 矩阵,这里的特征就变为了 (64+128+256+512)x2=1920 维,明显小了很多。但是我们稍加计算即可知道,1920x(18496+73856+147584x10+73792+18464)=3188060160,假设是 float32,那么权值至少有 11.8GB,显然无法在一块 1080ti 上实现 MetaNet。那么作者又提出了一个想法,使用分组全连接层。 585 | 586 | > The dimension of hidden vector is 1792 without specification. The hidden features are connected with the filters of each conv layer of the network in a group manner to decrease the parameter size, which means a 128 dimensional hidden vector for each conv layer. 587 | 588 | 意思就是隐藏层全连接层使用14x128=1792个神经元,这个14对应的就是 TransformNet 里面的每一层卷积层(downsampling2层,residual10层,upsampling2层),然后每一层卷积层的权值只连接其中的一小片128,那么整体结构参考下图: 589 | 590 | ![](imgs/metanet.png) 591 | 592 | 如果看不清可以点击查看原图。 593 | 594 | 在经过重重努力之后,模型大小终于限制在 1GB 以内了。当 `base=32` 时,保存为 pth 文件的模型大小为 870MB。 595 | 596 | 下面是代码: 597 | 598 | ```py 599 | class MetaNet(nn.Module): 600 | def __init__(self, param_dict): 601 | super(MetaNet, self).__init__() 602 | self.param_num = len(param_dict) 603 | self.hidden = nn.Linear(1920, 128*self.param_num) 604 | self.fc_dict = {} 605 | for i, (name, params) in enumerate(param_dict.items()): 606 | self.fc_dict[name] = i 607 | setattr(self, 'fc{}'.format(i+1), nn.Linear(128, params)) 608 | 609 | def forward(self, mean_std_features): 610 | hidden = F.relu(self.hidden(mean_std_features)) 611 | filters = {} 612 | for name, i in self.fc_dict.items(): 613 | fc = getattr(self, 'fc{}'.format(i+1)) 614 | filters[name] = fc(hidden[:,i*128:(i+1)*128]) 615 | return filters 616 | ``` 617 | 618 | 直接 print 模型: 619 | 620 | ``` 621 | MetaNet( 622 | (hidden): Linear(in_features=1920, out_features=1792, bias=True) 623 | (fc1): Linear(in_features=128, out_features=18496, bias=True) 624 | (fc2): Linear(in_features=128, out_features=73856, bias=True) 625 | (fc3): Linear(in_features=128, out_features=147584, bias=True) 626 | (fc4): Linear(in_features=128, out_features=147584, bias=True) 627 | (fc5): Linear(in_features=128, out_features=147584, bias=True) 628 | (fc6): Linear(in_features=128, out_features=147584, bias=True) 629 | (fc7): Linear(in_features=128, out_features=147584, bias=True) 630 | (fc8): Linear(in_features=128, out_features=147584, bias=True) 631 | (fc9): Linear(in_features=128, out_features=147584, bias=True) 632 | (fc10): Linear(in_features=128, out_features=147584, bias=True) 633 | (fc11): Linear(in_features=128, out_features=147584, bias=True) 634 | (fc12): Linear(in_features=128, out_features=147584, bias=True) 635 | (fc13): Linear(in_features=128, out_features=73792, bias=True) 636 | (fc14): Linear(in_features=128, out_features=18464, bias=True) 637 | ) 638 | ``` 639 | 640 | ## 数据 641 | 642 | > There are about 120k images in MS- COCO trainval set and about 80k images in the test set of WikiArt. 643 | 644 | 要想训练这么大的模型,那么就必须要海量的风格图像和内容图像。原论文依旧选择了 COCO 作为内容数据集。而风格数据集选择了 [WikiArt](https://www.kaggle.com/c/painter-by-numbers/data),该数据集包含大量艺术作品,很适合作为风格迁移的风格图片。 645 | 646 | > During training, each content image or style image is resized to keep the smallest dimension in the range [256, 480], and randomly cropped regions of size 256 × 256. 647 | 648 | 论文提到图像要先缩放到 [256, 480] 的尺寸,然后再随机裁剪为 256 × 256。 649 | 650 | 代码如下: 651 | 652 | ```py 653 | data_transform = transforms.Compose([ 654 | transforms.RandomResizedCrop(width, scale=(256/480, 1), ratio=(1, 1)), 655 | transforms.ToTensor(), 656 | tensor_normalizer 657 | ]) 658 | 659 | style_dataset = torchvision.datasets.ImageFolder('/home/ypw/WikiArt/', transform=data_transform) 660 | content_dataset = torchvision.datasets.ImageFolder('/home/ypw/COCO/', transform=data_transform) 661 | ``` 662 | 663 | ## 训练 664 | 665 | ### 超参数 666 | 667 | > The weight of content loss is 1 while the weight of style loss is 250. 668 | 669 | 虽然论文里给出的 `style_weight` 是 250,但是我这里测试得并不理想,可能是不同的预训练模型、不同的预处理方式造成的差异,设置为 25 是比较理想的。 670 | 671 | > We use Adam (Kingma and Ba 2014) with fixed learning rate 0.001 for 600k iterations without weight decay. 672 | 673 | 优化器使用了论文中提到的 Adam 1e-3。 674 | 675 | > The transferred images are regularized with total variations loss with a strength of 10. 676 | 677 | 因为这篇论文的作者用的是 caffe,VGG16 的预训练权值与 pytorch 差异比较大,所以我这里的 `tv_weight` 没有设置为论文中的10,而是选择了 1e-4。 678 | 679 | > The batch size of content images is 8 and the meta network is trained for 20 iterations before changing the style image. 680 | 681 | 这里的 batch_size 很有意思,每次来8张内容图片,但是每当训练20个 batch 之后,换一张风格图片。这样做的目的是为了保证 TransformNet 能在每张风格图像上都收敛一段时间,切换图像又能保证 MetaNet 能够适应所有的风格图像。 682 | 683 | ### 代码 684 | 685 | 由于代码太长,这里也只贴一些关键代码: 686 | 687 | ```py 688 | for batch, (content_images, _) in pbar: 689 | # 每 20 个 batch 随机挑选一张新的风格图像,计算其特征 690 | if batch % 20 == 0: 691 | style_image = random.choice(style_dataset)[0].unsqueeze(0).to(device) 692 | style_features = vgg16(style_image) 693 | style_mean_std = mean_std(style_features) 694 | 695 | # 检查纯色 696 | x = content_images.cpu().numpy() 697 | if (x.min(-1).min(-1) == x.max(-1).max(-1)).any(): 698 | continue 699 | 700 | optimizer.zero_grad() 701 | 702 | # 使用风格图像生成风格模型 703 | weights = metanet(mean_std(style_features)) 704 | transform_net.set_weights(weights, 0) 705 | 706 | # 使用风格模型预测风格迁移图像 707 | content_images = content_images.to(device) 708 | transformed_images = transform_net(content_images) 709 | 710 | # 使用 vgg16 计算特征 711 | content_features = vgg16(content_images) 712 | transformed_features = vgg16(transformed_images) 713 | transformed_mean_std = mean_std(transformed_features) 714 | 715 | # content loss 716 | content_loss = content_weight * F.mse_loss(transformed_features[2], content_features[2]) 717 | 718 | # style loss 719 | style_loss = style_weight * F.mse_loss(transformed_mean_std, 720 | style_mean_std.expand_as(transformed_mean_std)) 721 | 722 | # total variation loss 723 | y = transformed_images 724 | tv_loss = tv_weight * (torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + 725 | torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :]))) 726 | 727 | # 求和 728 | loss = content_loss + style_loss + tv_loss 729 | 730 | loss.backward() 731 | optimizer.step() 732 | ``` 733 | 734 | 这里有几点问题值得思考: 735 | 736 | 1. 如果内容图像是纯色的,那么权值会直接 nan,原因不明,为了避免这个问题,需要检查纯色,然后 continue 来避免 nan。 737 | 2. 权值会逐渐增大,目前没有比较好的解决方案。 738 | 739 | ![](imgs/weights_diverge.png) 740 | 741 | ## 效果 742 | 743 | 最终效果如图所示: 744 | 745 | ![](imgs/S3.jpg) 746 | 747 | 可以看到对于任意内容图片,转换网络都能转换为固定风格的图像。 748 | 749 | 根据下面这段代码进行的测速,1080ti 可以在8.48秒内对 1000 张风格图像产出风格迁移模型,相当于117fps。而风格迁移模型转换的速度也很快,达到了4.59秒,相当于217fps。假设我们每一帧都用不同的风格,转换1000张图片也只需要13.1秒,相当于76fps,可以说做到了实时任意风格任意内容的极速风格迁移。 750 | 751 | ![](imgs/S3_speed.png) 752 | 753 | # 总结 754 | 755 | 我们使用 pytorch 实现了以下三种风格迁移: 756 | 757 | * 固定风格固定内容的普通风格迁移([A Neural Algorithm of Artistic Style](https://arxiv.org/abs/1508.06576)) 758 | * 固定风格任意内容的快速风格迁移([Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155)) 759 | * 任意风格任意内容的极速风格迁移([Meta Networks for Neural Style Transfer](https://arxiv.org/abs/1709.04111)) 760 | 761 | 首先第一篇论文打破了以往的思维定式:只有权值可以训练。它通过对图像进行训练实现了风格迁移。然后第二篇论文就比较正常,通过训练一个模型来实现风格迁移。第三篇论文就很神奇了,通过模型来生成权值,进而实现任意风格的风格迁移。不得不感谢这些走在科技前沿的科研工作者,给了我们许多新奇的思路。 762 | 763 | -------------------------------------------------------------------------------- /code/Situation2_test_speed.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "ExecuteTime": { 8 | "end_time": "2018-07-12T10:18:08.541385Z", 9 | "start_time": "2018-07-12T10:18:08.213839Z" 10 | } 11 | }, 12 | "outputs": [], 13 | "source": [ 14 | "import torch\n", 15 | "import torch.nn as nn\n", 16 | "import torch.nn.functional as F\n", 17 | "\n", 18 | "def ConvLayer(in_channels, out_channels, kernel_size=3, stride=1, \n", 19 | " upsample=None, instance_norm=True, relu=True):\n", 20 | " layers = []\n", 21 | " if upsample:\n", 22 | " layers.append(nn.Upsample(mode='nearest', scale_factor=upsample))\n", 23 | " layers.append(nn.ReflectionPad2d(kernel_size // 2))\n", 24 | " layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride))\n", 25 | " if instance_norm:\n", 26 | " layers.append(nn.InstanceNorm2d(out_channels))\n", 27 | " if relu:\n", 28 | " layers.append(nn.ReLU())\n", 29 | " return layers\n", 30 | "\n", 31 | "class ResidualBlock(nn.Module):\n", 32 | " def __init__(self, channels):\n", 33 | " super(ResidualBlock, self).__init__()\n", 34 | " self.conv = nn.Sequential(\n", 35 | " *ConvLayer(channels, channels, kernel_size=3, stride=1), \n", 36 | " *ConvLayer(channels, channels, kernel_size=3, stride=1, relu=False)\n", 37 | " )\n", 38 | "\n", 39 | " def forward(self, x):\n", 40 | " return self.conv(x) + x\n", 41 | "\n", 42 | "class TransformNet(nn.Module):\n", 43 | " def __init__(self, base=32):\n", 44 | " super(TransformNet, self).__init__()\n", 45 | " self.downsampling = nn.Sequential(\n", 46 | " *ConvLayer(3, base, kernel_size=9), \n", 47 | " *ConvLayer(base, base*2, kernel_size=3, stride=2), \n", 48 | " *ConvLayer(base*2, base*4, kernel_size=3, stride=2), \n", 49 | " )\n", 50 | " self.residuals = nn.Sequential(*[ResidualBlock(base*4) for i in range(5)])\n", 51 | " self.upsampling = nn.Sequential(\n", 52 | " *ConvLayer(base*4, base*2, kernel_size=3, upsample=2),\n", 53 | " *ConvLayer(base*2, base, kernel_size=3, upsample=2),\n", 54 | " *ConvLayer(base, 3, kernel_size=9, instance_norm=False, relu=False),\n", 55 | " )\n", 56 | " \n", 57 | " def forward(self, X):\n", 58 | " y = self.downsampling(X)\n", 59 | " y = self.residuals(y)\n", 60 | " y = self.upsampling(y)\n", 61 | " return y" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 2, 67 | "metadata": { 68 | "ExecuteTime": { 69 | "end_time": "2018-07-12T10:18:14.609117Z", 70 | "start_time": "2018-07-12T10:18:08.543466Z" 71 | } 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", 76 | "transform_net = TransformNet(32).to(device)\n", 77 | "transform_net.load_state_dict(torch.load('transform_net.pth'))" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 3, 83 | "metadata": { 84 | "ExecuteTime": { 85 | "end_time": "2018-07-12T10:18:14.616276Z", 86 | "start_time": "2018-07-12T10:18:14.611411Z" 87 | } 88 | }, 89 | "outputs": [], 90 | "source": [ 91 | "X = torch.rand((1, 3, 256, 256)).to(device)" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 4, 97 | "metadata": { 98 | "ExecuteTime": { 99 | "end_time": "2018-07-12T10:18:19.447909Z", 100 | "start_time": "2018-07-12T10:18:14.617999Z" 101 | } 102 | }, 103 | "outputs": [ 104 | { 105 | "name": "stdout", 106 | "output_type": "stream", 107 | "text": [ 108 | "CPU times: user 4.58 s, sys: 236 ms, total: 4.81 s\n", 109 | "Wall time: 4.82 s\n" 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "%%time \n", 115 | "for i in range(1000):\n", 116 | " out = transform_net(X)\n", 117 | " del out" 118 | ] 119 | }, 120 | { 121 | "cell_type": "code", 122 | "execution_count": null, 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [] 126 | } 127 | ], 128 | "metadata": { 129 | "kernelspec": { 130 | "display_name": "Python 3", 131 | "language": "python", 132 | "name": "python3" 133 | }, 134 | "language_info": { 135 | "codemirror_mode": { 136 | "name": "ipython", 137 | "version": 3 138 | }, 139 | "file_extension": ".py", 140 | "mimetype": "text/x-python", 141 | "name": "python", 142 | "nbconvert_exporter": "python", 143 | "pygments_lexer": "ipython3", 144 | "version": "3.6.5" 145 | }, 146 | "toc": { 147 | "nav_menu": {}, 148 | "number_sections": true, 149 | "sideBar": true, 150 | "skip_h1_title": false, 151 | "toc_cell": false, 152 | "toc_position": {}, 153 | "toc_section_display": "block", 154 | "toc_window_display": false 155 | }, 156 | "varInspector": { 157 | "cols": { 158 | "lenName": "40", 159 | "lenType": 16, 160 | "lenVar": 40 161 | }, 162 | "kernels_config": { 163 | "python": { 164 | "delete_cmd_postfix": "", 165 | "delete_cmd_prefix": "del ", 166 | "library": "var_list.py", 167 | "varRefreshCmd": "print(var_dic_list())" 168 | }, 169 | "r": { 170 | "delete_cmd_postfix": ") ", 171 | "delete_cmd_prefix": "rm(", 172 | "library": "var_list.r", 173 | "varRefreshCmd": "cat(var_dic_list()) " 174 | } 175 | }, 176 | "position": { 177 | "height": "441px", 178 | "left": "934px", 179 | "right": "20px", 180 | "top": "120px", 181 | "width": "333px" 182 | }, 183 | "types_to_exclude": [ 184 | "module", 185 | "function", 186 | "builtin_function_or_method", 187 | "instance", 188 | "_Feature" 189 | ], 190 | "window_display": true 191 | } 192 | }, 193 | "nbformat": 4, 194 | "nbformat_minor": 2 195 | } 196 | -------------------------------------------------------------------------------- /code/Situation3.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 导入必要的库" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 4, 13 | "metadata": { 14 | "ExecuteTime": { 15 | "end_time": "2018-07-13T11:38:00.577786Z", 16 | "start_time": "2018-07-13T11:38:00.567386Z" 17 | } 18 | }, 19 | "outputs": [], 20 | "source": [ 21 | "import os\n", 22 | "\n", 23 | "# os.environ['CUDA_VISIBLE_DEVICES'] = '4'\n", 24 | "\n", 25 | "import torch\n", 26 | "import torch.nn as nn\n", 27 | "import torch.nn.functional as F\n", 28 | "import torch.optim as optim\n", 29 | "\n", 30 | "import random\n", 31 | "from PIL import Image\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "\n", 34 | "import torchvision\n", 35 | "import torchvision.transforms as transforms\n", 36 | "import torchvision.models as models\n", 37 | "import shutil\n", 38 | "from glob import glob\n", 39 | "\n", 40 | "from tensorboardX import SummaryWriter\n", 41 | "\n", 42 | "import numpy as np\n", 43 | "import multiprocessing\n", 44 | "\n", 45 | "import copy\n", 46 | "from tqdm import tqdm\n", 47 | "from collections import defaultdict\n", 48 | "\n", 49 | "import horovod.torch as hvd\n", 50 | "import torch.utils.data.distributed\n", 51 | "\n", 52 | "from utils import *\n", 53 | "from models import *\n", 54 | "import time\n", 55 | "\n", 56 | "from pprint import pprint\n", 57 | "display = pprint\n", 58 | "\n", 59 | "hvd.init()\n", 60 | "torch.cuda.set_device(hvd.local_rank())\n", 61 | "\n", 62 | "device = torch.device(\"cuda:%s\" %hvd.local_rank() if torch.cuda.is_available() else \"cpu\")" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 5, 68 | "metadata": { 69 | "ExecuteTime": { 70 | "end_time": "2018-07-13T11:38:00.744883Z", 71 | "start_time": "2018-07-13T11:38:00.737156Z" 72 | } 73 | }, 74 | "outputs": [ 75 | { 76 | "name": "stdout", 77 | "output_type": "stream", 78 | "text": [ 79 | "model_name: metanet_base32_style25_tv1e-07_l21e-05_taghvd, rank: 0\n" 80 | ] 81 | } 82 | ], 83 | "source": [ 84 | "is_hvd = False\n", 85 | "tag = 'nohvd'\n", 86 | "base = 32\n", 87 | "style_weight = 50\n", 88 | "content_weight = 1\n", 89 | "tv_weight = 1e-6\n", 90 | "epochs = 22\n", 91 | "\n", 92 | "batch_size = 8\n", 93 | "width = 256\n", 94 | "\n", 95 | "verbose_hist_batch = 100\n", 96 | "verbose_image_batch = 800\n", 97 | "\n", 98 | "model_name = f'metanet_base{base}_style{style_weight}_tv{tv_weight}_tag{tag}'\n", 99 | "print(f'model_name: {model_name}, rank: {hvd.rank()}')" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": null, 105 | "metadata": { 106 | "ExecuteTime": { 107 | "end_time": "2018-07-12T13:04:56.541964Z", 108 | "start_time": "2018-07-12T13:04:56.535774Z" 109 | } 110 | }, 111 | "outputs": [], 112 | "source": [ 113 | "def rmrf(path):\n", 114 | " try:\n", 115 | " shutil.rmtree(path)\n", 116 | " except:\n", 117 | " pass\n", 118 | "\n", 119 | "for f in glob('runs/*/.AppleDouble'):\n", 120 | " rmrf(f)\n", 121 | "\n", 122 | "rmrf('runs/' + model_name)" 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "# 搭建模型" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 3, 135 | "metadata": { 136 | "ExecuteTime": { 137 | "end_time": "2018-07-13T10:38:51.871437Z", 138 | "start_time": "2018-07-13T10:38:43.789881Z" 139 | } 140 | }, 141 | "outputs": [], 142 | "source": [ 143 | "vgg16 = models.vgg16(pretrained=True)\n", 144 | "vgg16 = VGG(vgg16.features[:23]).to(device).eval()" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 4, 150 | "metadata": { 151 | "ExecuteTime": { 152 | "end_time": "2018-07-13T10:38:51.925705Z", 153 | "start_time": "2018-07-13T10:38:51.874457Z" 154 | } 155 | }, 156 | "outputs": [ 157 | { 158 | "data": { 159 | "text/plain": [ 160 | "defaultdict(int,\n", 161 | " {'downsampling.5': 18496,\n", 162 | " 'downsampling.9': 73856,\n", 163 | " 'residuals.0.conv.1': 147584,\n", 164 | " 'residuals.0.conv.5': 147584,\n", 165 | " 'residuals.1.conv.1': 147584,\n", 166 | " 'residuals.1.conv.5': 147584,\n", 167 | " 'residuals.2.conv.1': 147584,\n", 168 | " 'residuals.2.conv.5': 147584,\n", 169 | " 'residuals.3.conv.1': 147584,\n", 170 | " 'residuals.3.conv.5': 147584,\n", 171 | " 'residuals.4.conv.1': 147584,\n", 172 | " 'residuals.4.conv.5': 147584,\n", 173 | " 'upsampling.2': 73792,\n", 174 | " 'upsampling.7': 18464})" 175 | ] 176 | }, 177 | "execution_count": 4, 178 | "metadata": {}, 179 | "output_type": "execute_result" 180 | } 181 | ], 182 | "source": [ 183 | "transform_net = TransformNet(base).to(device)\n", 184 | "transform_net.get_param_dict()" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 7, 190 | "metadata": { 191 | "ExecuteTime": { 192 | "end_time": "2018-07-13T10:38:54.307510Z", 193 | "start_time": "2018-07-13T10:38:51.954926Z" 194 | }, 195 | "code_folding": [], 196 | "scrolled": true 197 | }, 198 | "outputs": [], 199 | "source": [ 200 | "metanet = MetaNet(transform_net.get_param_dict()).to(device)" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "# 载入数据集\n", 208 | "\n", 209 | "> During training, each content image or style image is resized to keep the smallest dimension in the range [256, 480], and randomly cropped regions of size 256 × 256.\n", 210 | "\n", 211 | "## 载入 COCO 数据集和 WikiArt 数据集\n", 212 | "\n", 213 | "> The batch size of content images is 8 and the meta network is trained for 20 iterations before changing the style image." 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 6, 219 | "metadata": { 220 | "ExecuteTime": { 221 | "end_time": "2018-07-13T11:38:09.383610Z", 222 | "start_time": "2018-07-13T11:38:08.371037Z" 223 | } 224 | }, 225 | "outputs": [ 226 | { 227 | "name": "stdout", 228 | "output_type": "stream", 229 | "text": [ 230 | "Dataset ImageFolder\n", 231 | " Number of datapoints: 23806\n", 232 | " Root Location: /home/ypw/WikiArt/\n", 233 | " Transforms (if any): Compose(\n", 234 | " RandomResizedCrop(size=(256, 256), scale=(0.5333, 1), ratio=(1, 1), interpolation=PIL.Image.BILINEAR)\n", 235 | " ToTensor()\n", 236 | " Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", 237 | " )\n", 238 | " Target Transforms (if any): None\n", 239 | "--------------------\n", 240 | "Dataset ImageFolder\n", 241 | " Number of datapoints: 164062\n", 242 | " Root Location: /home/ypw/COCO/\n", 243 | " Transforms (if any): Compose(\n", 244 | " RandomResizedCrop(size=(256, 256), scale=(0.5333, 1), ratio=(1, 1), interpolation=PIL.Image.BILINEAR)\n", 245 | " ToTensor()\n", 246 | " Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", 247 | " )\n", 248 | " Target Transforms (if any): None\n" 249 | ] 250 | } 251 | ], 252 | "source": [ 253 | "data_transform = transforms.Compose([\n", 254 | " transforms.RandomResizedCrop(width, scale=(256/480, 1), ratio=(1, 1)), \n", 255 | " transforms.ToTensor(), \n", 256 | " tensor_normalizer\n", 257 | "])\n", 258 | "\n", 259 | "style_dataset = torchvision.datasets.ImageFolder('/home/ypw/WikiArt/', transform=data_transform)\n", 260 | "content_dataset = torchvision.datasets.ImageFolder('/home/ypw/COCO/', transform=data_transform)\n", 261 | "\n", 262 | "if is_hvd:\n", 263 | " train_sampler = torch.utils.data.distributed.DistributedSampler(\n", 264 | " content_dataset, num_replicas=hvd.size(), rank=hvd.rank())\n", 265 | " content_data_loader = torch.utils.data.DataLoader(content_dataset, batch_size=batch_size, \n", 266 | " num_workers=multiprocessing.cpu_count(),sampler=train_sampler)\n", 267 | "else:\n", 268 | " content_data_loader = torch.utils.data.DataLoader(content_dataset, batch_size=batch_size, \n", 269 | " shuffle=True, num_workers=multiprocessing.cpu_count())\n", 270 | "\n", 271 | "if not is_hvd or hvd.rank() == 0:\n", 272 | " print(style_dataset)\n", 273 | " print('-'*20)\n", 274 | " print(content_dataset)" 275 | ] 276 | }, 277 | { 278 | "cell_type": "markdown", 279 | "metadata": {}, 280 | "source": [ 281 | "# 测试 infer" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 8, 287 | "metadata": { 288 | "ExecuteTime": { 289 | "end_time": "2018-07-13T09:42:47.537476Z", 290 | "start_time": "2018-07-13T09:42:47.379446Z" 291 | } 292 | }, 293 | "outputs": [ 294 | { 295 | "name": "stdout", 296 | "output_type": "stream", 297 | "text": [ 298 | "features:\n", 299 | "[torch.Size([4, 64, 256, 256]),\n", 300 | " torch.Size([4, 128, 128, 128]),\n", 301 | " torch.Size([4, 256, 64, 64]),\n", 302 | " torch.Size([4, 512, 32, 32])]\n", 303 | "weights:\n", 304 | "[torch.Size([4, 18496]),\n", 305 | " torch.Size([4, 73856]),\n", 306 | " torch.Size([4, 147584]),\n", 307 | " torch.Size([4, 147584]),\n", 308 | " torch.Size([4, 147584]),\n", 309 | " torch.Size([4, 147584]),\n", 310 | " torch.Size([4, 147584]),\n", 311 | " torch.Size([4, 147584]),\n", 312 | " torch.Size([4, 147584]),\n", 313 | " torch.Size([4, 147584]),\n", 314 | " torch.Size([4, 147584]),\n", 315 | " torch.Size([4, 147584]),\n", 316 | " torch.Size([4, 73792]),\n", 317 | " torch.Size([4, 18464])]\n", 318 | "transformed_images:\n", 319 | "torch.Size([4, 3, 256, 256])\n" 320 | ] 321 | } 322 | ], 323 | "source": [ 324 | "metanet.eval()\n", 325 | "transform_net.eval()\n", 326 | "\n", 327 | "rands = torch.rand(4, 3, 256, 256).to(device)\n", 328 | "features = vgg16(rands);\n", 329 | "weights = metanet(mean_std(features));\n", 330 | "transform_net.set_weights(weights)\n", 331 | "transformed_images = transform_net(torch.rand(4, 3, 256, 256).to(device));\n", 332 | "\n", 333 | "if not is_hvd or hvd.rank() == 0:\n", 334 | " print('features:')\n", 335 | " display([x.shape for x in features])\n", 336 | " \n", 337 | " print('weights:')\n", 338 | " display([x.shape for x in weights.values()])\n", 339 | "\n", 340 | " print('transformed_images:')\n", 341 | " display(transformed_images.shape)" 342 | ] 343 | }, 344 | { 345 | "cell_type": "markdown", 346 | "metadata": {}, 347 | "source": [ 348 | "# 初始化一些变量" 349 | ] 350 | }, 351 | { 352 | "cell_type": "code", 353 | "execution_count": null, 354 | "metadata": { 355 | "ExecuteTime": { 356 | "end_time": "2018-07-12T13:05:07.481869Z", 357 | "start_time": "2018-07-12T13:05:07.398188Z" 358 | } 359 | }, 360 | "outputs": [], 361 | "source": [ 362 | "visualization_style_image = random.choice(style_dataset)[0].unsqueeze(0).to(device)\n", 363 | "visualization_content_images = torch.stack([random.choice(content_dataset)[0] for i in range(4)]).to(device)" 364 | ] 365 | }, 366 | { 367 | "cell_type": "code", 368 | "execution_count": null, 369 | "metadata": { 370 | "ExecuteTime": { 371 | "end_time": "2018-07-12T13:05:08.288833Z", 372 | "start_time": "2018-07-12T13:05:07.483858Z" 373 | } 374 | }, 375 | "outputs": [], 376 | "source": [ 377 | "if not is_hvd or hvd.rank() == 0:\n", 378 | " for f in glob('runs/*/.AppleDouble'):\n", 379 | " rmrf(f)\n", 380 | "\n", 381 | " rmrf('runs/' + model_name)\n", 382 | " writer = SummaryWriter('runs/'+model_name)\n", 383 | "else:\n", 384 | " writer = SummaryWriter('/tmp/'+model_name)\n", 385 | "\n", 386 | "visualization_style_image = random.choice(style_dataset)[0].unsqueeze(0).to(device)\n", 387 | "visualization_content_images = torch.stack([random.choice(content_dataset)[0] for i in range(4)]).to(device)\n", 388 | "\n", 389 | "writer.add_image('content_image', recover_tensor(visualization_content_images), 0)\n", 390 | "writer.add_graph(transform_net, (rands, ))\n", 391 | "\n", 392 | "del rands, features, weights, transformed_images" 393 | ] 394 | }, 395 | { 396 | "cell_type": "code", 397 | "execution_count": null, 398 | "metadata": { 399 | "ExecuteTime": { 400 | "end_time": "2018-07-12T13:05:08.334236Z", 401 | "start_time": "2018-07-12T13:05:08.329306Z" 402 | }, 403 | "scrolled": false 404 | }, 405 | "outputs": [], 406 | "source": [ 407 | "trainable_params = {}\n", 408 | "trainable_param_shapes = {}\n", 409 | "for model in [vgg16, transform_net, metanet]:\n", 410 | " for name, param in model.named_parameters():\n", 411 | " if param.requires_grad:\n", 412 | " trainable_params[name] = param\n", 413 | " trainable_param_shapes[name] = param.shape" 414 | ] 415 | }, 416 | { 417 | "cell_type": "markdown", 418 | "metadata": {}, 419 | "source": [ 420 | "# 开始训练" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": null, 426 | "metadata": { 427 | "ExecuteTime": { 428 | "end_time": "2018-07-12T13:05:09.472661Z", 429 | "start_time": "2018-07-12T13:05:08.337215Z" 430 | } 431 | }, 432 | "outputs": [], 433 | "source": [ 434 | "optimizer = optim.Adam(trainable_params.values(), 1e-3)\n", 435 | "\n", 436 | "if is_hvd:\n", 437 | " optimizer = hvd.DistributedOptimizer(optimizer, \n", 438 | " named_parameters=trainable_params.items())\n", 439 | " params = transform_net.state_dict()\n", 440 | " params.update(metanet.state_dict())\n", 441 | " hvd.broadcast_parameters(params, root_rank=0)" 442 | ] 443 | }, 444 | { 445 | "cell_type": "code", 446 | "execution_count": null, 447 | "metadata": { 448 | "ExecuteTime": { 449 | "end_time": "2018-07-12T13:06:43.549811Z", 450 | "start_time": "2018-07-12T13:05:09.476595Z" 451 | }, 452 | "code_folding": [], 453 | "scrolled": false 454 | }, 455 | "outputs": [], 456 | "source": [ 457 | "n_batch = len(content_data_loader)\n", 458 | "metanet.train()\n", 459 | "transform_net.train()\n", 460 | "\n", 461 | "for epoch in range(epochs):\n", 462 | " smoother = defaultdict(Smooth)\n", 463 | " with tqdm(enumerate(content_data_loader), total=n_batch) as pbar:\n", 464 | " for batch, (content_images, _) in pbar:\n", 465 | " n_iter = epoch*n_batch + batch\n", 466 | " \n", 467 | " # 每 20 个 batch 随机挑选一张新的风格图像,计算其特征\n", 468 | " if batch % 20 == 0:\n", 469 | " style_image = random.choice(style_dataset)[0].unsqueeze(0).to(device)\n", 470 | " style_features = vgg16(style_image)\n", 471 | " style_mean_std = mean_std(style_features)\n", 472 | " \n", 473 | " # 检查纯色\n", 474 | " x = content_images.cpu().numpy()\n", 475 | " if (x.min(-1).min(-1) == x.max(-1).max(-1)).any():\n", 476 | " continue\n", 477 | " \n", 478 | " optimizer.zero_grad()\n", 479 | " \n", 480 | " # 使用风格图像生成风格模型\n", 481 | " weights = metanet(mean_std(style_features))\n", 482 | " transform_net.set_weights(weights, 0)\n", 483 | " \n", 484 | " # 使用风格模型预测风格迁移图像\n", 485 | " content_images = content_images.to(device)\n", 486 | " transformed_images = transform_net(content_images)\n", 487 | "\n", 488 | " # 使用 vgg16 计算特征\n", 489 | " content_features = vgg16(content_images)\n", 490 | " transformed_features = vgg16(transformed_images)\n", 491 | " transformed_mean_std = mean_std(transformed_features)\n", 492 | " \n", 493 | " # content loss\n", 494 | " content_loss = content_weight * F.mse_loss(transformed_features[2], content_features[2])\n", 495 | " \n", 496 | " # style loss\n", 497 | " style_loss = style_weight * F.mse_loss(transformed_mean_std, \n", 498 | " style_mean_std.expand_as(transformed_mean_std))\n", 499 | " \n", 500 | " # total variation loss\n", 501 | " y = transformed_images\n", 502 | " tv_loss = tv_weight * (torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + \n", 503 | " torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])))\n", 504 | " \n", 505 | " # 求和\n", 506 | " loss = content_loss + style_loss + tv_loss \n", 507 | " \n", 508 | " loss.backward()\n", 509 | " optimizer.step()\n", 510 | " \n", 511 | " smoother['content_loss'] += content_loss.item()\n", 512 | " smoother['style_loss'] += style_loss.item()\n", 513 | " smoother['tv_loss'] += tv_loss.item()\n", 514 | " smoother['loss'] += loss.item()\n", 515 | " \n", 516 | " max_value = max([x.max().item() for x in weights.values()])\n", 517 | " \n", 518 | " writer.add_scalar('loss/loss', loss, n_iter)\n", 519 | " writer.add_scalar('loss/content_loss', content_loss, n_iter)\n", 520 | " writer.add_scalar('loss/style_loss', style_loss, n_iter)\n", 521 | " writer.add_scalar('loss/total_variation', tv_loss, n_iter)\n", 522 | " writer.add_scalar('loss/max', max_value, n_iter)\n", 523 | " \n", 524 | " s = 'Epoch: {} '.format(epoch+1)\n", 525 | " s += 'Content: {:.2f} '.format(smoother['content_loss'])\n", 526 | " s += 'Style: {:.1f} '.format(smoother['style_loss'])\n", 527 | " s += 'Loss: {:.2f} '.format(smoother['loss'])\n", 528 | " s += 'Max: {:.2f}'.format(max_value)\n", 529 | " \n", 530 | " if (batch + 1) % verbose_image_batch == 0:\n", 531 | " transform_net.eval()\n", 532 | " visualization_transformed_images = transform_net(visualization_content_images)\n", 533 | " transform_net.train()\n", 534 | " visualization_transformed_images = torch.cat([style_image, visualization_transformed_images])\n", 535 | " writer.add_image('debug', recover_tensor(visualization_transformed_images), n_iter)\n", 536 | " del visualization_transformed_images\n", 537 | " \n", 538 | " if (batch + 1) % verbose_hist_batch == 0:\n", 539 | " for name, param in weights.items():\n", 540 | " writer.add_histogram('transform_net.'+name, param.clone().cpu().data.numpy(), \n", 541 | " n_iter, bins='auto')\n", 542 | " \n", 543 | " for name, param in transform_net.named_parameters():\n", 544 | " writer.add_histogram('transform_net.'+name, param.clone().cpu().data.numpy(), \n", 545 | " n_iter, bins='auto')\n", 546 | " \n", 547 | " for name, param in metanet.named_parameters():\n", 548 | " l = name.split('.')\n", 549 | " l.remove(l[-1])\n", 550 | " writer.add_histogram('metanet.'+'.'.join(l), param.clone().cpu().data.numpy(), \n", 551 | " n_iter, bins='auto')\n", 552 | "\n", 553 | " pbar.set_description(s)\n", 554 | " \n", 555 | " del transformed_images, weights\n", 556 | " \n", 557 | " if not is_hvd or hvd.rank() == 0:\n", 558 | " torch.save(metanet.state_dict(), 'checkpoints/{}_{}.pth'.format(model_name, epoch+1))\n", 559 | " torch.save(transform_net.state_dict(), \n", 560 | " 'checkpoints/{}_transform_net_{}.pth'.format(model_name, epoch+1))\n", 561 | " \n", 562 | " torch.save(metanet.state_dict(), 'models/{}.pth'.format(model_name))\n", 563 | " torch.save(transform_net.state_dict(), 'models/{}_transform_net.pth'.format(model_name))" 564 | ] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": null, 569 | "metadata": {}, 570 | "outputs": [], 571 | "source": [] 572 | } 573 | ], 574 | "metadata": { 575 | "kernelspec": { 576 | "display_name": "Python 3", 577 | "language": "python", 578 | "name": "python3" 579 | }, 580 | "language_info": { 581 | "codemirror_mode": { 582 | "name": "ipython", 583 | "version": 3 584 | }, 585 | "file_extension": ".py", 586 | "mimetype": "text/x-python", 587 | "name": "python", 588 | "nbconvert_exporter": "python", 589 | "pygments_lexer": "ipython3", 590 | "version": "3.6.5" 591 | }, 592 | "toc": { 593 | "nav_menu": {}, 594 | "number_sections": true, 595 | "sideBar": true, 596 | "skip_h1_title": false, 597 | "toc_cell": false, 598 | "toc_position": {}, 599 | "toc_section_display": "block", 600 | "toc_window_display": false 601 | }, 602 | "varInspector": { 603 | "cols": { 604 | "lenName": "40", 605 | "lenType": 16, 606 | "lenVar": 40 607 | }, 608 | "kernels_config": { 609 | "python": { 610 | "delete_cmd_postfix": "", 611 | "delete_cmd_prefix": "del ", 612 | "library": "var_list.py", 613 | "varRefreshCmd": "print(var_dic_list())" 614 | }, 615 | "r": { 616 | "delete_cmd_postfix": ") ", 617 | "delete_cmd_prefix": "rm(", 618 | "library": "var_list.r", 619 | "varRefreshCmd": "cat(var_dic_list()) " 620 | } 621 | }, 622 | "position": { 623 | "height": "441px", 624 | "left": "934px", 625 | "right": "20px", 626 | "top": "120px", 627 | "width": "333px" 628 | }, 629 | "types_to_exclude": [ 630 | "module", 631 | "function", 632 | "builtin_function_or_method", 633 | "instance", 634 | "_Feature" 635 | ], 636 | "window_display": true 637 | } 638 | }, 639 | "nbformat": 4, 640 | "nbformat_minor": 2 641 | } 642 | -------------------------------------------------------------------------------- /code/Situation3.py: -------------------------------------------------------------------------------- 1 | 2 | # coding: utf-8 3 | 4 | # # 导入必要的库 5 | 6 | # In[4]: 7 | 8 | 9 | import os 10 | 11 | # os.environ['CUDA_VISIBLE_DEVICES'] = '4' 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.optim as optim 17 | 18 | import random 19 | from PIL import Image 20 | import matplotlib.pyplot as plt 21 | 22 | import torchvision 23 | import torchvision.transforms as transforms 24 | import torchvision.models as models 25 | import shutil 26 | from glob import glob 27 | 28 | from tensorboardX import SummaryWriter 29 | 30 | import numpy as np 31 | import multiprocessing 32 | 33 | import copy 34 | from tqdm import tqdm 35 | from collections import defaultdict 36 | 37 | import horovod.torch as hvd 38 | import torch.utils.data.distributed 39 | 40 | from utils import * 41 | from models import * 42 | import time 43 | 44 | from pprint import pprint 45 | display = pprint 46 | 47 | hvd.init() 48 | torch.cuda.set_device(hvd.local_rank()) 49 | 50 | device = torch.device("cuda:%s" %hvd.local_rank() if torch.cuda.is_available() else "cpu") 51 | 52 | 53 | # In[5]: 54 | 55 | 56 | is_hvd = False 57 | tag = 'nohvd' 58 | base = 32 59 | style_weight = 50 60 | content_weight = 1 61 | tv_weight = 1e-6 62 | epochs = 22 63 | 64 | batch_size = 8 65 | width = 256 66 | 67 | verbose_hist_batch = 100 68 | verbose_image_batch = 800 69 | 70 | model_name = f'metanet_base{base}_style{style_weight}_tv{tv_weight}_tag{tag}' 71 | print(f'model_name: {model_name}, rank: {hvd.rank()}') 72 | 73 | 74 | # In[ ]: 75 | 76 | 77 | def rmrf(path): 78 | try: 79 | shutil.rmtree(path) 80 | except: 81 | pass 82 | 83 | for f in glob('runs/*/.AppleDouble'): 84 | rmrf(f) 85 | 86 | rmrf('runs/' + model_name) 87 | 88 | 89 | # # 搭建模型 90 | 91 | # In[3]: 92 | 93 | 94 | vgg16 = models.vgg16(pretrained=True) 95 | vgg16 = VGG(vgg16.features[:23]).to(device).eval() 96 | 97 | 98 | # In[4]: 99 | 100 | 101 | transform_net = TransformNet(base).to(device) 102 | transform_net.get_param_dict() 103 | 104 | 105 | # In[7]: 106 | 107 | 108 | metanet = MetaNet(transform_net.get_param_dict()).to(device) 109 | 110 | 111 | # # 载入数据集 112 | # 113 | # > During training, each content image or style image is resized to keep the smallest dimension in the range [256, 480], and randomly cropped regions of size 256 × 256. 114 | # 115 | # ## 载入 COCO 数据集和 WikiArt 数据集 116 | # 117 | # > The batch size of content images is 8 and the meta network is trained for 20 iterations before changing the style image. 118 | 119 | # In[6]: 120 | 121 | 122 | data_transform = transforms.Compose([ 123 | transforms.RandomResizedCrop(width, scale=(256/480, 1), ratio=(1, 1)), 124 | transforms.ToTensor(), 125 | tensor_normalizer 126 | ]) 127 | 128 | style_dataset = torchvision.datasets.ImageFolder('/home/ypw/WikiArt/', transform=data_transform) 129 | content_dataset = torchvision.datasets.ImageFolder('/home/ypw/COCO/', transform=data_transform) 130 | 131 | if is_hvd: 132 | train_sampler = torch.utils.data.distributed.DistributedSampler( 133 | content_dataset, num_replicas=hvd.size(), rank=hvd.rank()) 134 | content_data_loader = torch.utils.data.DataLoader(content_dataset, batch_size=batch_size, 135 | num_workers=multiprocessing.cpu_count(),sampler=train_sampler) 136 | else: 137 | content_data_loader = torch.utils.data.DataLoader(content_dataset, batch_size=batch_size, 138 | shuffle=True, num_workers=multiprocessing.cpu_count()) 139 | 140 | if not is_hvd or hvd.rank() == 0: 141 | print(style_dataset) 142 | print('-'*20) 143 | print(content_dataset) 144 | 145 | 146 | # # 测试 infer 147 | 148 | # In[8]: 149 | 150 | 151 | metanet.eval() 152 | transform_net.eval() 153 | 154 | rands = torch.rand(4, 3, 256, 256).to(device) 155 | features = vgg16(rands); 156 | weights = metanet(mean_std(features)); 157 | transform_net.set_weights(weights) 158 | transformed_images = transform_net(torch.rand(4, 3, 256, 256).to(device)); 159 | 160 | if not is_hvd or hvd.rank() == 0: 161 | print('features:') 162 | display([x.shape for x in features]) 163 | 164 | print('weights:') 165 | display([x.shape for x in weights.values()]) 166 | 167 | print('transformed_images:') 168 | display(transformed_images.shape) 169 | 170 | 171 | # # 初始化一些变量 172 | 173 | # In[ ]: 174 | 175 | 176 | visualization_style_image = random.choice(style_dataset)[0].unsqueeze(0).to(device) 177 | visualization_content_images = torch.stack([random.choice(content_dataset)[0] for i in range(4)]).to(device) 178 | 179 | 180 | # In[ ]: 181 | 182 | 183 | if not is_hvd or hvd.rank() == 0: 184 | for f in glob('runs/*/.AppleDouble'): 185 | rmrf(f) 186 | 187 | rmrf('runs/' + model_name) 188 | writer = SummaryWriter('runs/'+model_name) 189 | else: 190 | writer = SummaryWriter('/tmp/'+model_name) 191 | 192 | visualization_style_image = random.choice(style_dataset)[0].unsqueeze(0).to(device) 193 | visualization_content_images = torch.stack([random.choice(content_dataset)[0] for i in range(4)]).to(device) 194 | 195 | writer.add_image('content_image', recover_tensor(visualization_content_images), 0) 196 | writer.add_graph(transform_net, (rands, )) 197 | 198 | del rands, features, weights, transformed_images 199 | 200 | 201 | # In[ ]: 202 | 203 | 204 | trainable_params = {} 205 | trainable_param_shapes = {} 206 | for model in [vgg16, transform_net, metanet]: 207 | for name, param in model.named_parameters(): 208 | if param.requires_grad: 209 | trainable_params[name] = param 210 | trainable_param_shapes[name] = param.shape 211 | 212 | 213 | # # 开始训练 214 | 215 | # In[ ]: 216 | 217 | 218 | optimizer = optim.Adam(trainable_params.values(), 1e-3) 219 | 220 | if is_hvd: 221 | optimizer = hvd.DistributedOptimizer(optimizer, 222 | named_parameters=trainable_params.items()) 223 | params = transform_net.state_dict() 224 | params.update(metanet.state_dict()) 225 | hvd.broadcast_parameters(params, root_rank=0) 226 | 227 | 228 | # In[ ]: 229 | 230 | 231 | n_batch = len(content_data_loader) 232 | metanet.train() 233 | transform_net.train() 234 | 235 | for epoch in range(epochs): 236 | smoother = defaultdict(Smooth) 237 | with tqdm(enumerate(content_data_loader), total=n_batch) as pbar: 238 | for batch, (content_images, _) in pbar: 239 | n_iter = epoch*n_batch + batch 240 | 241 | # 每 20 个 batch 随机挑选一张新的风格图像,计算其特征 242 | if batch % 20 == 0: 243 | style_image = random.choice(style_dataset)[0].unsqueeze(0).to(device) 244 | style_features = vgg16(style_image) 245 | style_mean_std = mean_std(style_features) 246 | 247 | # 检查纯色 248 | x = content_images.cpu().numpy() 249 | if (x.min(-1).min(-1) == x.max(-1).max(-1)).any(): 250 | continue 251 | 252 | optimizer.zero_grad() 253 | 254 | # 使用风格图像生成风格模型 255 | weights = metanet(mean_std(style_features)) 256 | transform_net.set_weights(weights, 0) 257 | 258 | # 使用风格模型预测风格迁移图像 259 | content_images = content_images.to(device) 260 | transformed_images = transform_net(content_images) 261 | 262 | # 使用 vgg16 计算特征 263 | content_features = vgg16(content_images) 264 | transformed_features = vgg16(transformed_images) 265 | transformed_mean_std = mean_std(transformed_features) 266 | 267 | # content loss 268 | content_loss = content_weight * F.mse_loss(transformed_features[2], content_features[2]) 269 | 270 | # style loss 271 | style_loss = style_weight * F.mse_loss(transformed_mean_std, 272 | style_mean_std.expand_as(transformed_mean_std)) 273 | 274 | # total variation loss 275 | y = transformed_images 276 | tv_loss = tv_weight * (torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + 277 | torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :]))) 278 | 279 | # 求和 280 | loss = content_loss + style_loss + tv_loss 281 | 282 | loss.backward() 283 | optimizer.step() 284 | 285 | smoother['content_loss'] += content_loss.item() 286 | smoother['style_loss'] += style_loss.item() 287 | smoother['tv_loss'] += tv_loss.item() 288 | smoother['loss'] += loss.item() 289 | 290 | max_value = max([x.max().item() for x in weights.values()]) 291 | 292 | writer.add_scalar('loss/loss', loss, n_iter) 293 | writer.add_scalar('loss/content_loss', content_loss, n_iter) 294 | writer.add_scalar('loss/style_loss', style_loss, n_iter) 295 | writer.add_scalar('loss/total_variation', tv_loss, n_iter) 296 | writer.add_scalar('loss/max', max_value, n_iter) 297 | 298 | s = 'Epoch: {} '.format(epoch+1) 299 | s += 'Content: {:.2f} '.format(smoother['content_loss']) 300 | s += 'Style: {:.1f} '.format(smoother['style_loss']) 301 | s += 'Loss: {:.2f} '.format(smoother['loss']) 302 | s += 'Max: {:.2f}'.format(max_value) 303 | 304 | if (batch + 1) % verbose_image_batch == 0: 305 | transform_net.eval() 306 | visualization_transformed_images = transform_net(visualization_content_images) 307 | transform_net.train() 308 | visualization_transformed_images = torch.cat([style_image, visualization_transformed_images]) 309 | writer.add_image('debug', recover_tensor(visualization_transformed_images), n_iter) 310 | del visualization_transformed_images 311 | 312 | if (batch + 1) % verbose_hist_batch == 0: 313 | for name, param in weights.items(): 314 | writer.add_histogram('transform_net.'+name, param.clone().cpu().data.numpy(), 315 | n_iter, bins='auto') 316 | 317 | for name, param in transform_net.named_parameters(): 318 | writer.add_histogram('transform_net.'+name, param.clone().cpu().data.numpy(), 319 | n_iter, bins='auto') 320 | 321 | for name, param in metanet.named_parameters(): 322 | l = name.split('.') 323 | l.remove(l[-1]) 324 | writer.add_histogram('metanet.'+'.'.join(l), param.clone().cpu().data.numpy(), 325 | n_iter, bins='auto') 326 | 327 | pbar.set_description(s) 328 | 329 | del transformed_images, weights 330 | 331 | if not is_hvd or hvd.rank() == 0: 332 | torch.save(metanet.state_dict(), 'checkpoints/{}_{}.pth'.format(model_name, epoch+1)) 333 | torch.save(transform_net.state_dict(), 334 | 'checkpoints/{}_transform_net_{}.pth'.format(model_name, epoch+1)) 335 | 336 | torch.save(metanet.state_dict(), 'models/{}.pth'.format(model_name)) 337 | torch.save(transform_net.state_dict(), 'models/{}_transform_net.pth'.format(model_name)) 338 | 339 | -------------------------------------------------------------------------------- /code/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | import torchvision.models as models 8 | 9 | import numpy as np 10 | from collections import defaultdict 11 | 12 | from utils import * 13 | 14 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 15 | 16 | class VGG(nn.Module): 17 | 18 | def __init__(self, features): 19 | super(VGG, self).__init__() 20 | self.features = features 21 | self.layer_name_mapping = { 22 | '3': "relu1_2", 23 | '8': "relu2_2", 24 | '15': "relu3_3", 25 | '22': "relu4_3" 26 | } 27 | for p in self.parameters(): 28 | p.requires_grad = False 29 | 30 | def forward(self, x): 31 | outs = [] 32 | for name, module in self.features._modules.items(): 33 | x = module(x) 34 | if name in self.layer_name_mapping: 35 | outs.append(x) 36 | return outs 37 | 38 | 39 | class MyConv2D(nn.Module): 40 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): 41 | super(MyConv2D, self).__init__() 42 | self.weight = torch.zeros((out_channels, in_channels, kernel_size, kernel_size)).to(device) 43 | self.bias = torch.zeros(out_channels).to(device) 44 | 45 | self.in_channels = in_channels 46 | self.out_channels = out_channels 47 | self.kernel_size = (kernel_size, kernel_size) 48 | self.stride = (stride, stride) 49 | 50 | def forward(self, x): 51 | return F.conv2d(x, self.weight, self.bias, self.stride) 52 | 53 | def extra_repr(self): 54 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 55 | ', stride={stride}') 56 | return s.format(**self.__dict__) 57 | 58 | 59 | class ResidualBlock(nn.Module): 60 | def __init__(self, channels): 61 | super(ResidualBlock, self).__init__() 62 | self.conv = nn.Sequential( 63 | *ConvLayer(channels, channels, kernel_size=3, stride=1), 64 | *ConvLayer(channels, channels, kernel_size=3, stride=1, relu=False) 65 | ) 66 | 67 | def forward(self, x): 68 | return self.conv(x) + x 69 | 70 | 71 | def ConvLayer(in_channels, out_channels, kernel_size=3, stride=1, 72 | upsample=None, instance_norm=True, relu=True, trainable=False): 73 | layers = [] 74 | if upsample: 75 | layers.append(nn.Upsample(mode='nearest', scale_factor=upsample)) 76 | layers.append(nn.ReflectionPad2d(kernel_size // 2)) 77 | if trainable: 78 | layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride)) 79 | else: 80 | layers.append(MyConv2D(in_channels, out_channels, kernel_size, stride)) 81 | if instance_norm: 82 | layers.append(nn.InstanceNorm2d(out_channels)) 83 | if relu: 84 | layers.append(nn.ReLU()) 85 | return layers 86 | 87 | 88 | class TransformNet(nn.Module): 89 | def __init__(self, base=8): 90 | super(TransformNet, self).__init__() 91 | self.base = base 92 | self.weights = [] 93 | self.downsampling = nn.Sequential( 94 | *ConvLayer(3, base, kernel_size=9, trainable=True), 95 | *ConvLayer(base, base*2, kernel_size=3, stride=2), 96 | *ConvLayer(base*2, base*4, kernel_size=3, stride=2), 97 | ) 98 | self.residuals = nn.Sequential(*[ResidualBlock(base*4) for i in range(5)]) 99 | self.upsampling = nn.Sequential( 100 | *ConvLayer(base*4, base*2, kernel_size=3, upsample=2), 101 | *ConvLayer(base*2, base, kernel_size=3, upsample=2), 102 | *ConvLayer(base, 3, kernel_size=9, instance_norm=False, relu=False, trainable=True), 103 | ) 104 | self.get_param_dict() 105 | 106 | def forward(self, X): 107 | y = self.downsampling(X) 108 | y = self.residuals(y) 109 | y = self.upsampling(y) 110 | return y 111 | 112 | def get_param_dict(self): 113 | """找出该网络所有 MyConv2D 层,计算它们需要的权值数量""" 114 | param_dict = defaultdict(int) 115 | def dfs(module, name): 116 | for name2, layer in module.named_children(): 117 | dfs(layer, '%s.%s' % (name, name2) if name != '' else name2) 118 | if module.__class__ == MyConv2D: 119 | param_dict[name] += int(np.prod(module.weight.shape)) 120 | param_dict[name] += int(np.prod(module.bias.shape)) 121 | dfs(self, '') 122 | return param_dict 123 | 124 | def set_my_attr(self, name, value): 125 | # 下面这个循环是一步步遍历类似 residuals.0.conv.1 的字符串,找到相应的权值 126 | target = self 127 | for x in name.split('.'): 128 | if x.isnumeric(): 129 | target = target.__getitem__(int(x)) 130 | else: 131 | target = getattr(target, x) 132 | 133 | # 设置对应的权值 134 | n_weight = np.prod(target.weight.shape) 135 | target.weight = value[:n_weight].view(target.weight.shape) 136 | target.bias = value[n_weight:].view(target.bias.shape) 137 | 138 | def set_weights(self, weights, i=0): 139 | """输入权值字典,对该网络所有的 MyConv2D 层设置权值""" 140 | for name, param in weights.items(): 141 | self.set_my_attr(name, weights[name][i]) 142 | 143 | 144 | class MetaNet(nn.Module): 145 | def __init__(self, param_dict): 146 | super(MetaNet, self).__init__() 147 | self.param_num = len(param_dict) 148 | self.hidden = nn.Linear(1920, 128*self.param_num) 149 | self.fc_dict = {} 150 | for i, (name, params) in enumerate(param_dict.items()): 151 | self.fc_dict[name] = i 152 | setattr(self, 'fc{}'.format(i+1), nn.Linear(128, params)) 153 | 154 | def forward(self, mean_std_features): 155 | hidden = F.relu(self.hidden(mean_std_features)) 156 | filters = {} 157 | for name, i in self.fc_dict.items(): 158 | fc = getattr(self, 'fc{}'.format(i+1)) 159 | filters[name] = fc(hidden[:,i*128:(i+1)*128]) 160 | return filters 161 | -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import torchvision 6 | import torchvision.transforms as transforms 7 | import torchvision.models as models 8 | 9 | import cv2 10 | import numpy as np 11 | from PIL import Image 12 | import matplotlib.pyplot as plt 13 | 14 | cnn_normalization_mean = [0.485, 0.456, 0.406] 15 | cnn_normalization_std = [0.229, 0.224, 0.225] 16 | tensor_normalizer = transforms.Normalize(mean=cnn_normalization_mean, std=cnn_normalization_std) 17 | epsilon = 1e-5 18 | 19 | 20 | def preprocess_image(image, target_width=None): 21 | """输入 PIL.Image 对象,输出标准化后的四维 tensor""" 22 | if target_width: 23 | t = transforms.Compose([ 24 | transforms.Resize(target_width), 25 | transforms.CenterCrop(target_width), 26 | transforms.ToTensor(), 27 | tensor_normalizer, 28 | ]) 29 | else: 30 | t = transforms.Compose([ 31 | transforms.ToTensor(), 32 | tensor_normalizer, 33 | ]) 34 | return t(image).unsqueeze(0) 35 | 36 | 37 | def image_to_tensor(image, target_width=None): 38 | """输入 OpenCV 图像,范围 0~255,BGR 顺序,输出标准化后的四维 tensor""" 39 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 40 | image = Image.fromarray(image) 41 | return preprocess_image(image, target_width) 42 | 43 | 44 | def read_image(path, target_width=None): 45 | """输入图像路径,输出标准化后的四维 tensor""" 46 | image = Image.open(path) 47 | return preprocess_image(image, target_width) 48 | 49 | 50 | def recover_image(tensor): 51 | """输入 GPU 上的四维 tensor,输出 0~255 范围的三维 numpy 矩阵,RGB 顺序""" 52 | image = tensor.detach().cpu().numpy() 53 | image = image * np.array(cnn_normalization_std).reshape((1, 3, 1, 1)) + \ 54 | np.array(cnn_normalization_mean).reshape((1, 3, 1, 1)) 55 | return (image.transpose(0, 2, 3, 1) * 255.).clip(0, 255).astype(np.uint8)[0] 56 | 57 | 58 | def recover_tensor(tensor): 59 | m = torch.tensor(cnn_normalization_mean).view(1, 3, 1, 1).to(tensor.device) 60 | s = torch.tensor(cnn_normalization_std).view(1, 3, 1, 1).to(tensor.device) 61 | tensor = tensor * s + m 62 | return tensor.clamp(0, 1) 63 | 64 | 65 | def imshow(tensor, title=None): 66 | """输入 GPU 上的四维 tensor,然后绘制该图像""" 67 | image = recover_image(tensor) 68 | print(image.shape) 69 | plt.imshow(image) 70 | if title is not None: 71 | plt.title(title) 72 | 73 | 74 | def mean_std(features): 75 | """输入 VGG16 计算的四个特征,输出每张特征图的均值和标准差,长度为1920""" 76 | mean_std_features = [] 77 | for x in features: 78 | x = x.view(*x.shape[:2], -1) 79 | x = torch.cat([x.mean(-1), torch.sqrt(x.var(-1) + epsilon)], dim=-1) 80 | n = x.shape[0] 81 | x2 = x.view(n, 2, -1).transpose(2, 1).contiguous().view(n, -1) # 【mean, ..., std, ...] to [mean, std, ...] 82 | mean_std_features.append(x2) 83 | mean_std_features = torch.cat(mean_std_features, dim=-1) 84 | return mean_std_features 85 | 86 | 87 | class Smooth: 88 | # 对输入的数据进行滑动平均 89 | def __init__(self, windowsize=100): 90 | self.window_size = windowsize 91 | self.data = np.zeros((self.window_size, 1), dtype=np.float32) 92 | self.index = 0 93 | 94 | def __iadd__(self, x): 95 | if self.index == 0: 96 | self.data[:] = x 97 | self.data[self.index % self.window_size] = x 98 | self.index += 1 99 | return self 100 | 101 | def __float__(self): 102 | return float(self.data.mean()) 103 | 104 | def __format__(self, f): 105 | return self.__float__().__format__(f) -------------------------------------------------------------------------------- /imgs/Conv2d_MyConv2D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/imgs/Conv2d_MyConv2D.png -------------------------------------------------------------------------------- /imgs/S1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/imgs/S1.jpg -------------------------------------------------------------------------------- /imgs/S2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/imgs/S2.jpg -------------------------------------------------------------------------------- /imgs/S2_speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/imgs/S2_speed.png -------------------------------------------------------------------------------- /imgs/S3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/imgs/S3.jpg -------------------------------------------------------------------------------- /imgs/S3_speed.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/imgs/S3_speed.png -------------------------------------------------------------------------------- /imgs/VGG16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/imgs/VGG16.png -------------------------------------------------------------------------------- /imgs/metanet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/imgs/metanet.png -------------------------------------------------------------------------------- /imgs/monitor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/imgs/monitor.png -------------------------------------------------------------------------------- /imgs/transform_net.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/imgs/transform_net.png -------------------------------------------------------------------------------- /imgs/transform_net2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/imgs/transform_net2.png -------------------------------------------------------------------------------- /imgs/transform_net3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/imgs/transform_net3.png -------------------------------------------------------------------------------- /imgs/weights_diverge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/imgs/weights_diverge.png -------------------------------------------------------------------------------- /keras_version/3_situation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | # References 5 | - (http://arxiv.org/abs/1709.04111) 6 | ''' 7 | 8 | from PIL import Image 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from keras.preprocessing.image import load_img, img_to_array, DirectoryIterator, ImageDataGenerator 12 | from keras.applications import vgg19 13 | from keras import backend as K 14 | from keras.models import Sequential, Model 15 | from keras.layers import * 16 | from keras.optimizers import * 17 | from keras.callbacks import Callback, ModelCheckpoint, ReduceLROnPlateau 18 | from util import * 19 | import matplotlib 20 | from scipy.misc import imsave 21 | from joblib import dump, load 22 | # matplotlib.use('agg') 23 | 24 | class Adagrad2(Adagrad): 25 | def __init__(self, norm_val=[], clipvalue2 = 10.0, **kwargs): 26 | super(Adagrad2, self).__init__(**kwargs) 27 | self.norm_val = norm_val 28 | self.clipvalue2 = clipvalue2 29 | 30 | def get_gradients(self, loss, params): 31 | grads = K.gradients(loss, params) 32 | # grads = [grad for x,grad in zip(params,grads)] 33 | # print(self.norm_val) 34 | # print(params) 35 | if hasattr(self, 'clipnorm') and self.clipnorm > 0: 36 | norm = K.sqrt(sum([K.sum(K.square(g)) for g in grads])) 37 | grads = [clip_norm(g, self.clipnorm, norm) if x in self.norm_val else g for x,g in zip(params,grads)] 38 | if hasattr(self, 'clipvalue') and self.clipvalue > 0: 39 | grads = [K.clip(g, -self.clipvalue, self.clipvalue) if x in self.norm_val else K.clip(g, -self.clipvalue2, self.clipvalue2) for x,g in zip(params,grads)] 40 | return grads 41 | 42 | content_image_path = 'train_content/1/COCO_train2014_000000000034.jpg' 43 | style_image_path = 'train_style/1/122.jpg' 44 | 45 | # dimensions of the generated picture. 46 | img_nrows = 400 47 | img_ncols = 400 48 | scale = 1/1. 49 | 50 | style_weight = 0.75e-2 51 | content_weight = 4e-4 52 | tv_weight = 0.6e-8 # total variance 53 | print style_weight,content_weight,tv_weight,scale 54 | 55 | def preprocess_image_vgg19(image_path): 56 | img = load_img(image_path, target_size=(img_nrows, img_ncols)) 57 | img = img_to_array(img) 58 | img = np.expand_dims(img, axis=0) 59 | img = img[:, :, :, ::-1] * scale 60 | img = np.squeeze(img, axis=0) 61 | return img 62 | 63 | def preprocess_image(img): 64 | img = img[ :, :, ::-1] * scale 65 | return img*1.0 66 | 67 | def deprocess_image(x): 68 | if K.image_data_format() == 'channels_first': 69 | x = x.reshape((3, img_nrows, img_ncols)) 70 | x = x.transpose((1, 2, 0)) 71 | else: 72 | x = x.reshape((img_nrows, img_ncols, 3)) 73 | x = x[:, :, ::-1] / scale 74 | x = np.clip(x, 0, 255).astype('uint8') 75 | return x 76 | 77 | def imshow(image, title=None): 78 | image = np.array(image).astype('uint8') 79 | plt.imshow(image) 80 | if title is not None: 81 | plt.title(title) 82 | 83 | channels = 3 84 | 85 | 86 | 87 | def get_Net(input_shape = (img_nrows,img_ncols,channels),style_weight = style_weight,content_weight = content_weight,tv_weight =tv_weight): 88 | import collections 89 | residual_size = 48 90 | 91 | class Parm(object): 92 | params_pos = 0 93 | params_pos2 = 0 94 | params_dict = [ 95 | # ('cin', ((5, 5, 3, 16), 16)), 96 | # ('c1', ((3, 3, 16, 16), 16)), 97 | ('c2', ((3, 3, 32, residual_size), residual_size)), 98 | ('rc1-1', ((3, 3, residual_size, residual_size), residual_size)), 99 | ('rc1-2', ((3, 3, residual_size, residual_size), residual_size)), 100 | ('rc2-1', ((3, 3, residual_size, residual_size), residual_size)), 101 | ('rc2-2', ((3, 3, residual_size, residual_size), residual_size)), 102 | ('rc3-1', ((3, 3, residual_size, residual_size), residual_size)), 103 | ('rc3-2', ((3, 3, residual_size, residual_size), residual_size)), 104 | ('rc4-1', ((3, 3, residual_size, residual_size), residual_size)), 105 | ('rc4-2', ((3, 3, residual_size, residual_size), residual_size)), 106 | ('rc5-1', ((3, 3, residual_size, residual_size), residual_size)), 107 | ('rc5-2', ((3, 3, residual_size, residual_size), residual_size)), 108 | ('rc6-1', ((3, 3, residual_size, residual_size), residual_size)), 109 | ('rc6-2', ((3, 3, residual_size, residual_size), residual_size)), 110 | ('c3', ((3, 3, residual_size, 48), 48))] 111 | # ('c4', ((3, 3, 48, 16), 16))] 112 | # ('cout', ((7, 7, 16, 3), 3))] 113 | parms = Parm() 114 | parms2 = Parm() 115 | 116 | class ConvLayer(object): 117 | def __init__(self, filters, kernel_size=3, stride=1, 118 | upsample=None, instance_norm=False, activation='relu', trainable = False, parms = None): 119 | super(ConvLayer, self).__init__() 120 | self.upsample = upsample 121 | self.filters = filters 122 | self.kernel_size = kernel_size 123 | self.stride = stride 124 | self.activation = activation 125 | self.instance_norm = instance_norm 126 | self.trainable = trainable 127 | self.parms = parms 128 | 129 | def __call__(self, inputs): 130 | if self.trainable: 131 | x = inputs 132 | else: 133 | x = inputs[0] 134 | name = None 135 | if self.parms: 136 | params_pos = self.parms.params_pos 137 | params_pos2 = self.parms.params_pos2 138 | params_dict = self.parms.params_dict 139 | weight_start = params_pos2 140 | weight_shape = params_dict[params_pos][1] 141 | name = params_dict[params_pos][0] 142 | self.parms.params_pos += 1 143 | self.parms.params_pos2 += np.prod(weight_shape[0]) + weight_shape[1] 144 | if self.upsample: 145 | x = UpSampling2D(size=(self.upsample, self.upsample))(x) 146 | x = ReflectPadding2D(self.kernel_size//2)(x) 147 | if self.activation == 'prelu': 148 | if self.trainable: 149 | x = Conv2D(self.filters, self.kernel_size, strides=(self.stride, self.stride),name = name)(x) 150 | else: 151 | x = MyConv2D(self.filters, self.kernel_size, strides=(self.stride, self.stride),shape = (weight_start,weight_shape),name = name)([x,inputs[1]]) 152 | x = PReLU()(x) 153 | else: 154 | if self.trainable: 155 | x = Conv2D(self.filters, self.kernel_size, strides=(self.stride, self.stride), 156 | activation=self.activation,name = name)(x) 157 | else: 158 | x = MyConv2D(self.filters, self.kernel_size, strides=(self.stride, self.stride), 159 | activation=self.activation,shape = (weight_start,weight_shape),name = name)([x,inputs[1]]) 160 | if self.instance_norm: 161 | x = InstanceNormalization2D2()(x) 162 | return x 163 | 164 | class ResidualBlock(object): 165 | def __call__(self, inputs): 166 | if self.trainable: 167 | x = inputs 168 | out = self.conv1(x) 169 | out = self.conv2(out) 170 | else: 171 | x = inputs[0] 172 | out = self.conv1([x, inputs[1]]) 173 | out = self.conv2([out, inputs[1]]) 174 | out = add([out, x]) 175 | return out 176 | 177 | def __init__(self, filters, parms, trainable = False): 178 | super(ResidualBlock, self).__init__() 179 | self.conv1 = ConvLayer(filters, kernel_size=3, stride=1, parms = parms, trainable = trainable) 180 | self.conv2 = ConvLayer(filters, kernel_size=3, stride=1, parms = parms, trainable = trainable, activation=None) 181 | self.parms = parms 182 | self.trainable = trainable 183 | 184 | def vgg_pre(x): 185 | mul = np.array([1.0]*3).reshape(1,1,1,3) 186 | meanoffset = (np.array([-103.939,-116.779,-123.68]) * scale).reshape(1,1,1,3) 187 | assert K.ndim(x) == 4 188 | return (x + meanoffset)/mul 189 | 190 | def gram_matrix(features,shape): 191 | batch, ch, h, w = shape 192 | if K.image_data_format() != 'channels_first': 193 | features = K.permute_dimensions(features, (0, 3, 1, 2)) 194 | batch, h, w,ch = shape 195 | features = K.reshape(features,(-1, ch, h * w)) 196 | gram = K.batch_dot(features, K.permute_dimensions(features, (0, 2, 1))) / ((int(ch) * int(h) * int(w) * 10)) 197 | return gram 198 | 199 | input = Input(shape=(input_shape), dtype="float32") # content 200 | input2 = Input(shape=(input_shape), dtype="float32") # style 201 | 202 | # vgg 203 | vgg = vgg19.VGG19(weights='imagenet', include_top=False) 204 | vgg.trainable = False 205 | outputs_dict = dict([(layer.name, layer.output) for layer in vgg.layers]) 206 | content_layers = ['block5_conv2', 207 | 'block4_conv3','block1_conv2'] 208 | content_weights = [10.0, 1.0, 0.8, 1.0, 1.0] 209 | layer_features = map(lambda x: outputs_dict[x], content_layers) 210 | content_feature_model = Model(vgg.input,layer_features) 211 | content_feature_model.trainable = False 212 | 213 | style_layers = ['block3_conv2', 214 | 'block4_conv3','block1_conv2', 215 | 'block2_conv2','block5_conv2'] 216 | style_weights = [1.0, 0.5, 1.0, 1.0, 10.0] 217 | layer_features = map(lambda x: outputs_dict[x], style_layers) 218 | style_feature_model = Model(vgg.input, layer_features) 219 | style_feature_model.trainable = False 220 | 221 | 222 | # meta 223 | meta_input1 = concatenate([Mean_std()(l) for l in style_feature_model(Lambda(lambda x: vgg_pre(x))(input2))]) 224 | x = Lambda(lambda x: 1/15 * x)(input2) 225 | x = ConvLayer(64, kernel_size=9, stride=1, trainable=True)(x) 226 | x = ConvLayer(64, kernel_size=3, stride=2, trainable=True)(x) 227 | x = ConvLayer(256, kernel_size=3, stride=2, trainable=True)(x) 228 | meta_input2 = Mean_std()(x) 229 | meta_output_list = [] 230 | hidden = concatenate([Dense(128, activation='relu')(meta_input1), Dense(16, activation='sigmoid')(meta_input2), 231 | Dense(128, activation='relu')(meta_input2)]) 232 | hidden = Dropout(0.05)(hidden) 233 | for name,shape in parms.params_dict: 234 | hidden2 = concatenate([Dense(64,activation = 'relu')(meta_input1),Dense(16,activation = 'sigmoid')(meta_input2),Dense(6,activation = 'relu')(meta_input2)]) 235 | hidden2 = concatenate([Dropout(0.05)(hidden2), hidden]) 236 | kernel = Dense(np.prod(shape[0]),activation = 'linear')(hidden2) 237 | bias = Dense(shape[1],activation = 'linear')(hidden2) 238 | meta_output_list.append(Lambda(lambda x:0.0001*x)(kernel)) 239 | meta_output_list.append(Lambda(lambda x:0.0001*x)(bias)) 240 | meta_output = (concatenate(meta_output_list)) 241 | meta_model = Model(input2,meta_output) 242 | 243 | # transform 244 | in_encode = ConvLayer(16, kernel_size=9, stride=1, trainable=True) 245 | x = Lambda(lambda x: 1/15. * x)(input) 246 | x = in_encode(x) 247 | x = ConvLayer(32, kernel_size=3, stride=2, trainable=True)(x) 248 | x = ConvLayer(residual_size, kernel_size=3, stride=2, parms = parms)([x,meta_output]) 249 | for i in range(6): 250 | x = ResidualBlock(residual_size, parms = parms)([x,meta_output]) 251 | x = ConvLayer(48, kernel_size=3, stride=1, upsample=2, parms = parms)([x,meta_output]) 252 | x = ConvLayer(16, kernel_size=3, stride=1,upsample=2, trainable=True)(x) 253 | out_decode = ConvLayer(3, kernel_size=9, stride=1, activation=None, trainable=True) 254 | x = out_decode(x) 255 | output = x 256 | output = Lambda(lambda x: x * 15.)(output) 257 | g_model = Model(inputs=[input,input2], outputs=output) 258 | 259 | 260 | def total_variation_loss(x): 261 | assert K.ndim(x) == 4 262 | if K.image_data_format() == 'channels_first': 263 | a = K.square(x[:, :, :img_nrows - 1, :img_ncols - 1] - x[:, :, 1:, :img_ncols - 1]) 264 | b = K.square(x[:, :, :img_nrows - 1, :img_ncols - 1] - x[:, :, :img_nrows - 1, 1:]) 265 | else: 266 | a = K.square(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, 1:, :img_ncols - 1, :]) 267 | b = K.square(x[:, :img_nrows - 1, :img_ncols - 1, :] - x[:, :img_nrows - 1, 1:, :]) 268 | c = K.square(K.relu(0. - x[:, :, :, :])) 269 | d = K.square(K.relu(x[:, :, :, :] - 255.)) 270 | 271 | return K.sum(K.batch_flatten(K.pow(a + b, 1.25)),axis = -1,keepdims = True) + K.sum(K.batch_flatten(c + d),axis = -1,keepdims = True) * 1 272 | 273 | base_image_content_features = [layer for layer in (content_feature_model(Lambda(lambda x: vgg_pre(x))(input)))] 274 | combination_content_features = [layer for layer in (content_feature_model(Lambda(lambda x: vgg_pre(x))(output)))] 275 | content_loss_list = [] 276 | for i in range(len(content_layers)): 277 | content_loss_lambda = Lambda( 278 | lambda x: content_weights[i] * K.mean(K.batch_flatten(K.square(x[0] - x[1])), axis=-1, 279 | keepdims=True)) 280 | content_loss_list.append(content_loss_lambda( 281 | [base_image_content_features[i], combination_content_features[i]])) 282 | content_loss = add(content_loss_list) 283 | 284 | 285 | base_image_style_features = [InstanceNormalization2D()(layer) for layer in (style_feature_model(Lambda(lambda x:vgg_pre(x))(input2)))] 286 | combination_style_features = [InstanceNormalization2D()(layer) for layer in (style_feature_model(Lambda(lambda x:vgg_pre(x))(output)))] 287 | style_loss_list = [] 288 | for i in range(len(style_layers)): 289 | style_loss_lambda = Lambda(lambda x: style_weights[i] * K.mean( 290 | K.batch_flatten(K.square(Mean_std()(x[0]) - Mean_std()(x[1]))), axis=-1, keepdims=True)) 291 | style_loss_list.append(style_loss_lambda([base_image_style_features[i], combination_style_features[i]])) 292 | style_loss = add(style_loss_list) 293 | 294 | tv_loss_lambda = Lambda(lambda x:total_variation_loss(x)) 295 | tv_loss = tv_loss_lambda(output) 296 | 297 | total_loss = Lambda(lambda x: x[0] * style_weight + x[1] * content_weight + x[2] * tv_weight, name='totalloss')([style_loss,content_loss,tv_loss]) 298 | loss_model = Model(inputs=[input,input2], outputs=[content_loss,style_loss,tv_loss]) 299 | loss_model_debug = Model(inputs=[input, input2], 300 | outputs=style_loss_list + content_loss_list) 301 | train_model = Model(inputs=[input,input2], outputs=total_loss) 302 | 303 | return g_model,loss_model,train_model,meta_model,loss_model_debug 304 | 305 | 306 | 307 | batch_size = 4 308 | train_datagen = ImageDataGenerator( 309 | rescale=1.0) 310 | 311 | content_generator = train_datagen.flow_from_directory( 312 | 'train_content', 313 | target_size=(img_nrows, img_ncols), 314 | batch_size=batch_size * 30 * 4, 315 | class_mode='binary') 316 | 317 | style_generator = train_datagen.flow_from_directory( 318 | 'train_style', 319 | target_size=(img_nrows, img_ncols), 320 | batch_size=batch_size, 321 | class_mode='binary') 322 | 323 | 324 | 325 | g_model,loss_model,train_model,meta_model,loss_model_debug = get_Net(style_weight = style_weight,content_weight = content_weight,tv_weight =tv_weight) 326 | g_model.summary() 327 | 328 | 329 | 330 | class LossEvaluation(Callback): 331 | def __init__(self, validation_img = None, interval=1, loss_model = None): 332 | super(Callback, self).__init__() 333 | self.interval = interval 334 | self.loss_model = loss_model 335 | self.validation_img = validation_img 336 | 337 | def on_epoch_end(self, epoch, logs={}): 338 | if epoch % self.interval == 0: 339 | y_pred = self.loss_model.predict(self.validation_img, verbose=0, batch_size=2 ** 9) 340 | content_loss, style_loss, tv_loss = np.mean(y_pred, axis=1) 341 | print("epoch: %d - content_loss: %.3f, style_loss: %.3f, tv_loss: %.3f, total_loss: %.3f " % 342 | (epoch + 1, content_loss * content_weight, style_loss * style_weight, tv_loss * tv_weight, 343 | content_loss * content_weight+ style_loss * style_weight + tv_loss * tv_weight)) 344 | 345 | 346 | 347 | print('train') 348 | count = 0 349 | end = 1200 350 | lr = 0.0005 351 | val = preprocess_image_vgg19('./c.jpg') 352 | style_val2 = preprocess_image_vgg19('./picasso.jpg') 353 | style_val = preprocess_image_vgg19('./a.jpg') 354 | weights_norm = meta_model.trainable_weights 355 | opt = Adam(lr = 0.00004, clipnorm=2, clipvalue=5e0) 356 | train_model.compile(loss=lambda y_true, y_pred: y_pred,optimizer=opt) 357 | data_x = [] 358 | data_x2 = [] 359 | style_img = preprocess_image(style_generator.next()[0][1]) 360 | beta = 30 * 4 361 | ep = 1 362 | for content_images,_ in content_generator: 363 | count += 1 364 | if count % 50 == 0 and count < 60: 365 | print('-----') 366 | opt = Adagrad2(lr = 0.0001, clipnorm=2, clipvalue=5e0,norm_val = weights_norm, clipvalue2=2e0) 367 | train_model.compile(loss=lambda y_true, y_pred: y_pred, optimizer=opt) 368 | beta = 15 * 4 369 | ep = 1 370 | lossEvaluation = LossEvaluation([np.array([val]), np.array([style_val]).repeat(1, axis=0)], 2, loss_model) 371 | lossEvaluation2 = LossEvaluation([np.array([val]), np.array([style_val2]).repeat(1, axis=0)], 2, loss_model) 372 | 373 | for i in range(len(content_images)): 374 | content_images[i] = preprocess_image(content_images[i]) 375 | data_x.append(content_images[i]) 376 | data_x2.append(style_img) 377 | if i % (batch_size*beta)== 0: 378 | style_img = preprocess_image(style_generator.next()[0][1]) 379 | pass 380 | 381 | train_model.fit([np.array(data_x), np.array(data_x2)], np.array([0] * np.array(data_x).shape[0]), 382 | batch_size=batch_size/2, epochs=ep, 383 | shuffle=True, verbose=0, callbacks=[lossEvaluation,lossEvaluation2]) 384 | data_x = [] 385 | data_x2 = [] 386 | 387 | if count % 3 == 0: 388 | print count 389 | output_image = g_model.predict([np.array([val]), np.array([style_val])]) 390 | imsave('debug/output48_%s.jpg'%(count%100), deprocess_image(output_image[0].copy())) 391 | output_image = g_model.predict([np.array([val]), np.array([style_val2])]) 392 | imsave('debug/output482_%s.jpg' %(count%100), deprocess_image(output_image[0].copy())) 393 | if count == end: 394 | break 395 | 396 | train_model.save_weights('./s7.h5',overwrite=True) 397 | g_model.save_weights('./g7.h5',overwrite=True) 398 | 399 | -------------------------------------------------------------------------------- /keras_version/README.md: -------------------------------------------------------------------------------- 1 | # 风格迁移keras版本 2 | 3 | 4 | 和前面的torch版思路一样,我这里为了优化速度以及模型,做了一些细节改动。可以对比参考。 5 | keras版本的模型权重大约在200MB左右,在单1080TI卡训练时间12-24h。 6 | 7 | 已经训练好的权重文件:链接: https://pan.baidu.com/s/15B77PssGS4hBRsj7i3NRKw 密码: kmh8 8 | 9 | 最终效果如图所示: 10 | 11 | ![](res/all.jpg) 12 | 13 | ## Metanet 14 | 15 | 除了vgg外自己做了一个简易的cnn捕捉特征。另外将style的输出提取了一部分公共hidden,增强鲁棒性的同时还减小了模型大小。 16 | 输出时使用了0.0001的因子,为了控制kernel和bias参数不至于震荡。 17 | 18 | 19 | ```py 20 | meta_input1 = concatenate([Mean_std()(l) for l in style_feature_model(Lambda(lambda x: vgg_pre(x))(input2))]) 21 | x = Lambda(lambda x: 1/15 * x)(input2) 22 | x = ConvLayer(64, kernel_size=9, stride=1, trainable=True)(x) 23 | x = ConvLayer(64, kernel_size=3, stride=2, trainable=True)(x) 24 | x = ConvLayer(256, kernel_size=3, stride=2, trainable=True)(x) 25 | meta_input2 = Mean_std()(x) 26 | meta_output_list = [] 27 | hidden = concatenate([Dense(128, activation='relu')(meta_input1), Dense(16, activation='sigmoid')(meta_input2), 28 | Dense(128, activation='relu')(meta_input2)]) 29 | hidden = Dropout(0.05)(hidden) 30 | for name,shape in parms.params_dict: 31 | hidden2 = concatenate([Dense(64,activation = 'relu')(meta_input1),Dense(32,activation = 'sigmoid')(meta_input2),Dense(16,activation = 'relu')(meta_input2)]) 32 | hidden2 = concatenate([Dropout(0.1)(hidden2), hidden]) 33 | kernel = Dense(np.prod(shape[0]),activation = 'linear')(hidden2) 34 | bias = Dense(shape[1],activation = 'linear')(hidden2) 35 | meta_output_list.append(Lambda(lambda x:0.0001*x)(kernel)) 36 | meta_output_list.append(Lambda(lambda x:0.0001*x)(bias)) 37 | meta_output = (concatenate(meta_output_list)) 38 | meta_model = Model(input2,meta_output) 39 | ``` 40 | 41 | ## Transformnet 42 | 43 | 因为keras的vgg19并没有将input归一化除255,所以我这里权衡了归一及不归一的方案,折衷使用了15的scaler。 44 | 在upsample和downsample过程并没有大量使用metanet控制conv权值,因为这部分更多是进行sample,个人认为风格化的主要过程发生在residual部分。 45 | 46 | 47 | ## 端到端的训练 48 | 49 | 端到端的网络将Conv2D中的权重改为了metanet的输入,这里使用了我自定义的Conv: 50 | 51 | ```py 52 | def call(self, inputs): 53 | weight_start,weight_shape = self.shape 54 | kernel = K.reshape(inputs[1][:, weight_start:weight_start + np.prod(weight_shape[0])], (-1,) + weight_shape[0]) 55 | bias = K.reshape(inputs[1][:, 56 | weight_start + np.prod(weight_shape[0]):weight_start + np.prod(weight_shape[0]) + weight_shape[ 57 | 1]] 58 | , (-1,weight_shape[1])) 59 | kernel = kernel[0,:,:,:,:] 60 | bias = bias[0,:] 61 | outputs = K.conv2d( 62 | inputs[0], 63 | kernel, 64 | strides=self.strides, 65 | padding=self.padding, 66 | data_format=self.data_format, 67 | dilation_rate=self.dilation_rate) 68 | 69 | if self.use_bias: 70 | outputs = K.bias_add( 71 | outputs, 72 | bias, 73 | data_format=self.data_format) 74 | 75 | if self.activation is not None: 76 | return self.activation(outputs) 77 | return outputs 78 | ``` 79 | 80 | 放弃了论文中的单风格图片训练以及切换风格图片的方式,而是使用端到端的方式。在训练初期保留单风格的方式训练,而到训练了30种风格以后切换成4-8张风格图片shuffle+batch输入,加快收敛速度。 81 | 82 | ## 梯度控制optimizers 83 | 84 | 升级版的adagrad,可以对某些特定layer(metanet)进行梯度控制,否则无法收敛。 85 | 86 | ```py 87 | class Adagrad2(Adagrad): 88 | def __init__(self, norm_val=[], clipvalue2 = 10.0, **kwargs): 89 | super(Adagrad2, self).__init__(**kwargs) 90 | self.norm_val = norm_val 91 | self.clipvalue2 = clipvalue2 92 | 93 | def get_gradients(self, loss, params): 94 | grads = K.gradients(loss, params) 95 | if hasattr(self, 'clipnorm') and self.clipnorm > 0: 96 | norm = K.sqrt(sum([K.sum(K.square(g)) for g in grads])) 97 | grads = [clip_norm(g, self.clipnorm, norm) if x in self.norm_val else g for x,g in zip(params,grads)] 98 | if hasattr(self, 'clipvalue') and self.clipvalue > 0: 99 | grads = [K.clip(g, -self.clipvalue, self.clipvalue) if x in self.norm_val else K.clip(g, -self.clipvalue2, self.clipvalue2) for x,g in zip(params,grads)] 100 | return grads 101 | ``` 102 | 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /keras_version/a.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/a.jpg -------------------------------------------------------------------------------- /keras_version/c.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/c.jpg -------------------------------------------------------------------------------- /keras_version/demo.sh: -------------------------------------------------------------------------------- 1 | python demo3.py res/c2.jpg res/style1.jpg res/c2_1.jpg 2 | python demo3.py res/c2.jpg res/style2.jpg res/c2_2.jpg 3 | python demo3.py res/c2.jpg res/style3.jpg res/c2_3.jpg 4 | python demo3.py res/c2.jpg res/style5.jpg res/c2_5.jpg 5 | python demo3.py res/c3.jpg res/style1.jpg res/c3_1.jpg 6 | python demo3.py res/c3.jpg res/style2.jpg res/c3_2.jpg 7 | python demo3.py res/c3.jpg res/style3.jpg res/c3_3.jpg 8 | python demo3.py res/c3.jpg res/style5.jpg res/c3_5.jpg 9 | python demo3.py res/c4.jpg res/style1.jpg res/c4_1.jpg 10 | python demo3.py res/c4.jpg res/style2.jpg res/c4_2.jpg 11 | python demo3.py res/c4.jpg res/style3.jpg res/c4_3.jpg 12 | python demo3.py res/c4.jpg res/style5.jpg res/c4_5.jpg -------------------------------------------------------------------------------- /keras_version/demo3.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | # References 5 | - (http://arxiv.org/abs/1709.04111) 6 | ''' 7 | 8 | from PIL import Image 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from keras.preprocessing.image import load_img, img_to_array, DirectoryIterator, ImageDataGenerator 12 | from keras.applications import vgg19 13 | from keras import backend as K 14 | from keras.models import Sequential, Model 15 | from keras.layers import * 16 | from keras.optimizers import * 17 | from keras.callbacks import Callback, ModelCheckpoint, ReduceLROnPlateau 18 | from util import * 19 | import matplotlib 20 | from scipy.misc import imsave 21 | from joblib import dump, load 22 | import sys 23 | 24 | 25 | # dimensions of the generated picture. 26 | img_nrows = 400 27 | img_ncols = 400 28 | scale = 1/1. 29 | 30 | style_weight = 0.75e-2 31 | content_weight = 3e-4 32 | tv_weight = 0.8e-8 # total variance 33 | print style_weight,content_weight,tv_weight,scale 34 | 35 | def preprocess_image_vgg19(image_path): 36 | img = load_img(image_path, target_size=(img_nrows, img_ncols)) 37 | img = img_to_array(img) 38 | img = np.expand_dims(img, axis=0) 39 | img = img[:, :, :, ::-1] * scale 40 | img = np.squeeze(img, axis=0) 41 | return img 42 | 43 | def preprocess_image(img): 44 | img = img[ :, :, ::-1] * scale 45 | return img*1.0 46 | 47 | def deprocess_image(x): 48 | if K.image_data_format() == 'channels_first': 49 | x = x.reshape((3, img_nrows, img_ncols)) 50 | x = x.transpose((1, 2, 0)) 51 | else: 52 | x = x.reshape((img_nrows, img_ncols, 3)) 53 | x = x[:, :, ::-1] / scale 54 | x = np.clip(x, 0, 255).astype('uint8') 55 | return x 56 | 57 | def imshow(image, title=None): 58 | image = np.array(image).astype('uint8') 59 | plt.imshow(image) 60 | if title is not None: 61 | plt.title(title) 62 | 63 | channels = 3 64 | 65 | 66 | 67 | def get_Net(input_shape = (img_nrows,img_ncols,channels),style_weight = style_weight,content_weight = content_weight,tv_weight =tv_weight): 68 | import collections 69 | residual_size = 48 70 | 71 | class Parm(object): 72 | params_pos = 0 73 | params_pos2 = 0 74 | params_dict = [ 75 | ('c2', ((3, 3, 32, residual_size), residual_size)), 76 | ('rc1-1', ((3, 3, residual_size, residual_size), residual_size)), 77 | ('rc1-2', ((3, 3, residual_size, residual_size), residual_size)), 78 | ('rc2-1', ((3, 3, residual_size, residual_size), residual_size)), 79 | ('rc2-2', ((3, 3, residual_size, residual_size), residual_size)), 80 | ('rc3-1', ((3, 3, residual_size, residual_size), residual_size)), 81 | ('rc3-2', ((3, 3, residual_size, residual_size), residual_size)), 82 | ('rc4-1', ((3, 3, residual_size, residual_size), residual_size)), 83 | ('rc4-2', ((3, 3, residual_size, residual_size), residual_size)), 84 | ('rc5-1', ((3, 3, residual_size, residual_size), residual_size)), 85 | ('rc5-2', ((3, 3, residual_size, residual_size), residual_size)), 86 | ('rc6-1', ((3, 3, residual_size, residual_size), residual_size)), 87 | ('rc6-2', ((3, 3, residual_size, residual_size), residual_size)), 88 | ('c3', ((3, 3, residual_size, 48), 48))] 89 | parms = Parm() 90 | 91 | class ConvLayer(object): 92 | def __init__(self, filters, kernel_size=3, stride=1, 93 | upsample=None, instance_norm=False, activation='relu', trainable = False, parms = None): 94 | super(ConvLayer, self).__init__() 95 | self.upsample = upsample 96 | self.filters = filters 97 | self.kernel_size = kernel_size 98 | self.stride = stride 99 | self.activation = activation 100 | self.instance_norm = instance_norm 101 | self.trainable = trainable 102 | self.parms = parms 103 | 104 | def __call__(self, inputs): 105 | if self.trainable: 106 | x = inputs 107 | else: 108 | x = inputs[0] 109 | name = None 110 | if self.parms: 111 | params_pos = self.parms.params_pos 112 | params_pos2 = self.parms.params_pos2 113 | params_dict = self.parms.params_dict 114 | weight_start = params_pos2 115 | weight_shape = params_dict[params_pos][1] 116 | name = params_dict[params_pos][0] 117 | self.parms.params_pos += 1 118 | self.parms.params_pos2 += np.prod(weight_shape[0]) + weight_shape[1] 119 | if self.upsample: 120 | x = UpSampling2D(size=(self.upsample, self.upsample))(x) 121 | x = ReflectPadding2D(self.kernel_size//2)(x) 122 | if self.activation == 'prelu': 123 | if self.trainable: 124 | x = Conv2D(self.filters, self.kernel_size, strides=(self.stride, self.stride),name = name)(x) 125 | else: 126 | x = MyConv2D(self.filters, self.kernel_size, strides=(self.stride, self.stride),shape = (weight_start,weight_shape),name = name)([x,inputs[1]]) 127 | x = PReLU()(x) 128 | else: 129 | if self.trainable: 130 | x = Conv2D(self.filters, self.kernel_size, strides=(self.stride, self.stride), 131 | activation=self.activation,name = name)(x) 132 | else: 133 | x = MyConv2D(self.filters, self.kernel_size, strides=(self.stride, self.stride), 134 | activation=self.activation,shape = (weight_start,weight_shape),name = name)([x,inputs[1]]) 135 | if self.instance_norm: 136 | x = InstanceNormalization2D()(x) 137 | return x 138 | 139 | class ResidualBlock(object): 140 | def __call__(self, inputs): 141 | if self.trainable: 142 | x = inputs 143 | out = self.conv1(x) 144 | out = self.conv2(out) 145 | else: 146 | x = inputs[0] 147 | out = self.conv1([x, inputs[1]]) 148 | out = self.conv2([out, inputs[1]]) 149 | out = add([out, x]) 150 | return out 151 | 152 | def __init__(self, filters, parms, trainable = False): 153 | super(ResidualBlock, self).__init__() 154 | self.conv1 = ConvLayer(filters, kernel_size=3, stride=1, parms = parms, trainable = trainable) 155 | self.conv2 = ConvLayer(filters, kernel_size=3, stride=1, parms = parms, trainable = trainable, activation=None) 156 | self.parms = parms 157 | self.trainable = trainable 158 | 159 | def vgg_pre(x): 160 | mul = np.array([1.0]*3).reshape(1,1,1,3) 161 | meanoffset = (np.array([-103.939,-116.779,-123.68]) * scale).reshape(1,1,1,3) 162 | assert K.ndim(x) == 4 163 | return (x + meanoffset)/mul 164 | 165 | def gram_matrix(features,shape): 166 | batch, ch, h, w = shape 167 | if K.image_data_format() != 'channels_first': 168 | features = K.permute_dimensions(features, (0, 3, 1, 2)) 169 | batch, h, w,ch = shape 170 | features = K.reshape(features,(-1, ch, h * w)) 171 | gram = K.batch_dot(features, K.permute_dimensions(features, (0, 2, 1))) / ((int(ch) * int(h) * int(w) * 10)) 172 | return gram 173 | 174 | input = Input(shape=(input_shape), dtype="float32") # content 175 | input2 = Input(shape=(input_shape), dtype="float32") # style 176 | 177 | # vgg 178 | vgg = vgg19.VGG19(weights='imagenet', include_top=False) 179 | vgg.trainable = False 180 | outputs_dict = dict([(layer.name, layer.output) for layer in vgg.layers]) 181 | content_layers = ['block5_conv2', 182 | 'block4_conv3','block1_conv2'] 183 | layer_features = map(lambda x: outputs_dict[x], content_layers) 184 | content_feature_model = Model(vgg.input,layer_features) 185 | content_feature_model.trainable = False 186 | 187 | style_layers = ['block3_conv2', 188 | 'block4_conv3','block1_conv2', 189 | 'block2_conv2','block5_conv2'] 190 | layer_features = map(lambda x: outputs_dict[x], style_layers) 191 | style_feature_model = Model(vgg.input, layer_features) 192 | style_feature_model.trainable = False 193 | 194 | 195 | # meta 196 | meta_input1 = concatenate([Mean_std()(l) for l in style_feature_model(Lambda(lambda x: vgg_pre(x))(input2))]) 197 | x = Lambda(lambda x: 1/15 * x)(input2) 198 | x = ConvLayer(64, kernel_size=9, stride=1, trainable=True)(x) 199 | x = ConvLayer(64, kernel_size=3, stride=2, trainable=True)(x) 200 | x = ConvLayer(256, kernel_size=3, stride=2, trainable=True)(x) 201 | meta_input2 = Mean_std()(x) 202 | meta_output_list = [] 203 | hidden = concatenate([Dense(128, activation='relu')(meta_input1), Dense(16, activation='sigmoid')(meta_input2), 204 | Dense(128, activation='relu')(meta_input2)]) 205 | hidden = Dropout(0.05)(hidden) 206 | for name,shape in parms.params_dict: 207 | hidden2 = concatenate([Dense(64,activation = 'relu')(meta_input1),Dense(16,activation = 'sigmoid')(meta_input2),Dense(6,activation = 'relu')(meta_input2)]) 208 | hidden2 = concatenate([Dropout(0.05)(hidden2), hidden]) 209 | kernel = Dense(np.prod(shape[0]),activation = 'linear')(hidden2) 210 | bias = Dense(shape[1],activation = 'linear')(hidden2) 211 | meta_output_list.append(Lambda(lambda x:0.0001*x)(kernel)) 212 | meta_output_list.append(Lambda(lambda x:0.0001*x)(bias)) 213 | meta_output = (concatenate(meta_output_list)) 214 | 215 | meta_model = Model(input2,meta_output) 216 | 217 | # transform 218 | in_encode = ConvLayer(16, kernel_size=9, stride=1, trainable=True) 219 | x = Lambda(lambda x: 1/15. * x)(input) 220 | x = in_encode(x) 221 | x = ConvLayer(32, kernel_size=3, stride=2, trainable=True)(x) 222 | x = ConvLayer(residual_size, kernel_size=3, stride=2, parms = parms)([x,meta_output]) 223 | for i in range(6): 224 | x = ResidualBlock(residual_size, parms = parms)([x,meta_output]) 225 | x = ConvLayer(48, kernel_size=3, stride=1, upsample=2, parms = parms)([x,meta_output]) 226 | x = ConvLayer(16, kernel_size=3, stride=1,upsample=2, trainable=True)(x) 227 | out_decode = ConvLayer(3, kernel_size=9, stride=1, activation=None, trainable=True) 228 | x = out_decode(x) 229 | output = x 230 | output = Lambda(lambda x: x * 15.)(output) 231 | g_model = Model(inputs=[input,input2], outputs=output) 232 | 233 | 234 | return g_model 235 | 236 | 237 | g_model = get_Net(style_weight = style_weight,content_weight = content_weight,tv_weight =tv_weight) 238 | g_model.summary() 239 | # g_model.load_weights('./g5_48.h5') 240 | g_model.load_weights('./g7.h5') 241 | 242 | val = preprocess_image_vgg19(sys.argv[1]) 243 | style_val = preprocess_image_vgg19(sys.argv[2]) 244 | opt = Adam(lr = 0.00004, clipnorm=2, clipvalue=5e0) 245 | g_model.compile(loss='mse',optimizer=opt) 246 | output_image = g_model.predict([np.array([val]), np.array([style_val])]) 247 | imsave(sys.argv[3], deprocess_image(output_image[0].copy())) 248 | -------------------------------------------------------------------------------- /keras_version/picasso.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/picasso.jpg -------------------------------------------------------------------------------- /keras_version/res/all.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/res/all.jpg -------------------------------------------------------------------------------- /keras_version/res/c2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/res/c2.jpg -------------------------------------------------------------------------------- /keras_version/res/c2_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/res/c2_1.jpg -------------------------------------------------------------------------------- /keras_version/res/c2_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/res/c2_2.jpg -------------------------------------------------------------------------------- /keras_version/res/c2_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/res/c2_3.jpg -------------------------------------------------------------------------------- /keras_version/res/c2_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/res/c2_5.jpg -------------------------------------------------------------------------------- /keras_version/res/c4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/res/c4.jpg -------------------------------------------------------------------------------- /keras_version/res/c4_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/res/c4_1.jpg -------------------------------------------------------------------------------- /keras_version/res/c4_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/res/c4_2.jpg -------------------------------------------------------------------------------- /keras_version/res/c4_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/res/c4_3.jpg -------------------------------------------------------------------------------- /keras_version/res/c4_5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/res/c4_5.jpg -------------------------------------------------------------------------------- /keras_version/res/style1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/res/style1.jpg -------------------------------------------------------------------------------- /keras_version/res/style2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/res/style2.jpg -------------------------------------------------------------------------------- /keras_version/res/style3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/res/style3.jpg -------------------------------------------------------------------------------- /keras_version/res/style5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/res/style5.jpg -------------------------------------------------------------------------------- /keras_version/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/result.jpg -------------------------------------------------------------------------------- /keras_version/train_content/gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/train_content/gitkeep -------------------------------------------------------------------------------- /keras_version/train_style/gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/train_style/gitkeep -------------------------------------------------------------------------------- /keras_version/util.py: -------------------------------------------------------------------------------- 1 | 2 | from keras import backend as K 3 | from keras.layers import * 4 | from keras.engine.topology import Layer,InputSpec 5 | import tensorflow as tf 6 | import math 7 | from keras import initializers 8 | 9 | class InstanceNormalization2D(Layer): 10 | def __init__(self, 11 | **kwargs): 12 | super(InstanceNormalization2D, self).__init__(**kwargs) 13 | self.axis = -1 14 | 15 | def build(self, input_shape): 16 | dim = input_shape[self.axis] 17 | if dim is None: 18 | raise ValueError('Axis ' + str(self.axis) + ' of ' 19 | 'input tensor should have a defined dimension ' 20 | 'but the layer received an input with shape ' + 21 | str(input_shape) + '.') 22 | self.input_spec = InputSpec(ndim=len(input_shape), 23 | axes={self.axis: dim}) 24 | self.built = True 25 | 26 | 27 | def call(self, inputs): 28 | epsilon = 1e-4 29 | reduction_axes = [0,1] 30 | shape = inputs.shape 31 | if K.image_data_format() == 'channels_first': 32 | inputs = K.reshape(inputs,(-1,int(shape[1]),int(shape[2])*int(shape[3]))) 33 | m, v = tf.nn.moments(inputs, reduction_axes, keep_dims=True) 34 | return K.reshape((inputs - m) + 0.6 * m,(-1,int(shape[1]),int(shape[2]),int(shape[3]))) 35 | else: 36 | inputs = (K.permute_dimensions(inputs, (0, 3, 1, 2))) 37 | inputs = K.reshape(inputs, (-1, int(shape[3]), int(shape[1]) * int(shape[2]))) 38 | m,v = tf.nn.moments(inputs,reduction_axes,keep_dims = True) 39 | return K.permute_dimensions(K.reshape((inputs - m) + 0.8 * m,(-1,int(shape[3]),int(shape[1]),int(shape[2]))), (0, 2, 3, 1)) 40 | 41 | def get_config(self): 42 | config = { 43 | 'axis': self.axis, 44 | } 45 | base_config = super(InstanceNormalization2D, self).get_config() 46 | return dict(list(base_config.items()) + list(config.items())) 47 | 48 | class InstanceNormalization2D2(Layer): 49 | def __init__(self,alpha = 0.3, 50 | **kwargs): 51 | super(InstanceNormalization2D2, self).__init__(**kwargs) 52 | self.alpha = alpha 53 | if K.image_data_format() == 'channels_first': 54 | self.axis = 1 55 | else: 56 | self.axis = -1 57 | 58 | def build(self, input_shape): 59 | dim = input_shape[self.axis] 60 | if dim is None: 61 | raise ValueError('Axis ' + str(self.axis) + ' of ' 62 | 'input tensor should have a defined dimension ' 63 | 'but the layer received an input with shape ' + 64 | str(input_shape) + '.') 65 | self.input_spec = InputSpec(ndim=len(input_shape), 66 | axes={self.axis: dim}) 67 | shape = (dim,) 68 | self.gamma = self.add_weight(shape=shape, 69 | name='gamma', 70 | initializer=initializers.get('ones')) 71 | self.beta = self.add_weight(shape=shape, 72 | name='beta', 73 | initializer=initializers.get('zeros')) 74 | self.built = True 75 | 76 | def call(self, inputs): 77 | epsilon = 1e-4 78 | reduction_axes = [0,1] 79 | shape = inputs.shape 80 | 81 | if K.image_data_format() == 'channels_first': 82 | inputs = K.reshape(inputs,(-1,int(shape[1]),int(shape[2])*int(shape[3]))) 83 | m, v = tf.nn.moments(inputs, reduction_axes, keep_dims=True) 84 | v = (v - 1.0) * self.alpha + 1. 85 | output = K.reshape((inputs - m)/(K.sqrt(v)+epsilon) + self.alpha * m,(-1,int(shape[1]),int(shape[2]),int(shape[3]))) 86 | gamma = K.repeat_elements(K.repeat_elements(K.reshape(self.gamma, (-1, int(shape[1]), 1, 1)), int(shape[2]), 2), int(shape[3]), 3) 87 | beta = K.repeat_elements( 88 | K.repeat_elements(K.reshape(self.beta, (-1, int(shape[1]), 1, 1)), int(shape[2]), 2), int(shape[3]), 3) 89 | else: 90 | inputs = (K.permute_dimensions(inputs, (0, 3, 1, 2))) 91 | inputs = K.reshape(inputs, (-1, int(shape[3]), int(shape[1]) * int(shape[2]))) 92 | m,v = tf.nn.moments(inputs,reduction_axes,keep_dims = True) 93 | v = (v - 1.0) * self.alpha + 1. 94 | output = K.permute_dimensions(K.reshape((inputs - m)/(K.sqrt(v)+epsilon) + self.alpha * m,(-1,int(shape[3]),int(shape[1]),int(shape[2]))), (0, 2, 3, 1)) 95 | gamma = K.repeat_elements(K.repeat_elements(K.reshape(self.gamma, (-1, 1, 1, int(shape[3]))), int(shape[2]), 2), int(shape[1]), 1) 96 | beta = K.repeat_elements(K.repeat_elements(K.reshape(self.beta, (-1, 1, 1, int(shape[3]))), int(shape[2]), 2), 97 | int(shape[1]), 1) 98 | 99 | return output * gamma + beta 100 | 101 | def get_config(self): 102 | config = { 103 | 'axis': self.axis, 104 | } 105 | base_config = super(InstanceNormalization2D2, self).get_config() 106 | return dict(list(base_config.items()) + list(config.items())) 107 | 108 | class ReflectPadding2D(Layer): 109 | def __init__(self,padding_length, 110 | **kwargs): 111 | super(ReflectPadding2D, self).__init__(**kwargs) 112 | self.axis = 2 113 | self.padding_length = padding_length 114 | 115 | def build(self, input_shape): 116 | dim = input_shape[self.axis] 117 | if dim is None: 118 | raise ValueError('Axis ' + str(self.axis) + ' of ' 119 | 'input tensor should have a defined dimension ' 120 | 'but the layer received an input with shape ' + 121 | str(input_shape) + '.') 122 | self.input_spec = InputSpec(ndim=len(input_shape), 123 | axes={self.axis: dim}) 124 | self.built = True 125 | 126 | def call(self, inputs): 127 | if K.image_data_format() != 'channels_first': 128 | inputs = (K.permute_dimensions(inputs, (0, 3, 1, 2))) 129 | reverse1 = K.reverse(inputs,-1) 130 | inputs = K.concatenate([reverse1[:,:,:,-self.padding_length:],inputs,reverse1[:,:,:,:self.padding_length]],axis = -1) 131 | reverse2 = K.reverse(inputs,-2) 132 | inputs = K.concatenate([reverse2[:,:,-self.padding_length:,:],inputs,reverse2[:,:,:self.padding_length,:]],axis = -2) 133 | if K.image_data_format() != 'channels_first': 134 | inputs = (K.permute_dimensions(inputs, (0, 2, 3, 1))) 135 | return inputs 136 | 137 | def get_config(self): 138 | config = { 139 | 'axis': self.axis, 140 | 'padding_length': self.padding_length, 141 | } 142 | base_config = super(ReflectPadding2D, self).get_config() 143 | return dict(list(base_config.items()) + list(config.items())) 144 | 145 | def compute_output_shape(self,input_shape): 146 | output_shape = (input_shape[0],input_shape[1] + 2 * self.padding_length,input_shape[2] + 2 * self.padding_length,input_shape[3]) 147 | return output_shape 148 | 149 | class Mean_std(Layer): 150 | def __init__(self, 151 | **kwargs): 152 | super(Mean_std, self).__init__(**kwargs) 153 | self.axis = -1 154 | 155 | def build(self, input_shape): 156 | dim = input_shape[self.axis] 157 | if dim is None: 158 | raise ValueError('Axis ' + str(self.axis) + ' of ' 159 | 'input tensor should have a defined dimension ' 160 | 'but the layer received an input with shape ' + 161 | str(input_shape) + '.') 162 | self.input_spec = InputSpec(ndim=len(input_shape), 163 | axes={self.axis: dim}) 164 | self.built = True 165 | 166 | 167 | def call(self, inputs): 168 | shape = inputs.shape 169 | if K.image_data_format() == 'channels_first': 170 | inputs = K.reshape(inputs,(-1,int(shape[1]),int(shape[2])*int(shape[3]))) 171 | m = K.mean(inputs, axis=-1, keepdims=False) 172 | v = K.sqrt(K.update_add(K.var(inputs, axis=-1, keepdims=False),1.0e-5)) 173 | return K.concatenate([m,v],axis = -1) 174 | else: 175 | inputs = (K.permute_dimensions(inputs, (0, 3, 1, 2))) 176 | inputs = K.reshape(inputs, (-1, int(shape[3]), int(shape[1]) * int(shape[2]))) 177 | m = K.mean(inputs, axis=-1, keepdims=False) 178 | v = K.sqrt(K.var(inputs, axis=-1, keepdims=False)+K.constant(1.0e-5, dtype=inputs.dtype.base_dtype)) 179 | lmax = K.max(inputs, axis=-1, keepdims=False) 180 | return K.concatenate([m,v],axis = -1) 181 | 182 | def get_config(self): 183 | config = { 184 | 'axis': self.axis, 185 | } 186 | base_config = super(Mean_std, self).get_config() 187 | return dict(list(base_config.items()) + list(config.items())) 188 | 189 | def compute_output_shape(self,input_shape): 190 | if K.image_data_format() == 'channels_first': 191 | output_shape = (input_shape[0],input_shape[1] * 2) 192 | else: 193 | output_shape = (input_shape[0], input_shape[3] * 2) 194 | return output_shape 195 | 196 | 197 | 198 | class MyConv2D(Layer): 199 | def __init__(self, 200 | filters, 201 | kernel_size, 202 | strides=1, 203 | padding='valid', 204 | data_format=None, 205 | dilation_rate=1, 206 | activation=None, 207 | use_bias=True, 208 | kernel_initializer='glorot_uniform', 209 | bias_initializer='zeros', 210 | kernel_regularizer=None, 211 | bias_regularizer=None, 212 | activity_regularizer=None, 213 | kernel_constraint=None, 214 | bias_constraint=None, 215 | shape=(0, ((3, 3, 64, 64), 64)), 216 | **kwargs): 217 | super(MyConv2D, self).__init__(**kwargs) 218 | self.rank = 2 219 | self.filters = filters 220 | self.kernel_size = conv_utils.normalize_tuple(kernel_size, 2, 'kernel_size') 221 | self.strides = conv_utils.normalize_tuple(strides, 2, 'strides') 222 | self.padding = conv_utils.normalize_padding(padding) 223 | self.data_format = conv_utils.normalize_data_format(data_format) 224 | self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, 2, 'dilation_rate') 225 | self.activation = activations.get(activation) 226 | self.use_bias = use_bias 227 | self.kernel_initializer = initializers.get(kernel_initializer) 228 | self.bias_initializer = initializers.get(bias_initializer) 229 | self.kernel_regularizer = regularizers.get(kernel_regularizer) 230 | self.bias_regularizer = regularizers.get(bias_regularizer) 231 | self.activity_regularizer = regularizers.get(activity_regularizer) 232 | self.kernel_constraint = constraints.get(kernel_constraint) 233 | self.bias_constraint = constraints.get(bias_constraint) 234 | self.shape = shape 235 | # self.input_spec = InputSpec(ndim=self.rank + 2) 236 | 237 | def build(self, input_shape): 238 | input_shape = input_shape[0] 239 | if self.data_format == 'channels_first': 240 | channel_axis = 1 241 | else: 242 | channel_axis = -1 243 | if input_shape[channel_axis] is None: 244 | raise ValueError('The channel dimension of the inputs ' 245 | 'should be defined. Found `None`.') 246 | input_dim = input_shape[channel_axis] 247 | # self.input_spec = InputSpec(ndim=self.rank + 2, 248 | # axes={channel_axis: input_dim}) 249 | self.built = True 250 | 251 | def call(self, inputs): 252 | weight_start,weight_shape = self.shape 253 | kernel = K.reshape(inputs[1][:, weight_start:weight_start + np.prod(weight_shape[0])], (-1,) + weight_shape[0]) 254 | bias = K.reshape(inputs[1][:, 255 | weight_start + np.prod(weight_shape[0]):weight_start + np.prod(weight_shape[0]) + weight_shape[ 256 | 1]] 257 | , (-1,weight_shape[1])) 258 | kernel = kernel[0,:,:,:,:] 259 | bias = bias[0,:] 260 | outputs = K.conv2d( 261 | inputs[0], 262 | kernel, 263 | strides=self.strides, 264 | padding=self.padding, 265 | data_format=self.data_format, 266 | dilation_rate=self.dilation_rate) 267 | 268 | if self.use_bias: 269 | outputs = K.bias_add( 270 | outputs, 271 | bias, 272 | data_format=self.data_format) 273 | 274 | if self.activation is not None: 275 | return self.activation(outputs) 276 | return outputs 277 | 278 | def compute_output_shape(self, input_shape): 279 | input_shape = input_shape[0] 280 | if self.data_format == 'channels_last': 281 | space = input_shape[1:-1] 282 | new_space = [] 283 | for i in range(len(space)): 284 | new_dim = conv_utils.conv_output_length( 285 | space[i], 286 | self.kernel_size[i], 287 | padding=self.padding, 288 | stride=self.strides[i], 289 | dilation=self.dilation_rate[i]) 290 | new_space.append(new_dim) 291 | return (input_shape[0],) + tuple(new_space) + (self.filters,) 292 | if self.data_format == 'channels_first': 293 | space = input_shape[2:] 294 | new_space = [] 295 | for i in range(len(space)): 296 | new_dim = conv_utils.conv_output_length( 297 | space[i], 298 | self.kernel_size[i], 299 | padding=self.padding, 300 | stride=self.strides[i], 301 | dilation=self.dilation_rate[i]) 302 | new_space.append(new_dim) 303 | return (input_shape[0], self.filters) + tuple(new_space) 304 | 305 | def get_config(self): 306 | config = { 307 | 'filters': self.filters, 308 | 'kernel_size': self.kernel_size, 309 | 'strides': self.strides, 310 | 'padding': self.padding, 311 | 'data_format': self.data_format, 312 | 'dilation_rate': self.dilation_rate, 313 | 'activation': activations.serialize(self.activation), 314 | 'use_bias': self.use_bias, 315 | 'kernel_initializer': initializers.serialize(self.kernel_initializer), 316 | 'bias_initializer': initializers.serialize(self.bias_initializer), 317 | 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 318 | 'bias_regularizer': regularizers.serialize(self.bias_regularizer), 319 | 'activity_regularizer': regularizers.serialize(self.activity_regularizer), 320 | 'kernel_constraint': constraints.serialize(self.kernel_constraint), 321 | 'bias_constraint': constraints.serialize(self.bias_constraint) 322 | } 323 | base_config = super(MyConv2D, self).get_config() 324 | return dict(list(base_config.items()) + list(config.items())) 325 | 326 | 327 | class Smooth: 328 | def __init__(self, windowsize=100): 329 | self.window_size = windowsize 330 | self.data = np.zeros((self.window_size, 1), dtype=np.float32) 331 | self.index = 0 332 | 333 | def __iadd__(self, x): 334 | if self.index == 0: 335 | self.data[:] = x 336 | self.data[self.index % self.window_size] = x 337 | self.index += 1 338 | return self 339 | 340 | def __float__(self): 341 | return float(self.data.mean()) 342 | 343 | def __format__(self, f): 344 | return self.__float__().__format__(f) -------------------------------------------------------------------------------- /keras_version/util.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CortexFoundation/StyleTransferTrilogy/467046b6271fc4f0e4f01967d184cfcd9078b534/keras_version/util.pyc --------------------------------------------------------------------------------