├── .gitignore ├── images ├── pwr.jpg ├── candy.jpg ├── dancing.jpg ├── mosaic.jpg ├── picasso.jpg └── results │ ├── pwr_candy.jpg │ ├── pwr_mosaic.jpg │ ├── pwr_picasso.jpg │ ├── dancing_candy.jpg │ ├── dancing_mosaic.jpg │ └── dancing_picasso.jpg ├── models ├── candy.pth ├── mosaic.pth └── picasso.pth ├── feature_ext.py ├── utils.py ├── webcam.py ├── fnst_modules.py ├── README.md └── fnst.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | -------------------------------------------------------------------------------- /images/pwr.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/pwr.jpg -------------------------------------------------------------------------------- /images/candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/candy.jpg -------------------------------------------------------------------------------- /images/dancing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/dancing.jpg -------------------------------------------------------------------------------- /images/mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/mosaic.jpg -------------------------------------------------------------------------------- /images/picasso.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/picasso.jpg -------------------------------------------------------------------------------- /models/candy.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/models/candy.pth -------------------------------------------------------------------------------- /models/mosaic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/models/mosaic.pth -------------------------------------------------------------------------------- /models/picasso.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/models/picasso.pth -------------------------------------------------------------------------------- /images/results/pwr_candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/results/pwr_candy.jpg -------------------------------------------------------------------------------- /images/results/pwr_mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/results/pwr_mosaic.jpg -------------------------------------------------------------------------------- /images/results/pwr_picasso.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/results/pwr_picasso.jpg -------------------------------------------------------------------------------- /images/results/dancing_candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/results/dancing_candy.jpg -------------------------------------------------------------------------------- /images/results/dancing_mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/results/dancing_mosaic.jpg -------------------------------------------------------------------------------- /images/results/dancing_picasso.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/results/dancing_picasso.jpg -------------------------------------------------------------------------------- /feature_ext.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | 4 | class FeatureExtractor(): 5 | def __init__(self, model, idxs): 6 | self.__list = [0] * len(idxs) 7 | self.__hooks = [] 8 | self.__modules = [] 9 | self.__model_to_list(model) 10 | self.__create_hooks(idxs) 11 | 12 | def __create_hooks(self, idxs): 13 | help_idxs = list(range(len(idxs))) 14 | for i, idx in zip(idxs, help_idxs): 15 | fun = partial(self.__hook_fn, idx=idx) 16 | hook = self.__modules[i].register_forward_hook(fun) 17 | self.__hooks.append(hook) 18 | 19 | def __model_to_list(self, model): 20 | if list(model.children()) == []: 21 | self.__modules.append(model) 22 | for ch in model.children(): 23 | self.__model_to_list(ch) 24 | 25 | def __hook_fn(self, module, input, output, idx): 26 | self.__list[idx] = output 27 | 28 | @property 29 | def features(self): 30 | return self.__list 31 | 32 | def remove_hooks(self): 33 | [x.remove() for x in self.__hooks] 34 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from PIL import Image 4 | 5 | 6 | def load_im(f, size=None): 7 | img = Image.open(f) 8 | if size is not None: 9 | img = img.resize((size, size), Image.ANTIALIAS) 10 | return img 11 | 12 | 13 | def save_im(f, tens): 14 | img = tens.detach().clamp(0, 255).cpu().numpy() 15 | img = img.transpose(1, 2, 0).astype("uint8") 16 | img = Image.fromarray(img) 17 | img.save(f) 18 | 19 | 20 | def gram_matrix(x): 21 | b, c, h, w = x.size() 22 | features = x.view(b, c, w*h) 23 | features_t = features.transpose(1, 2) 24 | gram = features.bmm(features_t) / (c*h*w) 25 | return gram 26 | 27 | 28 | def norm_batch(b): 29 | mean = b.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) 30 | std = b.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) 31 | b = b.div_(255.0) 32 | return (b - mean) / std 33 | 34 | 35 | def regularization_loss(x): 36 | loss = (torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) + 37 | torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))) 38 | return loss 39 | -------------------------------------------------------------------------------- /webcam.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | from fnst_modules import TransformerMobileNet 4 | from torchvision import transforms 5 | 6 | MODEL = 'models/mosaic.pth' 7 | IMAGE_SIZE = 300 8 | 9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 10 | 11 | transform = transforms.Compose([ 12 | transforms.ToTensor(), 13 | transforms.Lambda(lambda x: x.mul(255).unsqueeze(0).to(device))]) 14 | 15 | 16 | def postprocess(tens): 17 | img = tens.permute(1, 2, 0).clamp(0, 255) 18 | img = img.cpu().numpy() 19 | img = img.astype("uint8") 20 | return img 21 | 22 | 23 | def prepare_net(f, net): 24 | state_dict = torch.load(f) 25 | net.load_state_dict(state_dict) 26 | net.to(device) 27 | net.eval() 28 | 29 | 30 | def main(): 31 | with torch.no_grad(): 32 | net = TransformerMobileNet() 33 | prepare_net(MODEL, net) 34 | 35 | capture = cv2.VideoCapture(0) 36 | 37 | while True: 38 | _, im = capture.read() 39 | im = cv2.resize(im, (IMAGE_SIZE, IMAGE_SIZE)) 40 | t = transform(im) 41 | res = net(t)[0] 42 | im = postprocess(res) 43 | cv2.imshow('webcam', im) 44 | if cv2.waitKey(1) == 27: # press Esc to end 45 | break 46 | 47 | capture.release() 48 | cv2.destroyAllWindows() 49 | 50 | 51 | if __name__ == '__main__': 52 | main() 53 | -------------------------------------------------------------------------------- /fnst_modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class Bottleneck(nn.Module): 6 | def __init__(self, inp_c, out_c, kernel_size, stride, t=1): 7 | assert stride in [1, 2], 'stride must be either 1 or 2' 8 | super().__init__() 9 | self.residual = stride == 1 and inp_c == out_c 10 | pad = kernel_size // 2 11 | self.reflection_pad = nn.ReflectionPad2d(pad) 12 | self.conv1 = nn.Conv2d(inp_c, t*inp_c, 1, 1, bias=False) 13 | self.in1 = nn.InstanceNorm2d(t*inp_c, affine=True) 14 | self.conv2 = nn.Conv2d(t*inp_c, t*inp_c, kernel_size, stride, 15 | groups=t*inp_c, bias=False) 16 | self.in2 = nn.InstanceNorm2d(t*inp_c, affine=True) 17 | self.conv3 = nn.Conv2d(t*inp_c, out_c, 1, 1, bias=False) 18 | self.in3 = nn.InstanceNorm2d(out_c, affine=True) 19 | 20 | def forward(self, x): 21 | out = F.relu6(self.in1(self.conv1(x))) 22 | out = self.reflection_pad(out) 23 | out = F.relu6(self.in2(self.conv2(out))) 24 | out = self.in3(self.conv3(out)) 25 | if self.residual: 26 | out = x + out 27 | return out 28 | 29 | 30 | class UpsampleConv(nn.Module): 31 | def __init__(self, inp_c, out_c, kernel_size, stride, upsample=2): 32 | super().__init__() 33 | if upsample: 34 | self.upsample = nn.Upsample(mode='nearest', scale_factor=upsample) 35 | else: 36 | self.upsample = None 37 | self.conv1 = Bottleneck(inp_c, out_c, kernel_size, stride) 38 | 39 | def forward(self, x): 40 | x_in = x 41 | if self.upsample is not None: 42 | x_in = self.upsample(x_in) 43 | out = F.relu(self.conv1(x_in)) 44 | return out 45 | 46 | 47 | class TransformerMobileNet(nn.Module): 48 | def __init__(self): 49 | super().__init__() 50 | # Conv Layers 51 | self.reflection_pad = nn.ReflectionPad2d(9//2) 52 | self.conv1 = nn.Conv2d(3, 32, kernel_size=9, stride=1, bias=False) 53 | self.in1 = nn.InstanceNorm2d(32, affine=True) 54 | self.conv2 = Bottleneck(32, 64, kernel_size=3, stride=2) 55 | self.conv3 = Bottleneck(64, 128, kernel_size=3, stride=2) 56 | # Residual Layers 57 | self.res1 = Bottleneck(128, 128, 3, 1) 58 | self.res2 = Bottleneck(128, 128, 3, 1) 59 | self.res3 = Bottleneck(128, 128, 3, 1) 60 | self.res4 = Bottleneck(128, 128, 3, 1) 61 | self.res5 = Bottleneck(128, 128, 3, 1) 62 | # Upsampling Layers 63 | self.upconv1 = UpsampleConv(128, 64, kernel_size=3, stride=1) 64 | self.upconv2 = UpsampleConv(64, 32, kernel_size=3, stride=1) 65 | self.conv4 = nn.Conv2d(32, 3, kernel_size=9, stride=1, bias=False) 66 | 67 | def forward(self, x): 68 | out = self.reflection_pad(x) 69 | out = F.relu(self.in1(self.conv1(out))) 70 | out = self.conv2(out) 71 | out = self.conv3(out) 72 | out = self.res1(out) 73 | out = self.res2(out) 74 | out = self.res3(out) 75 | out = self.res4(out) 76 | out = self.res5(out) 77 | out = self.upconv1(out) 78 | out = self.upconv2(out) 79 | out = self.conv4(self.reflection_pad(out)) 80 | return out 81 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast neural style with MobileNetV2 bottleneck blocks 2 | This repository contains a PyTorch implementation of an algorithm for artistic style transfer. The implementation is based on the following papers and repositories: 3 | 4 | - [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155) 5 | - [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) 6 | - [A Neural Algorithm of Artistic Style](https://arxiv.org/abs/1508.06576) 7 | - [PyTorch Fast Neural Style Example](https://github.com/pytorch/examples/tree/master/fast_neural_style) 8 | 9 | ## Main Differences from Other Implementations 10 | 11 | * Residual Blocks and Convolutions are changed to MobileNetV2 bottleneck blocks which make use of Inverted Residuals and Depthwise Separable Convolutions. 12 | 13 | ![Bottleneck](https://hsto.org/webt/wl/yo/sz/wlyoszqnws58itd4ojt1cqt7sng.png) 14 | 15 | On the picture you can see 2 types of MobileNetV2 bottleneck blocks. Left one is used instead of residual block and right one is used instead of convolution layer. Purposes of this change: 16 | 17 | - Decrease number of trainable parameters of the transformer network from __~1.67m__ to __~0.23m__, therefore decrease amount of the memory used by the transformer network. 18 | 19 | - In theory this should give a good speedup during training time and, more importantly, during inference time (fast neural style should be fast as possible). It appeared that in practice things are not so good and this architecture of the transformer network is only a bit faster than the original transformer network. The main cause of it is that depthwise convolutions are not so efficiently implemented on GPU as common convolutions are (on CPU the speedup is bigger, but still not drastic). 20 | 21 | * This implementation uses the feature extractor wrapper around PyTorch module which uses PyTorch hook methods to retrieve layer activations. With this extractor: 22 | 23 | - You don't need to write a new module wrapper in order to extract desired features every time you want to use a new loss network. You just need to input model and layer indexes to the feature extractor wrapper and it will handle extracting for you. (__Note__: The wrapper flattens the input module/model so you need to input proper indexes of the flattened module, i.e. if module/model is a composition of smaller modules it will be represented as flat list of layers inside the wrapper). 24 | 25 | - Makes training process slightly faster. 26 | 27 | * The implementation allows you use different weights for different style features, which leads to better visual results. 28 | 29 | ## Requirements 30 | - [pytorch](https://pytorch.org) (>= 0.4.0) 31 | - [torchvision](https://pytorch.org) 32 | - [PIL](https://pillow.readthedocs.io/en/5.1.x/) 33 | - [OpenCV](https://opencv.org/) (for webcam demo) 34 | - GPU is not necessary 35 | 36 | ## Usage 37 | To train the transformer network: 38 | ``` 39 | python fnst.py -train 40 | ``` 41 | To stylize an image with a pretrained model: 42 | ``` 43 | python fnst.py 44 | ``` 45 | 46 | All configurable parameters are stored as globals in the top of `fnst.py` file, so in order to configure those parameters just change them in `fnst.py` (I thought it is more convenient way than adding dozen of arguments). 47 | 48 | There is also webcam demo in the repo, to run it: 49 | ``` 50 | python webcam.py 51 | ``` 52 | `webcam.py` also has some globals in the top of the file which you can change. 53 | 54 | 55 | ## Examples 56 | All models were trained on 128x128 (because of GTX 960m on my laptop) [COCO Dataset](http://cocodataset.org) images for 3 epochs. 57 | 58 | - Styles 59 | 60 | 61 | - Results 62 | 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /fnst.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.optim import Adam 5 | from torch.utils.data import DataLoader 6 | from torchvision import datasets, transforms, models 7 | 8 | from PIL import Image 9 | from tqdm import tqdm 10 | import os 11 | import argparse 12 | 13 | from utils import load_im, save_im 14 | from utils import gram_matrix, norm_batch, regularization_loss 15 | from fnst_modules import TransformerMobileNet 16 | from feature_ext import FeatureExtractor 17 | 18 | EPOCHS = 2 19 | LOSS_NETWORK = models.vgg16 20 | TRAIN_PATH = '101_ObjectCategories' 21 | STYLE_IMAGE = 'images/mosaic.jpg' 22 | CHECK_IMAGE = 'images/dancing.jpg' 23 | IMAGE_SIZE = 128 24 | LAYER_IDXS = [3, 8, 15, 22] 25 | BATCH_SIZE = 4 26 | CONTENT_WEIGHT = 1 27 | STYLE_WEIGHT = 3 * 1e5 28 | REG_WEIGHT = 3 * 1e-5 29 | STYLE_PROPORTIONS = [.35, .35, .15, .15] 30 | CONTENT_INDEX = 1 31 | LOG_INTERVAL = 1000 32 | CHECKPOINT = 4000 33 | 34 | OUTPUT_PATH = 'images/results' 35 | INPUT_IMAGE = 'images/pwr.jpg' 36 | MODEL_PATH = 'models/mosaic.pth' 37 | 38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 39 | 40 | 41 | class Learner(): 42 | def __init__(self, loss_network, train_path, style_image, check_image, im_size, 43 | layer_idxs, batch_size, c_weight, s_weight, r_weight, 44 | style_proportions, content_index, log_interval, checkpoint): 45 | assert len(style_proportions) == len(layer_idxs) 46 | 47 | # prepare dataset and loader 48 | self.dataset, self.loader = ( 49 | self.__prepare_dataset(train_path, im_size, batch_size)) 50 | 51 | # prepare transformer and classifier nets, optimizer 52 | self.tfm_net = TransformerMobileNet().to(device) 53 | self.loss_net = loss_network 54 | self.__prepare_loss_net() 55 | self.optimizer = Adam(self.tfm_net.parameters(), 1e-3) 56 | 57 | # prepare feature extractor, style image 58 | # and image for checking intermediate results 59 | self.content_index = content_index 60 | self.fx = FeatureExtractor(self.loss_net, layer_idxs) 61 | self.style_batch, self.style_target = ( 62 | self.__prepare_style_target(style_image, batch_size)) 63 | self.check_tensor = self.__prepare_check_tensor(check_image) 64 | 65 | # set weights for different losses 66 | self.content_weight = c_weight 67 | self.style_weights = [s_weight*x for x in style_proportions] 68 | self.reg_weight = r_weight 69 | 70 | # intervals 71 | self.log_intl = log_interval 72 | self.checkpoint = checkpoint 73 | 74 | def __prepare_paths(self): 75 | for p in ['models', 'images']: 76 | _path = os.path.join('tmp', p) 77 | os.makedirs(_path, exist_ok=True) 78 | 79 | def __prepare_dataset(self, train_path, im_size, batch_size): 80 | transform = transforms.Compose([ 81 | transforms.Resize(im_size), 82 | transforms.CenterCrop(im_size), 83 | transforms.ToTensor(), 84 | transforms.Lambda(lambda x: x.mul(255))]) 85 | 86 | ds = datasets.ImageFolder(train_path, transform) 87 | ld = DataLoader(ds, batch_size=batch_size) 88 | return ds, ld 89 | 90 | def __prepare_loss_net(self): 91 | for p in self.loss_net.parameters(): 92 | p.requires_grad_(False) 93 | self.loss_net.to(device).eval() 94 | 95 | def __prepare_style_target(self, im_path, batch_size): 96 | transform = transforms.Compose([ 97 | transforms.ToTensor(), 98 | transforms.Lambda(lambda x: x.mul(255))]) 99 | 100 | style_im = load_im(im_path) 101 | style_tensor = transform(style_im) 102 | style_batch = style_tensor.repeat(batch_size, 1, 1, 1).to(device) 103 | 104 | self.loss_net(norm_batch(style_batch)) 105 | style_target = tuple(gram_matrix(x) 106 | for x in self.fx.features) 107 | return style_batch, style_target 108 | 109 | def __prepare_check_tensor(self, im_path): 110 | transform = transforms.Compose([ 111 | transforms.ToTensor(), 112 | transforms.Lambda(lambda x: x.mul(255))]) 113 | 114 | if im_path is None: 115 | res = self.dataset[0].unsqueeze(0).to(device) 116 | else: 117 | res = transform(Image.open(im_path)).unsqueeze(0).to(device) 118 | return res 119 | 120 | def train(self, epochs): 121 | self.__prepare_paths() 122 | for e in range(epochs): 123 | self.tfm_net.train() 124 | agg_content_loss = 0. 125 | agg_style_loss = 0. 126 | agg_reg_loss = 0. 127 | for i, (x, _) in enumerate(tqdm(self.loader, desc=f'Epoch {e}')): 128 | len_batch = len(x) 129 | self.optimizer.zero_grad() 130 | 131 | x = x.to(device) 132 | y = self.tfm_net(x) 133 | x = norm_batch(x) 134 | y = norm_batch(y) 135 | 136 | self.loss_net(y) 137 | style_y = tuple(gram_matrix(x) for x in self.fx.features) 138 | content_y = self.fx.features[self.content_index] 139 | 140 | self.loss_net(x) 141 | content_x = self.fx.features[self.content_index] 142 | 143 | content_loss = F.mse_loss(content_y, content_x) 144 | content_loss *= self.content_weight 145 | 146 | style_loss = 0. 147 | for gm_y, gm_t, w in zip(style_y, self.style_target, 148 | self.style_weights): 149 | style_loss += (w*F.mse_loss(gm_y, gm_t[:len_batch, :, :])) 150 | 151 | reg_loss = self.reg_weight * regularization_loss(y) 152 | 153 | total_loss = content_loss + style_loss + reg_loss 154 | 155 | total_loss.backward() 156 | self.optimizer.step() 157 | 158 | agg_content_loss += content_loss.item() 159 | agg_style_loss += style_loss.item() 160 | agg_reg_loss += reg_loss.item() 161 | 162 | if (i+1) % self.log_intl == 0: 163 | self.intermediate_res(agg_content_loss, agg_style_loss, 164 | agg_reg_loss, i+1) 165 | 166 | if (i+1) % self.checkpoint == 0: 167 | self.save_tfm_net(e+1, i+1) 168 | 169 | self.save_tfm_net(e+1, i+1) 170 | 171 | def intermediate_res(self, c_loss, s_loss, r_loss, n): 172 | self.tfm_net.eval() 173 | check = self.tfm_net(self.check_tensor) 174 | _path = os.path.join('tmp', 'images', f'check{n}.jpg') 175 | save_im(_path, check[0]) 176 | self.tfm_net.train() 177 | 178 | msg = (f'\nbatch: {n}\t' 179 | f'content: {c_loss/n}\t' 180 | f'style: {s_loss/n}\t' 181 | f'reg: {r_loss/n}\t' 182 | f'total: {(c_loss + s_loss + r_loss)/n} \n') 183 | 184 | print(msg) 185 | 186 | def save_tfm_net(self, e, i): 187 | name = f'epoch{e}_batch{i}.pth' 188 | _path = os.path.join('tmp', 'models', name) 189 | torch.save(self.tfm_net.state_dict(), _path) 190 | 191 | 192 | class Stylizer(): 193 | def __init__(self, model_path, output_path): 194 | self.output_path = output_path 195 | self.model_name = os.path.basename(model_path).split('.')[0] 196 | self.__load_net(model_path) 197 | self.net.to(device) 198 | self.transform = transforms.Compose([ 199 | transforms.ToTensor(), 200 | transforms.Lambda(lambda x: x.mul(255).unsqueeze(0).to(device))]) 201 | 202 | def __load_net(self, model_path): 203 | with torch.no_grad(): 204 | self.net = TransformerMobileNet() 205 | state_dict = torch.load(model_path) 206 | self.net.load_state_dict(state_dict) 207 | 208 | def stylize(self, im_path): 209 | with torch.no_grad(): 210 | self.net.eval() 211 | im = load_im(im_path) 212 | x = self.transform(im) 213 | out = self.net(x) 214 | _name = (os.path.basename(im_path).split('.')[0] + '_' 215 | + self.model_name + '.jpg') 216 | _path = os.path.join(self.output_path, _name) 217 | save_im(_path, out[0]) 218 | 219 | 220 | parser = argparse.ArgumentParser() 221 | parser.add_argument('-train', action='store_true') 222 | 223 | 224 | def main(): 225 | args = parser.parse_args() 226 | if args.train: 227 | loss_network = LOSS_NETWORK(True) 228 | loss_network = nn.Sequential(*list(loss_network.features)[:23]) # ToDo 229 | lrn = Learner(loss_network=loss_network, 230 | train_path=TRAIN_PATH, style_image=STYLE_IMAGE, 231 | check_image=CHECK_IMAGE, im_size=IMAGE_SIZE, 232 | layer_idxs=LAYER_IDXS, batch_size=BATCH_SIZE, 233 | c_weight=CONTENT_WEIGHT, s_weight=STYLE_WEIGHT, 234 | r_weight=REG_WEIGHT, style_proportions=STYLE_PROPORTIONS, 235 | content_index=CONTENT_INDEX, log_interval=LOG_INTERVAL, 236 | checkpoint=CHECKPOINT) 237 | lrn.train(EPOCHS) 238 | else: 239 | stl = Stylizer(MODEL_PATH, OUTPUT_PATH) 240 | stl.stylize(INPUT_IMAGE) 241 | 242 | 243 | if __name__ == '__main__': 244 | main() 245 | --------------------------------------------------------------------------------