├── .gitignore
├── images
├── pwr.jpg
├── candy.jpg
├── dancing.jpg
├── mosaic.jpg
├── picasso.jpg
└── results
│ ├── pwr_candy.jpg
│ ├── pwr_mosaic.jpg
│ ├── pwr_picasso.jpg
│ ├── dancing_candy.jpg
│ ├── dancing_mosaic.jpg
│ └── dancing_picasso.jpg
├── models
├── candy.pth
├── mosaic.pth
└── picasso.pth
├── feature_ext.py
├── utils.py
├── webcam.py
├── fnst_modules.py
├── README.md
└── fnst.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 |
--------------------------------------------------------------------------------
/images/pwr.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/pwr.jpg
--------------------------------------------------------------------------------
/images/candy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/candy.jpg
--------------------------------------------------------------------------------
/images/dancing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/dancing.jpg
--------------------------------------------------------------------------------
/images/mosaic.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/mosaic.jpg
--------------------------------------------------------------------------------
/images/picasso.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/picasso.jpg
--------------------------------------------------------------------------------
/models/candy.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/models/candy.pth
--------------------------------------------------------------------------------
/models/mosaic.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/models/mosaic.pth
--------------------------------------------------------------------------------
/models/picasso.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/models/picasso.pth
--------------------------------------------------------------------------------
/images/results/pwr_candy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/results/pwr_candy.jpg
--------------------------------------------------------------------------------
/images/results/pwr_mosaic.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/results/pwr_mosaic.jpg
--------------------------------------------------------------------------------
/images/results/pwr_picasso.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/results/pwr_picasso.jpg
--------------------------------------------------------------------------------
/images/results/dancing_candy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/results/dancing_candy.jpg
--------------------------------------------------------------------------------
/images/results/dancing_mosaic.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/results/dancing_mosaic.jpg
--------------------------------------------------------------------------------
/images/results/dancing_picasso.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mmalotin/pytorch-fast-neural-style-mobilenetV2/HEAD/images/results/dancing_picasso.jpg
--------------------------------------------------------------------------------
/feature_ext.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 |
4 | class FeatureExtractor():
5 | def __init__(self, model, idxs):
6 | self.__list = [0] * len(idxs)
7 | self.__hooks = []
8 | self.__modules = []
9 | self.__model_to_list(model)
10 | self.__create_hooks(idxs)
11 |
12 | def __create_hooks(self, idxs):
13 | help_idxs = list(range(len(idxs)))
14 | for i, idx in zip(idxs, help_idxs):
15 | fun = partial(self.__hook_fn, idx=idx)
16 | hook = self.__modules[i].register_forward_hook(fun)
17 | self.__hooks.append(hook)
18 |
19 | def __model_to_list(self, model):
20 | if list(model.children()) == []:
21 | self.__modules.append(model)
22 | for ch in model.children():
23 | self.__model_to_list(ch)
24 |
25 | def __hook_fn(self, module, input, output, idx):
26 | self.__list[idx] = output
27 |
28 | @property
29 | def features(self):
30 | return self.__list
31 |
32 | def remove_hooks(self):
33 | [x.remove() for x in self.__hooks]
34 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from PIL import Image
4 |
5 |
6 | def load_im(f, size=None):
7 | img = Image.open(f)
8 | if size is not None:
9 | img = img.resize((size, size), Image.ANTIALIAS)
10 | return img
11 |
12 |
13 | def save_im(f, tens):
14 | img = tens.detach().clamp(0, 255).cpu().numpy()
15 | img = img.transpose(1, 2, 0).astype("uint8")
16 | img = Image.fromarray(img)
17 | img.save(f)
18 |
19 |
20 | def gram_matrix(x):
21 | b, c, h, w = x.size()
22 | features = x.view(b, c, w*h)
23 | features_t = features.transpose(1, 2)
24 | gram = features.bmm(features_t) / (c*h*w)
25 | return gram
26 |
27 |
28 | def norm_batch(b):
29 | mean = b.new_tensor([0.485, 0.456, 0.406]).view(-1, 1, 1)
30 | std = b.new_tensor([0.229, 0.224, 0.225]).view(-1, 1, 1)
31 | b = b.div_(255.0)
32 | return (b - mean) / std
33 |
34 |
35 | def regularization_loss(x):
36 | loss = (torch.sum(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:])) +
37 | torch.sum(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :])))
38 | return loss
39 |
--------------------------------------------------------------------------------
/webcam.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | from fnst_modules import TransformerMobileNet
4 | from torchvision import transforms
5 |
6 | MODEL = 'models/mosaic.pth'
7 | IMAGE_SIZE = 300
8 |
9 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10 |
11 | transform = transforms.Compose([
12 | transforms.ToTensor(),
13 | transforms.Lambda(lambda x: x.mul(255).unsqueeze(0).to(device))])
14 |
15 |
16 | def postprocess(tens):
17 | img = tens.permute(1, 2, 0).clamp(0, 255)
18 | img = img.cpu().numpy()
19 | img = img.astype("uint8")
20 | return img
21 |
22 |
23 | def prepare_net(f, net):
24 | state_dict = torch.load(f)
25 | net.load_state_dict(state_dict)
26 | net.to(device)
27 | net.eval()
28 |
29 |
30 | def main():
31 | with torch.no_grad():
32 | net = TransformerMobileNet()
33 | prepare_net(MODEL, net)
34 |
35 | capture = cv2.VideoCapture(0)
36 |
37 | while True:
38 | _, im = capture.read()
39 | im = cv2.resize(im, (IMAGE_SIZE, IMAGE_SIZE))
40 | t = transform(im)
41 | res = net(t)[0]
42 | im = postprocess(res)
43 | cv2.imshow('webcam', im)
44 | if cv2.waitKey(1) == 27: # press Esc to end
45 | break
46 |
47 | capture.release()
48 | cv2.destroyAllWindows()
49 |
50 |
51 | if __name__ == '__main__':
52 | main()
53 |
--------------------------------------------------------------------------------
/fnst_modules.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class Bottleneck(nn.Module):
6 | def __init__(self, inp_c, out_c, kernel_size, stride, t=1):
7 | assert stride in [1, 2], 'stride must be either 1 or 2'
8 | super().__init__()
9 | self.residual = stride == 1 and inp_c == out_c
10 | pad = kernel_size // 2
11 | self.reflection_pad = nn.ReflectionPad2d(pad)
12 | self.conv1 = nn.Conv2d(inp_c, t*inp_c, 1, 1, bias=False)
13 | self.in1 = nn.InstanceNorm2d(t*inp_c, affine=True)
14 | self.conv2 = nn.Conv2d(t*inp_c, t*inp_c, kernel_size, stride,
15 | groups=t*inp_c, bias=False)
16 | self.in2 = nn.InstanceNorm2d(t*inp_c, affine=True)
17 | self.conv3 = nn.Conv2d(t*inp_c, out_c, 1, 1, bias=False)
18 | self.in3 = nn.InstanceNorm2d(out_c, affine=True)
19 |
20 | def forward(self, x):
21 | out = F.relu6(self.in1(self.conv1(x)))
22 | out = self.reflection_pad(out)
23 | out = F.relu6(self.in2(self.conv2(out)))
24 | out = self.in3(self.conv3(out))
25 | if self.residual:
26 | out = x + out
27 | return out
28 |
29 |
30 | class UpsampleConv(nn.Module):
31 | def __init__(self, inp_c, out_c, kernel_size, stride, upsample=2):
32 | super().__init__()
33 | if upsample:
34 | self.upsample = nn.Upsample(mode='nearest', scale_factor=upsample)
35 | else:
36 | self.upsample = None
37 | self.conv1 = Bottleneck(inp_c, out_c, kernel_size, stride)
38 |
39 | def forward(self, x):
40 | x_in = x
41 | if self.upsample is not None:
42 | x_in = self.upsample(x_in)
43 | out = F.relu(self.conv1(x_in))
44 | return out
45 |
46 |
47 | class TransformerMobileNet(nn.Module):
48 | def __init__(self):
49 | super().__init__()
50 | # Conv Layers
51 | self.reflection_pad = nn.ReflectionPad2d(9//2)
52 | self.conv1 = nn.Conv2d(3, 32, kernel_size=9, stride=1, bias=False)
53 | self.in1 = nn.InstanceNorm2d(32, affine=True)
54 | self.conv2 = Bottleneck(32, 64, kernel_size=3, stride=2)
55 | self.conv3 = Bottleneck(64, 128, kernel_size=3, stride=2)
56 | # Residual Layers
57 | self.res1 = Bottleneck(128, 128, 3, 1)
58 | self.res2 = Bottleneck(128, 128, 3, 1)
59 | self.res3 = Bottleneck(128, 128, 3, 1)
60 | self.res4 = Bottleneck(128, 128, 3, 1)
61 | self.res5 = Bottleneck(128, 128, 3, 1)
62 | # Upsampling Layers
63 | self.upconv1 = UpsampleConv(128, 64, kernel_size=3, stride=1)
64 | self.upconv2 = UpsampleConv(64, 32, kernel_size=3, stride=1)
65 | self.conv4 = nn.Conv2d(32, 3, kernel_size=9, stride=1, bias=False)
66 |
67 | def forward(self, x):
68 | out = self.reflection_pad(x)
69 | out = F.relu(self.in1(self.conv1(out)))
70 | out = self.conv2(out)
71 | out = self.conv3(out)
72 | out = self.res1(out)
73 | out = self.res2(out)
74 | out = self.res3(out)
75 | out = self.res4(out)
76 | out = self.res5(out)
77 | out = self.upconv1(out)
78 | out = self.upconv2(out)
79 | out = self.conv4(self.reflection_pad(out))
80 | return out
81 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Fast neural style with MobileNetV2 bottleneck blocks
2 | This repository contains a PyTorch implementation of an algorithm for artistic style transfer. The implementation is based on the following papers and repositories:
3 |
4 | - [Perceptual Losses for Real-Time Style Transfer and Super-Resolution](https://arxiv.org/abs/1603.08155)
5 | - [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381)
6 | - [A Neural Algorithm of Artistic Style](https://arxiv.org/abs/1508.06576)
7 | - [PyTorch Fast Neural Style Example](https://github.com/pytorch/examples/tree/master/fast_neural_style)
8 |
9 | ## Main Differences from Other Implementations
10 |
11 | * Residual Blocks and Convolutions are changed to MobileNetV2 bottleneck blocks which make use of Inverted Residuals and Depthwise Separable Convolutions.
12 |
13 | 
14 |
15 | On the picture you can see 2 types of MobileNetV2 bottleneck blocks. Left one is used instead of residual block and right one is used instead of convolution layer. Purposes of this change:
16 |
17 | - Decrease number of trainable parameters of the transformer network from __~1.67m__ to __~0.23m__, therefore decrease amount of the memory used by the transformer network.
18 |
19 | - In theory this should give a good speedup during training time and, more importantly, during inference time (fast neural style should be fast as possible). It appeared that in practice things are not so good and this architecture of the transformer network is only a bit faster than the original transformer network. The main cause of it is that depthwise convolutions are not so efficiently implemented on GPU as common convolutions are (on CPU the speedup is bigger, but still not drastic).
20 |
21 | * This implementation uses the feature extractor wrapper around PyTorch module which uses PyTorch hook methods to retrieve layer activations. With this extractor:
22 |
23 | - You don't need to write a new module wrapper in order to extract desired features every time you want to use a new loss network. You just need to input model and layer indexes to the feature extractor wrapper and it will handle extracting for you. (__Note__: The wrapper flattens the input module/model so you need to input proper indexes of the flattened module, i.e. if module/model is a composition of smaller modules it will be represented as flat list of layers inside the wrapper).
24 |
25 | - Makes training process slightly faster.
26 |
27 | * The implementation allows you use different weights for different style features, which leads to better visual results.
28 |
29 | ## Requirements
30 | - [pytorch](https://pytorch.org) (>= 0.4.0)
31 | - [torchvision](https://pytorch.org)
32 | - [PIL](https://pillow.readthedocs.io/en/5.1.x/)
33 | - [OpenCV](https://opencv.org/) (for webcam demo)
34 | - GPU is not necessary
35 |
36 | ## Usage
37 | To train the transformer network:
38 | ```
39 | python fnst.py -train
40 | ```
41 | To stylize an image with a pretrained model:
42 | ```
43 | python fnst.py
44 | ```
45 |
46 | All configurable parameters are stored as globals in the top of `fnst.py` file, so in order to configure those parameters just change them in `fnst.py` (I thought it is more convenient way than adding dozen of arguments).
47 |
48 | There is also webcam demo in the repo, to run it:
49 | ```
50 | python webcam.py
51 | ```
52 | `webcam.py` also has some globals in the top of the file which you can change.
53 |
54 |
55 | ## Examples
56 | All models were trained on 128x128 (because of GTX 960m on my laptop) [COCO Dataset](http://cocodataset.org) images for 3 epochs.
57 |
58 | - Styles
59 |
60 |
61 | - Results
62 |
63 |
64 |
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/fnst.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.optim import Adam
5 | from torch.utils.data import DataLoader
6 | from torchvision import datasets, transforms, models
7 |
8 | from PIL import Image
9 | from tqdm import tqdm
10 | import os
11 | import argparse
12 |
13 | from utils import load_im, save_im
14 | from utils import gram_matrix, norm_batch, regularization_loss
15 | from fnst_modules import TransformerMobileNet
16 | from feature_ext import FeatureExtractor
17 |
18 | EPOCHS = 2
19 | LOSS_NETWORK = models.vgg16
20 | TRAIN_PATH = '101_ObjectCategories'
21 | STYLE_IMAGE = 'images/mosaic.jpg'
22 | CHECK_IMAGE = 'images/dancing.jpg'
23 | IMAGE_SIZE = 128
24 | LAYER_IDXS = [3, 8, 15, 22]
25 | BATCH_SIZE = 4
26 | CONTENT_WEIGHT = 1
27 | STYLE_WEIGHT = 3 * 1e5
28 | REG_WEIGHT = 3 * 1e-5
29 | STYLE_PROPORTIONS = [.35, .35, .15, .15]
30 | CONTENT_INDEX = 1
31 | LOG_INTERVAL = 1000
32 | CHECKPOINT = 4000
33 |
34 | OUTPUT_PATH = 'images/results'
35 | INPUT_IMAGE = 'images/pwr.jpg'
36 | MODEL_PATH = 'models/mosaic.pth'
37 |
38 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39 |
40 |
41 | class Learner():
42 | def __init__(self, loss_network, train_path, style_image, check_image, im_size,
43 | layer_idxs, batch_size, c_weight, s_weight, r_weight,
44 | style_proportions, content_index, log_interval, checkpoint):
45 | assert len(style_proportions) == len(layer_idxs)
46 |
47 | # prepare dataset and loader
48 | self.dataset, self.loader = (
49 | self.__prepare_dataset(train_path, im_size, batch_size))
50 |
51 | # prepare transformer and classifier nets, optimizer
52 | self.tfm_net = TransformerMobileNet().to(device)
53 | self.loss_net = loss_network
54 | self.__prepare_loss_net()
55 | self.optimizer = Adam(self.tfm_net.parameters(), 1e-3)
56 |
57 | # prepare feature extractor, style image
58 | # and image for checking intermediate results
59 | self.content_index = content_index
60 | self.fx = FeatureExtractor(self.loss_net, layer_idxs)
61 | self.style_batch, self.style_target = (
62 | self.__prepare_style_target(style_image, batch_size))
63 | self.check_tensor = self.__prepare_check_tensor(check_image)
64 |
65 | # set weights for different losses
66 | self.content_weight = c_weight
67 | self.style_weights = [s_weight*x for x in style_proportions]
68 | self.reg_weight = r_weight
69 |
70 | # intervals
71 | self.log_intl = log_interval
72 | self.checkpoint = checkpoint
73 |
74 | def __prepare_paths(self):
75 | for p in ['models', 'images']:
76 | _path = os.path.join('tmp', p)
77 | os.makedirs(_path, exist_ok=True)
78 |
79 | def __prepare_dataset(self, train_path, im_size, batch_size):
80 | transform = transforms.Compose([
81 | transforms.Resize(im_size),
82 | transforms.CenterCrop(im_size),
83 | transforms.ToTensor(),
84 | transforms.Lambda(lambda x: x.mul(255))])
85 |
86 | ds = datasets.ImageFolder(train_path, transform)
87 | ld = DataLoader(ds, batch_size=batch_size)
88 | return ds, ld
89 |
90 | def __prepare_loss_net(self):
91 | for p in self.loss_net.parameters():
92 | p.requires_grad_(False)
93 | self.loss_net.to(device).eval()
94 |
95 | def __prepare_style_target(self, im_path, batch_size):
96 | transform = transforms.Compose([
97 | transforms.ToTensor(),
98 | transforms.Lambda(lambda x: x.mul(255))])
99 |
100 | style_im = load_im(im_path)
101 | style_tensor = transform(style_im)
102 | style_batch = style_tensor.repeat(batch_size, 1, 1, 1).to(device)
103 |
104 | self.loss_net(norm_batch(style_batch))
105 | style_target = tuple(gram_matrix(x)
106 | for x in self.fx.features)
107 | return style_batch, style_target
108 |
109 | def __prepare_check_tensor(self, im_path):
110 | transform = transforms.Compose([
111 | transforms.ToTensor(),
112 | transforms.Lambda(lambda x: x.mul(255))])
113 |
114 | if im_path is None:
115 | res = self.dataset[0].unsqueeze(0).to(device)
116 | else:
117 | res = transform(Image.open(im_path)).unsqueeze(0).to(device)
118 | return res
119 |
120 | def train(self, epochs):
121 | self.__prepare_paths()
122 | for e in range(epochs):
123 | self.tfm_net.train()
124 | agg_content_loss = 0.
125 | agg_style_loss = 0.
126 | agg_reg_loss = 0.
127 | for i, (x, _) in enumerate(tqdm(self.loader, desc=f'Epoch {e}')):
128 | len_batch = len(x)
129 | self.optimizer.zero_grad()
130 |
131 | x = x.to(device)
132 | y = self.tfm_net(x)
133 | x = norm_batch(x)
134 | y = norm_batch(y)
135 |
136 | self.loss_net(y)
137 | style_y = tuple(gram_matrix(x) for x in self.fx.features)
138 | content_y = self.fx.features[self.content_index]
139 |
140 | self.loss_net(x)
141 | content_x = self.fx.features[self.content_index]
142 |
143 | content_loss = F.mse_loss(content_y, content_x)
144 | content_loss *= self.content_weight
145 |
146 | style_loss = 0.
147 | for gm_y, gm_t, w in zip(style_y, self.style_target,
148 | self.style_weights):
149 | style_loss += (w*F.mse_loss(gm_y, gm_t[:len_batch, :, :]))
150 |
151 | reg_loss = self.reg_weight * regularization_loss(y)
152 |
153 | total_loss = content_loss + style_loss + reg_loss
154 |
155 | total_loss.backward()
156 | self.optimizer.step()
157 |
158 | agg_content_loss += content_loss.item()
159 | agg_style_loss += style_loss.item()
160 | agg_reg_loss += reg_loss.item()
161 |
162 | if (i+1) % self.log_intl == 0:
163 | self.intermediate_res(agg_content_loss, agg_style_loss,
164 | agg_reg_loss, i+1)
165 |
166 | if (i+1) % self.checkpoint == 0:
167 | self.save_tfm_net(e+1, i+1)
168 |
169 | self.save_tfm_net(e+1, i+1)
170 |
171 | def intermediate_res(self, c_loss, s_loss, r_loss, n):
172 | self.tfm_net.eval()
173 | check = self.tfm_net(self.check_tensor)
174 | _path = os.path.join('tmp', 'images', f'check{n}.jpg')
175 | save_im(_path, check[0])
176 | self.tfm_net.train()
177 |
178 | msg = (f'\nbatch: {n}\t'
179 | f'content: {c_loss/n}\t'
180 | f'style: {s_loss/n}\t'
181 | f'reg: {r_loss/n}\t'
182 | f'total: {(c_loss + s_loss + r_loss)/n} \n')
183 |
184 | print(msg)
185 |
186 | def save_tfm_net(self, e, i):
187 | name = f'epoch{e}_batch{i}.pth'
188 | _path = os.path.join('tmp', 'models', name)
189 | torch.save(self.tfm_net.state_dict(), _path)
190 |
191 |
192 | class Stylizer():
193 | def __init__(self, model_path, output_path):
194 | self.output_path = output_path
195 | self.model_name = os.path.basename(model_path).split('.')[0]
196 | self.__load_net(model_path)
197 | self.net.to(device)
198 | self.transform = transforms.Compose([
199 | transforms.ToTensor(),
200 | transforms.Lambda(lambda x: x.mul(255).unsqueeze(0).to(device))])
201 |
202 | def __load_net(self, model_path):
203 | with torch.no_grad():
204 | self.net = TransformerMobileNet()
205 | state_dict = torch.load(model_path)
206 | self.net.load_state_dict(state_dict)
207 |
208 | def stylize(self, im_path):
209 | with torch.no_grad():
210 | self.net.eval()
211 | im = load_im(im_path)
212 | x = self.transform(im)
213 | out = self.net(x)
214 | _name = (os.path.basename(im_path).split('.')[0] + '_'
215 | + self.model_name + '.jpg')
216 | _path = os.path.join(self.output_path, _name)
217 | save_im(_path, out[0])
218 |
219 |
220 | parser = argparse.ArgumentParser()
221 | parser.add_argument('-train', action='store_true')
222 |
223 |
224 | def main():
225 | args = parser.parse_args()
226 | if args.train:
227 | loss_network = LOSS_NETWORK(True)
228 | loss_network = nn.Sequential(*list(loss_network.features)[:23]) # ToDo
229 | lrn = Learner(loss_network=loss_network,
230 | train_path=TRAIN_PATH, style_image=STYLE_IMAGE,
231 | check_image=CHECK_IMAGE, im_size=IMAGE_SIZE,
232 | layer_idxs=LAYER_IDXS, batch_size=BATCH_SIZE,
233 | c_weight=CONTENT_WEIGHT, s_weight=STYLE_WEIGHT,
234 | r_weight=REG_WEIGHT, style_proportions=STYLE_PROPORTIONS,
235 | content_index=CONTENT_INDEX, log_interval=LOG_INTERVAL,
236 | checkpoint=CHECKPOINT)
237 | lrn.train(EPOCHS)
238 | else:
239 | stl = Stylizer(MODEL_PATH, OUTPUT_PATH)
240 | stl.stylize(INPUT_IMAGE)
241 |
242 |
243 | if __name__ == '__main__':
244 | main()
245 |
--------------------------------------------------------------------------------