├── .github └── workflows │ └── publish.yml ├── FST_preview.PNG ├── FastStyleTransferNode.py ├── README.md ├── Style_Transfer_Workflow.json ├── __init__.py ├── dataset └── MC COCO train dataset goes here ├── models ├── bayanihan.pth ├── lazy.pth ├── mosaic.pth ├── starry.pth ├── tokyo_ghoul.pth ├── udnie.pth └── wave.pth ├── neural_style_transfer.py ├── output └── Intermediate pictures from training saved here ├── pyproject.toml ├── temp └── temp_files_go_here ├── train.py └── vgg └── vgg16-00b39a1b model goes here /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | permissions: 12 | issues: write 13 | 14 | jobs: 15 | publish-node: 16 | name: Publish Custom Node to registry 17 | runs-on: ubuntu-latest 18 | if: ${{ github.repository_owner == 'zeroxoxo' }} 19 | steps: 20 | - name: Check out code 21 | uses: actions/checkout@v4 22 | - name: Publish Custom Node 23 | uses: Comfy-Org/publish-node-action@v1 24 | with: 25 | ## Add your own personal access token to your Github Repository secrets and reference it here. 26 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 27 | -------------------------------------------------------------------------------- /FST_preview.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroxoxo/ComfyUI-Fast-Style-Transfer/20363dc232300487bbbf736b3b493c06d7d84e23/FST_preview.PNG -------------------------------------------------------------------------------- /FastStyleTransferNode.py: -------------------------------------------------------------------------------- 1 | """ 2 | These nodes are a simple conversion of these repositories into ComfyUI ecosystem: 3 | https://github.com/rrmina/fast-neural-style-pytorch.git 4 | https://github.com/gordicaleksa/pytorch-neural-style-transfer.git 5 | 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import time 11 | import os 12 | import folder_paths 13 | import subprocess as sp 14 | import sys 15 | 16 | # ML classes 17 | class ConvolutionalLayer(nn.Module): 18 | def __init__(self, in_channels, out_channels, kernel_size, stride, norm="instance"): 19 | super(ConvolutionalLayer, self).__init__() 20 | # Padding Layers 21 | self.padding_size = kernel_size // 2 22 | self.reflection_pad = nn.ReflectionPad2d(self.padding_size) 23 | 24 | # Convolution Layer 25 | self.conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride) 26 | 27 | # Normalization Layers 28 | self.norm_type = norm 29 | if (norm=="instance"): 30 | self.norm_layer = nn.InstanceNorm2d(out_channels, affine=True) 31 | elif (norm=="batch"): 32 | self.norm_layer = nn.BatchNorm2d(out_channels, affine=True) 33 | 34 | def forward(self, x): 35 | x = self.reflection_pad(x) 36 | x = self.conv_layer(x) 37 | if self.norm_type == "None": 38 | out = x 39 | else: 40 | out = self.norm_layer(x) 41 | return out 42 | 43 | 44 | class ResidualLayer(nn.Module): 45 | """ 46 | Deep Residual Learning for Image Recognition 47 | 48 | https://arxiv.org/abs/1512.03385 49 | """ 50 | def __init__(self, channels=128, kernel_size=3): 51 | super(ResidualLayer, self).__init__() 52 | self.conv1 = ConvolutionalLayer(channels, channels, kernel_size, stride=1) 53 | self.relu = nn.ReLU() 54 | self.conv2 = ConvolutionalLayer(channels, channels, kernel_size, stride=1) 55 | 56 | def forward(self, x): 57 | identity = x # preserve residual 58 | out = self.relu(self.conv1(x)) # 1st conv layer + activation 59 | out = self.conv2(out) # 2nd conv layer 60 | out = out + identity # add residual 61 | return out 62 | 63 | 64 | class DeconvolutionalLayer(nn.Module): 65 | def __init__(self, in_channels, out_channels, kernel_size, stride, output_padding, norm="instance"): 66 | super(DeconvolutionalLayer, self).__init__() 67 | 68 | # Transposed Convolution 69 | padding_size = kernel_size // 2 70 | self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding_size, output_padding) 71 | 72 | # Normalization Layers 73 | self.norm_type = norm 74 | if (norm=="instance"): 75 | self.norm_layer = nn.InstanceNorm2d(out_channels, affine=True) 76 | elif (norm=="batch"): 77 | self.norm_layer = nn.BatchNorm2d(out_channels, affine=True) 78 | 79 | def forward(self, x): 80 | x = self.conv_transpose(x) 81 | if (self.norm_type=="None"): 82 | out = x 83 | else: 84 | out = self.norm_layer(x) 85 | return out 86 | 87 | 88 | class TransformerNetwork(nn.Module): 89 | """Feedforward Transformation Network without Tanh 90 | reference: https://arxiv.org/abs/1603.08155 91 | exact architecture: https://cs.stanford.edu/people/jcjohns/papers/fast-style/fast-style-supp.pdf 92 | """ 93 | def __init__(self): 94 | super(TransformerNetwork, self).__init__() 95 | self.ConvBlock = nn.Sequential( 96 | ConvolutionalLayer(3, 32, 9, 1), 97 | nn.ReLU(), 98 | ConvolutionalLayer(32, 64, 3, 2), 99 | nn.ReLU(), 100 | ConvolutionalLayer(64, 128, 3, 2), 101 | nn.ReLU() 102 | ) 103 | self.ResidualBlock = nn.Sequential( 104 | ResidualLayer(128, 3), 105 | ResidualLayer(128, 3), 106 | ResidualLayer(128, 3), 107 | ResidualLayer(128, 3), 108 | ResidualLayer(128, 3) 109 | ) 110 | self.DeconvBlock = nn.Sequential( 111 | DeconvolutionalLayer(128, 64, 3, 2, 1), 112 | nn.ReLU(), 113 | DeconvolutionalLayer(64, 32, 3, 2, 1), 114 | nn.ReLU(), 115 | ConvolutionalLayer(32, 3, 9, 1, norm="None") 116 | ) 117 | 118 | def forward(self, x): 119 | x = self.ConvBlock(x) 120 | x = self.ResidualBlock(x) 121 | out = self.DeconvBlock(x) 122 | return out 123 | 124 | 125 | # Node classes 126 | class TrainFastStyleTransfer: 127 | def __init__(self): 128 | pass 129 | 130 | @classmethod 131 | def INPUT_TYPES(s): 132 | return { 133 | "required": { 134 | "style_img": ("IMAGE",), 135 | "seed": ("INT", {"default": 30, "min": 0, "max": 999999, "step": 1,}), 136 | "content_weight": ("INT", {"default": 14, "min": 1, "max": 128, "step": 1,}), 137 | "style_weight": ("INT", {"default": 50, "min": 1, "max": 128, "step": 1,}), 138 | "tv_weight": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step": 0.0000001}), 139 | "batch_size": ("INT", {"default": 4, "min": 1, "max": 32, "step": 1,}), 140 | "train_img_size": ("INT", {"default": 256, "min": 128, "max": 2048, "step": 1,}), 141 | "learning_rate": ("FLOAT", {"default": 0.001, "min": 0.0001, "max": 100.0, "step": 0.0001}), 142 | "num_epochs": ("INT", {"default": 1, "min": 1, "max": 20, "step": 1,}), 143 | "save_model_every": ("INT", {"default": 500, "min": 10, "max": 10000, "step": 10,}), 144 | "from_pretrained": ("INT", {"default": 0, "min": 0, "max": 1, "step": 1,}), 145 | "model": ([file for file in os.listdir(os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/models/")) if file.endswith('.pth')], ), 146 | }, 147 | } 148 | 149 | RETURN_TYPES = () 150 | OUTPUT_NODE = True 151 | 152 | FUNCTION = "train" 153 | 154 | CATEGORY = "Style Transfer" 155 | 156 | def encode_tensor(self, tensor): 157 | tensor = tensor.permute(0, 3, 1, 2).contiguous() # Convert to [batch_size, channels, height, width] 158 | return tensor[:, [2, 1, 0], :, :] * 255 159 | 160 | 161 | def train(self, style_img, seed, batch_size, train_img_size, learning_rate, num_epochs, content_weight, style_weight, tv_weight, save_model_every, from_pretrained, model): 162 | temp_save_style_img = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/temp/") + "temp_save_content_img.pt" 163 | save_model_path = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/models/") 164 | dataset_path = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/dataset/") 165 | vgg_path = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/vgg/vgg16-00b39a1b.pth") 166 | save_image_path = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/output/") 167 | train_path = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/train.py") 168 | 169 | 170 | 171 | command = [ 172 | sys.executable, train_path, 173 | '--train_image_size', str(train_img_size), 174 | '--dataset_path', dataset_path, 175 | '--vgg_path', vgg_path, 176 | '--num_epochs', str(num_epochs), 177 | '--temp_save_style_img', temp_save_style_img, 178 | '--batch_size', str(batch_size), 179 | '--content_weight', str(content_weight), 180 | '--style_weight', str(style_weight), 181 | '--tv_weight', str(tv_weight), 182 | '--adam_lr', str(learning_rate), 183 | '--save_model_path', save_model_path, 184 | '--save_image_path', save_image_path, 185 | '--save_model_every', str(save_model_every), 186 | '--seed', str(seed), 187 | '--pretrained_model' 188 | ] 189 | 190 | if from_pretrained: 191 | command.append(model) 192 | else: 193 | command.append('none') 194 | 195 | 196 | torch.save(self.encode_tensor(style_img), temp_save_style_img) 197 | 198 | sp.run(command) 199 | return () 200 | 201 | 202 | class FastStyleTransfer: 203 | def __init__(self): 204 | pass 205 | 206 | @classmethod 207 | def INPUT_TYPES(s): 208 | return { 209 | "required": { 210 | "content_img": ("IMAGE",), 211 | "model": ([file for file in os.listdir(os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/models/")) if file.endswith('.pth')], ), 212 | }, 213 | } 214 | 215 | RETURN_TYPES = ("IMAGE",) 216 | 217 | FUNCTION = "styleTransfer" 218 | 219 | CATEGORY = "Style Transfer" 220 | 221 | 222 | def encode_tensor(self, tensor): 223 | tensor = tensor.permute(0, 3, 1, 2).contiguous() # Convert to [batch_size, channels, height, width] 224 | return tensor[:, [2, 1, 0], :, :] * 255 225 | 226 | def decode_tensor(self, tensor): 227 | tensor = tensor[:, [2, 1, 0], :, :] 228 | tensor = tensor.permute(0, 2, 3, 1).contiguous() # Convert to [batch_size, height, width, channels] 229 | return tensor / 255 230 | 231 | def styleTransfer(self, content_img, model): 232 | # Device 233 | device = ("cuda" if torch.cuda.is_available() else "cpu") 234 | 235 | # Load Transformer Network 236 | net = TransformerNetwork().to(device) 237 | model_path = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/models/") + model 238 | net.load_state_dict(torch.load(model_path, map_location=device)) 239 | net = net.to(device) 240 | 241 | with torch.no_grad(): 242 | torch.cuda.empty_cache() 243 | starttime = time.time() 244 | content_tensor = self.encode_tensor(content_img) 245 | generated_tensor = net(content_tensor.to(device)) 246 | print("Transfer Time: {}".format(time.time() - starttime)) 247 | image = self.decode_tensor(generated_tensor) 248 | return (image,) 249 | 250 | 251 | class NeuralStyleTransfer: 252 | 253 | 254 | 255 | @classmethod 256 | def INPUT_TYPES(s): 257 | return { 258 | "required": { 259 | "content_img": ("IMAGE",), 260 | "style_img": ("IMAGE",), 261 | "content_weight": ("FLOAT", {"default": 1e5, "min": 1e3, "max": 1e6, "step": 1e3}), 262 | "style_weight": ("FLOAT", {"default": 3e4, "min": 1e1, "max": 1e5, "step": 1e1}), 263 | "tv_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1e1, "step": 0.1}), 264 | "num_steps": ("INT", {"default": 100, "min": 10, "max": 10000, "step": 10}), 265 | "learning_rate": ("FLOAT", {"default": 1.0, "min": 1e-4, "max": 1e3, "step": 0.1}), 266 | }, 267 | } 268 | 269 | RETURN_TYPES = ("IMAGE",) 270 | FUNCTION = "neural_style_transfer" 271 | CATEGORY = "Style Transfer" 272 | 273 | def encode_tensor(self, tensor): 274 | tensor = tensor.permute(0, 3, 1, 2).contiguous() # Convert to [batch_size, channels, height, width] 275 | return tensor * 255 276 | 277 | def decode_tensor(self, tensor): 278 | tensor = tensor.permute(0, 2, 3, 1).contiguous() # Convert to [batch_size, height, width, channels] 279 | return tensor / 255 280 | 281 | 282 | def neural_style_transfer(self, content_img, style_img, content_weight, style_weight, tv_weight, num_steps, learning_rate): 283 | 284 | neural_style_transfer_path = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/neural_style_transfer.py") 285 | 286 | temp_save_content_img = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/temp/") + "temp_save_content_img.pt" 287 | temp_save_style_img = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/temp/") + "temp_save_style_img.pt" 288 | 289 | temp_load_final_img = os.path.join(folder_paths.base_path, "custom_nodes/ComfyUI-Fast-Style-Transfer/temp/") + "temp_load_final_img.pt" 290 | 291 | torch.save(self.encode_tensor(content_img), temp_save_content_img) 292 | torch.save(self.encode_tensor(style_img), temp_save_style_img) 293 | 294 | 295 | command = [ 296 | sys.executable, neural_style_transfer_path, 297 | '--content_weight', str(content_weight), 298 | '--style_weight', str(style_weight), 299 | '--tv_weight', str(tv_weight), 300 | '--temp_save_style_img', temp_save_style_img, 301 | '--temp_save_content_img', temp_save_content_img, 302 | '--temp_load_final_img', temp_load_final_img, 303 | '--num_steps', str(num_steps), 304 | '--learning_rate', str(learning_rate) 305 | ] 306 | 307 | sp.run(command) 308 | 309 | image = self.decode_tensor(torch.load(temp_load_final_img)) 310 | os.remove(temp_save_style_img) 311 | os.remove(temp_save_content_img) 312 | os.remove(temp_load_final_img) 313 | return (image,) 314 | 315 | # A dictionary that contains all nodes you want to export with their names 316 | # NOTE: names should be globally unique 317 | NODE_CLASS_MAPPINGS = { 318 | "FastStyleTransfer": FastStyleTransfer, 319 | "TrainFastStyleTransfer": TrainFastStyleTransfer, 320 | "NeuralStyleTransfer": NeuralStyleTransfer, 321 | } 322 | 323 | # A dictionary that contains the friendly/humanly readable titles for the nodes 324 | NODE_DISPLAY_NAME_MAPPINGS = { 325 | "FastStyleTransfer": "Fast Style Transfer", 326 | "TrainFastStyleTransfer": "Train Fast Style Transfer", 327 | "NeuralStyleTransfer": "Neural Style Transfer", 328 | } 329 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-Fast-Style-Transfer 2 | ComfyUI node for fast neural style transfer. 3 | 4 | This is a simple conversion based on this: 5 | https://github.com/rrmina/fast-neural-style-pytorch 6 | 7 | [Experimental] 8 | Also ported regular neural style transfer from here: 9 | https://github.com/gordicaleksa/pytorch-neural-style-transfer 10 | But it's much slower and not that useful but you can play with it if you want 11 | 12 | ![alt text](https://github.com/zeroxoxo/ComfyUI-Fast-Style-Transfer/blob/main/FST_preview.PNG?raw=true) 13 | 14 | # Installation 15 | 16 | Probably the usual. Just "git clone https://github.com/zeroxoxo/ComfyUI-Fast-Style-Transfer.git" into your custom_nodes folder. That should be it. 17 | 18 | If it doesn't work then idk, ask stack exchange or something, how should I know what's wrong with your setup? 19 | I use portable setup of ComfyUI so if it doesn't work try it with portable version 20 | 21 | # Training 22 | 23 | First you'll need to download some files: 24 | 25 | VGG-16: https://github.com/jcjohnson/pytorch-vgg 26 | 27 | Put it into vgg folder. 28 | 29 | 30 | MS COCO train dataset. 31 | 32 | Original repo suggests train-2014 dataset from here: https://cocodataset.org/#download 33 | 34 | But be wary that it's 13Gb. 35 | 36 | I used MS COCO train-2017 dataset downscaled to 256x256 from here: https://academictorrents.com/details/eea5a532dd69de7ff93d5d9c579eac55a41cb700 37 | 38 | It's only 1.64Gb and original repo still used training with 256x256 size images but it manually downscaled it from the 13Gb dataset. 39 | 40 | Put the train-2017 (or train-2014) folder into dataset folder. 41 | 42 | 43 | That's it for downloads. 44 | 45 | Now just use ComfyUI to load TrainFastStyleTransfer node. 46 | 47 | To select style picture load "load_image" node and connect it with the TFST node. 48 | 49 | Default content_weight, style_weight and tv_weight should be good starting points. Increase style_weight if you need more style, tv_weight affects sharpness of style features, needs experimenting but seems to be very useful in controlling how style applies to the image. 50 | 51 | Adjusting batch_size as high as you can with your vram doesn't seem to do much. So just use default 4 with img_size of 256. 52 | 53 | You probably won't need to wait for whole epoch either, just train until total loss stops getting reliably lower and just fluctuates around the same ballpark. 54 | 55 | Use one of the pretrained models as a starting point, helps to reduce training time drastically. 56 | 57 | save_model_every will save model and produce test picture every n-th step of training. 58 | 59 | After setting all parameters just queue prompt and don't wait until training is done. Set save_model_every to a low value like 100 or 200 and look at pictures it produces (intermediate pictures saved in outputs folder). Starting with pretrained model should produce good enough model in less than 2000 training steps. As soon as you're fine with the result just close the training script. 60 | 61 | All intermediate models will be saved in models folder, test them, delete redundant and rename the one you like. 62 | -------------------------------------------------------------------------------- /Style_Transfer_Workflow.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 23, 3 | "last_link_id": 60, 4 | "nodes": [ 5 | { 6 | "id": 20, 7 | "type": "LoadImage", 8 | "pos": [ 9 | -102, 10 | 237 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 314 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "IMAGE", 22 | "type": "IMAGE", 23 | "links": [ 24 | 59 25 | ], 26 | "shape": 3, 27 | "slot_index": 0 28 | }, 29 | { 30 | "name": "MASK", 31 | "type": "MASK", 32 | "links": null, 33 | "shape": 3 34 | } 35 | ], 36 | "properties": { 37 | "Node name for S&R": "LoadImage" 38 | }, 39 | "widgets_values": [ 40 | "ComfyUI_00131_ (1).png", 41 | "image" 42 | ] 43 | }, 44 | { 45 | "id": 5, 46 | "type": "SaveImage", 47 | "pos": [ 48 | 620, 49 | 273 50 | ], 51 | "size": { 52 | "0": 315, 53 | "1": 270 54 | }, 55 | "flags": {}, 56 | "order": 2, 57 | "mode": 0, 58 | "inputs": [ 59 | { 60 | "name": "images", 61 | "type": "IMAGE", 62 | "link": 60 63 | } 64 | ], 65 | "properties": {}, 66 | "widgets_values": [ 67 | "ComfyUI" 68 | ] 69 | }, 70 | { 71 | "id": 23, 72 | "type": "FastStyleTransfer", 73 | "pos": [ 74 | 263, 75 | 184 76 | ], 77 | "size": { 78 | "0": 315, 79 | "1": 58 80 | }, 81 | "flags": {}, 82 | "order": 1, 83 | "mode": 0, 84 | "inputs": [ 85 | { 86 | "name": "content_img", 87 | "type": "IMAGE", 88 | "link": 59 89 | } 90 | ], 91 | "outputs": [ 92 | { 93 | "name": "IMAGE", 94 | "type": "IMAGE", 95 | "links": [ 96 | 60 97 | ], 98 | "shape": 3, 99 | "slot_index": 0 100 | } 101 | ], 102 | "properties": { 103 | "Node name for S&R": "FastStyleTransfer" 104 | }, 105 | "widgets_values": [ 106 | "mosaic.pth" 107 | ] 108 | } 109 | ], 110 | "links": [ 111 | [ 112 | 59, 113 | 20, 114 | 0, 115 | 23, 116 | 0, 117 | "IMAGE" 118 | ], 119 | [ 120 | 60, 121 | 23, 122 | 0, 123 | 5, 124 | 0, 125 | "IMAGE" 126 | ] 127 | ], 128 | "groups": [], 129 | "config": {}, 130 | "extra": { 131 | "ds": { 132 | "scale": 1.0152559799477097, 133 | "offset": [ 134 | 255.66797552675814, 135 | -34.13997175527328 136 | ] 137 | }, 138 | "groupNodes": {} 139 | }, 140 | "version": 0.4 141 | } -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .FastStyleTransferNode import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS 2 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] -------------------------------------------------------------------------------- /dataset/MC COCO train dataset goes here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroxoxo/ComfyUI-Fast-Style-Transfer/20363dc232300487bbbf736b3b493c06d7d84e23/dataset/MC COCO train dataset goes here -------------------------------------------------------------------------------- /models/bayanihan.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroxoxo/ComfyUI-Fast-Style-Transfer/20363dc232300487bbbf736b3b493c06d7d84e23/models/bayanihan.pth -------------------------------------------------------------------------------- /models/lazy.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroxoxo/ComfyUI-Fast-Style-Transfer/20363dc232300487bbbf736b3b493c06d7d84e23/models/lazy.pth -------------------------------------------------------------------------------- /models/mosaic.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroxoxo/ComfyUI-Fast-Style-Transfer/20363dc232300487bbbf736b3b493c06d7d84e23/models/mosaic.pth -------------------------------------------------------------------------------- /models/starry.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroxoxo/ComfyUI-Fast-Style-Transfer/20363dc232300487bbbf736b3b493c06d7d84e23/models/starry.pth -------------------------------------------------------------------------------- /models/tokyo_ghoul.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroxoxo/ComfyUI-Fast-Style-Transfer/20363dc232300487bbbf736b3b493c06d7d84e23/models/tokyo_ghoul.pth -------------------------------------------------------------------------------- /models/udnie.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroxoxo/ComfyUI-Fast-Style-Transfer/20363dc232300487bbbf736b3b493c06d7d84e23/models/udnie.pth -------------------------------------------------------------------------------- /models/wave.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroxoxo/ComfyUI-Fast-Style-Transfer/20363dc232300487bbbf736b3b493c06d7d84e23/models/wave.pth -------------------------------------------------------------------------------- /neural_style_transfer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim import AdamW, LBFGS 3 | from torchvision import models 4 | import torch.nn as nn 5 | import argparse 6 | from collections import namedtuple 7 | 8 | class Vgg16(torch.nn.Module): 9 | """Only those layers are exposed which have already proven to work nicely.""" 10 | def __init__(self, requires_grad=False, show_progress=False): 11 | super().__init__() 12 | vgg_pretrained_features = models.vgg16(pretrained=True, progress=show_progress).features 13 | self.layer_names = ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3'] 14 | self.content_feature_maps_index = 1 # relu2_2 15 | self.style_feature_maps_indices = list(range(len(self.layer_names))) # all layers used for style representation 16 | 17 | self.slice1 = torch.nn.Sequential() 18 | self.slice2 = torch.nn.Sequential() 19 | self.slice3 = torch.nn.Sequential() 20 | self.slice4 = torch.nn.Sequential() 21 | for x in range(4): 22 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 23 | for x in range(4, 9): 24 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 25 | for x in range(9, 16): 26 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 27 | for x in range(16, 23): 28 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 29 | if not requires_grad: 30 | for param in self.parameters(): 31 | param.requires_grad = False 32 | 33 | def forward(self, x): 34 | x = self.slice1(x) 35 | relu1_2 = x 36 | x = self.slice2(x) 37 | relu2_2 = x 38 | x = self.slice3(x) 39 | relu3_3 = x 40 | x = self.slice4(x) 41 | relu4_3 = x 42 | vgg_outputs = namedtuple("VggOutputs", self.layer_names) 43 | out = vgg_outputs(relu1_2, relu2_2, relu3_3, relu4_3) 44 | return out 45 | 46 | 47 | def prepare_model(device): 48 | model = Vgg16() 49 | 50 | content_feature_maps_index = model.content_feature_maps_index 51 | style_feature_maps_indices = model.style_feature_maps_indices 52 | layer_names = model.layer_names 53 | 54 | content_fms_index_name = (content_feature_maps_index, layer_names[content_feature_maps_index]) 55 | style_fms_indices_names = (style_feature_maps_indices, layer_names) 56 | return model.to(device).eval(), content_fms_index_name, style_fms_indices_names 57 | 58 | 59 | def gram_matrix(x, should_normalize=True): 60 | (b, ch, h, w) = x.size() 61 | features = x.view(b, ch, w * h) 62 | features_t = features.transpose(1, 2) 63 | gram = features.bmm(features_t) 64 | if should_normalize: 65 | gram /= ch * h * w 66 | return gram 67 | 68 | 69 | def total_variation(y): 70 | return torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + \ 71 | torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])) 72 | 73 | 74 | def build_loss(neural_net, optimizing_img, target_representations, content_feature_maps_index, style_feature_maps_indices, config): 75 | target_content_representation = target_representations[0] 76 | target_style_representation = target_representations[1] 77 | 78 | current_set_of_feature_maps = neural_net(optimizing_img) 79 | 80 | current_content_representation = current_set_of_feature_maps[content_feature_maps_index].squeeze(axis=0) 81 | content_loss = torch.nn.MSELoss(reduction='mean')(target_content_representation, current_content_representation) 82 | 83 | style_loss = 0.0 84 | current_style_representation = [gram_matrix(x) for cnt, x in enumerate(current_set_of_feature_maps) if cnt in style_feature_maps_indices] 85 | for gram_gt, gram_hat in zip(target_style_representation, current_style_representation): 86 | style_loss += torch.nn.MSELoss(reduction='sum')(gram_gt[0], gram_hat[0]) 87 | style_loss /= len(target_style_representation) 88 | 89 | tv_loss = total_variation(optimizing_img) 90 | 91 | total_loss = config['content_weight'] * content_loss + config['style_weight'] * style_loss + config['tv_weight'] * tv_loss 92 | 93 | return total_loss, content_loss, style_loss, tv_loss 94 | 95 | def make_tuning_step(neural_net, optimizer, target_representations, content_feature_maps_index, style_feature_maps_indices, config): 96 | def tuning_step(optimizing_img): 97 | optimizer.zero_grad() 98 | total_loss, content_loss, style_loss, tv_loss = build_loss(neural_net, optimizing_img, target_representations, content_feature_maps_index, style_feature_maps_indices, config) 99 | total_loss.backward() 100 | optimizer.step() 101 | return total_loss, content_loss, style_loss, tv_loss 102 | return tuning_step 103 | 104 | def neural_style_transfer_from_tensors(content_img_tensor, style_img_tensor, config, num_steps, learning_rate): 105 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 106 | 107 | content_img = content_img_tensor.to(device) 108 | style_img = style_img_tensor.to(device) 109 | 110 | optimizing_img = content_img.clone().requires_grad_(True) 111 | 112 | neural_net, content_feature_maps_index_name, style_feature_maps_indices_names = prepare_model(device) 113 | 114 | content_img_set_of_feature_maps = neural_net(content_img) 115 | style_img_set_of_feature_maps = neural_net(style_img) 116 | 117 | target_content_representation = content_img_set_of_feature_maps[content_feature_maps_index_name[0]].squeeze(axis=0) 118 | target_style_representation = [gram_matrix(x) for cnt, x in enumerate(style_img_set_of_feature_maps) if cnt in style_feature_maps_indices_names[0]] 119 | target_representations = [target_content_representation, target_style_representation] 120 | 121 | optimizer = LBFGS((optimizing_img,), max_iter=num_steps, lr=learning_rate, line_search_fn='strong_wolfe') 122 | cnt = 0 123 | 124 | def closure(): 125 | nonlocal cnt 126 | if torch.is_grad_enabled(): 127 | optimizer.zero_grad() 128 | total_loss, content_loss, style_loss, tv_loss = build_loss(neural_net, optimizing_img, target_representations, content_feature_maps_index_name[0], style_feature_maps_indices_names[0], config) 129 | if total_loss.requires_grad: 130 | total_loss.backward() 131 | if cnt%100==0: 132 | with torch.no_grad(): 133 | print(f'L-BFGS | iteration: {cnt:03}, total loss={total_loss.item():12.4f}, content_loss={config["content_weight"] * content_loss.item():12.4f}, style loss={config["style_weight"] * style_loss.item():12.4f}, tv loss={config["tv_weight"] * tv_loss.item():12.4f}') 134 | 135 | cnt += 1 136 | return total_loss 137 | 138 | optimizer.step(closure) 139 | 140 | return optimizing_img.detach() 141 | 142 | 143 | if __name__ == "__main__": 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument("--temp_save_style_img", type=str) 146 | parser.add_argument("--temp_save_content_img", type=str) 147 | parser.add_argument("--temp_load_final_img", type=str) 148 | 149 | parser.add_argument("--content_weight", type=float, help="weight factor for content loss", default=1e5) 150 | parser.add_argument("--style_weight", type=float, help="weight factor for style loss", default=3e4) 151 | parser.add_argument("--tv_weight", type=float, help="weight factor for total variation loss", default=1e0) 152 | parser.add_argument("--learning_rate", type=float, help="learning_rate", default=1e1) 153 | parser.add_argument("--num_steps", type=int, help="number of training steps", default=100) 154 | 155 | args = parser.parse_args() 156 | config = { 157 | "content_weight": args.content_weight, 158 | "style_weight": args.style_weight, 159 | "tv_weight": args.tv_weight, 160 | } 161 | style_img = torch.load(args.temp_save_style_img) 162 | content_img = torch.load(args.temp_save_content_img) 163 | final_img = neural_style_transfer_from_tensors(content_img, style_img, config, args.num_steps, args.learning_rate) 164 | print("nst done") 165 | 166 | torch.save(final_img, args.temp_load_final_img) 167 | 168 | -------------------------------------------------------------------------------- /output/Intermediate pictures from training saved here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroxoxo/ComfyUI-Fast-Style-Transfer/20363dc232300487bbbf736b3b493c06d7d84e23/output/Intermediate pictures from training saved here -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-fast-style-transfer" 3 | description = "ComfyUI node for fast neural style transfer. This is a simple conversion based on this: [a/https://github.com/rrmina/fast-neural-style-pytorch](https://github.com/rrmina/fast-neural-style-pytorch)" 4 | version = "1.0.2" 5 | license = "LICENSE" 6 | 7 | [project.urls] 8 | Repository = "https://github.com/zeroxoxo/ComfyUI-Fast-Style-Transfer" 9 | # Used by Comfy Registry https://comfyregistry.org 10 | 11 | [tool.comfy] 12 | PublisherId = "zeroxoxo" 13 | DisplayName = "ComfyUI-Fast-Style-Transfer" 14 | Icon = "" 15 | -------------------------------------------------------------------------------- /temp/temp_files_go_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroxoxo/ComfyUI-Fast-Style-Transfer/20363dc232300487bbbf736b3b493c06d7d84e23/temp/temp_files_go_here -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | from torchvision import datasets, transforms, models 5 | import random 6 | import numpy as np 7 | import time 8 | import argparse 9 | import cv2 10 | 11 | parser = argparse.ArgumentParser(description="Train a neural network with style transfer.") 12 | 13 | parser.add_argument('--train_image_size', type=int, default=256, help='Size of the training images') 14 | parser.add_argument('--dataset_path', type=str, default="dataset", help='Path to the dataset') 15 | parser.add_argument('--vgg_path', type=str, default="vgg", help='Path to the vgg model') 16 | parser.add_argument('--num_epochs', type=int, default=1, help='Number of epochs for training') 17 | parser.add_argument('--temp_save_style_img', type=str, default="temp/temp_save_style_img.pt", help='Path to the style image tensor') 18 | parser.add_argument('--batch_size', type=int, default=12, help='Batch size for training') 19 | parser.add_argument('--content_weight', type=float, default=8, help='Weight for content loss') 20 | parser.add_argument('--style_weight', type=float, default=50, help='Weight for style loss') 21 | parser.add_argument('--tv_weight', type=float, default=0.001, help='Weight for total variation loss') 22 | parser.add_argument('--adam_lr', type=float, default=0.001, help='Learning rate for Adam optimizer') 23 | parser.add_argument('--save_model_path', type=str, default="models/oil/", help='Path to save the trained model') 24 | parser.add_argument('--save_image_path', type=str, default="images/out/", help='Path to save the output images') 25 | parser.add_argument('--save_model_every', type=int, default=200, help='Save model every n batches') 26 | parser.add_argument('--seed', type=int, default=1234, help='Random seed') 27 | parser.add_argument('--pretrained_model', type=str, default='none', help='pretrained model') 28 | 29 | args = parser.parse_args() 30 | 31 | # GLOBAL SETTINGS 32 | TRAIN_IMAGE_SIZE = args.train_image_size 33 | DATASET_PATH = args.dataset_path 34 | NUM_EPOCHS = args.num_epochs 35 | STYLE_IMAGE_PATH = args.temp_save_style_img 36 | VGG_PATH = args.vgg_path 37 | BATCH_SIZE = args.batch_size 38 | CONTENT_WEIGHT = args.content_weight 39 | STYLE_WEIGHT = args.style_weight 40 | TV_WEIGHT = args.tv_weight 41 | ADAM_LR = args.adam_lr 42 | SAVE_MODEL_PATH = args.save_model_path 43 | SAVE_IMAGE_PATH = args.save_image_path 44 | SAVE_MODEL_EVERY = args.save_model_every 45 | SEED = args.seed 46 | MODEL = args.pretrained_model 47 | 48 | # Utils 49 | # Gram Matrix 50 | def gram(tensor): 51 | B, C, H, W = tensor.shape 52 | x = tensor.view(B, C, H*W) 53 | x_t = x.transpose(1, 2) 54 | gram_matrix = torch.bmm(x, x_t) / (C*H*W) 55 | gram_matrix = torch.clamp(gram_matrix, min=1e-6, max=1e6) 56 | return gram_matrix 57 | 58 | # Save image 59 | def saveimg(img, image_path): 60 | img = img.clip(0, 255) 61 | cv2.imwrite(image_path, img) 62 | 63 | # Preprocessing ~ Tensor to Image 64 | def ttoi(tensor): 65 | # Add the means 66 | #ttoi_t = transforms.Compose([ 67 | # transforms.Normalize([-103.939, -116.779, -123.68],[1,1,1])]) 68 | 69 | # Remove the batch_size dimension 70 | tensor = tensor.squeeze() 71 | #img = ttoi_t(tensor) 72 | img = tensor.cpu().numpy() 73 | 74 | # Transpose from [C, H, W] -> [H, W, C] 75 | img = img.transpose(1, 2, 0) 76 | return img 77 | 78 | # Alternative loss functions experiments 79 | def total_variation(y): 80 | return torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + \ 81 | torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])) 82 | 83 | # VGG 84 | class VGG16(nn.Module): 85 | def __init__(self, vgg_path="models/vgg16-00b39a1b.pth"): 86 | super(VGG16, self).__init__() 87 | # Load VGG Skeleton, Pretrained Weights 88 | vgg16_features = models.vgg16(pretrained=False) 89 | vgg16_features.load_state_dict(torch.load(vgg_path), strict=False) 90 | self.features = vgg16_features.features 91 | 92 | # Turn-off Gradient History 93 | for param in self.features.parameters(): 94 | param.requires_grad = False 95 | 96 | def forward(self, x): 97 | layers = {'3': 'relu1_2', '8': 'relu2_2', '15': 'relu3_3', '22': 'relu4_3'} 98 | features = {} 99 | for name, layer in self.features._modules.items(): 100 | x = layer(x) 101 | if name in layers: 102 | features[layers[name]] = x 103 | if (name=='22'): 104 | break 105 | 106 | return features 107 | 108 | # Transformer 109 | class TransformerNetworkClass(nn.Module): 110 | """Feedforward Transformation Network without Tanh 111 | reference: https://arxiv.org/abs/1603.08155 112 | exact architecture: https://cs.stanford.edu/people/jcjohns/papers/fast-style/fast-style-supp.pdf 113 | """ 114 | def __init__(self): 115 | super(TransformerNetworkClass, self).__init__() 116 | self.ConvBlock = nn.Sequential( 117 | ConvolutionalLayer(3, 32, 9, 1), 118 | nn.ReLU(), 119 | ConvolutionalLayer(32, 64, 3, 2), 120 | nn.ReLU(), 121 | ConvolutionalLayer(64, 128, 3, 2), 122 | nn.ReLU() 123 | ) 124 | self.ResidualBlock = nn.Sequential( 125 | ResidualLayer(128, 3), 126 | ResidualLayer(128, 3), 127 | ResidualLayer(128, 3), 128 | ResidualLayer(128, 3), 129 | ResidualLayer(128, 3) 130 | ) 131 | self.DeconvBlock = nn.Sequential( 132 | DeconvolutionalLayer(128, 64, 3, 2, 1), 133 | nn.ReLU(), 134 | DeconvolutionalLayer(64, 32, 3, 2, 1), 135 | nn.ReLU(), 136 | ConvolutionalLayer(32, 3, 9, 1, norm="None") 137 | ) 138 | 139 | def forward(self, x): 140 | x = self.ConvBlock(x) 141 | x = self.ResidualBlock(x) 142 | out = self.DeconvBlock(x) 143 | return out 144 | 145 | class ConvolutionalLayer(nn.Module): 146 | def __init__(self, in_channels, out_channels, kernel_size, stride, norm="instance"): 147 | super(ConvolutionalLayer, self).__init__() 148 | # Padding Layers 149 | padding_size = kernel_size // 2 150 | self.reflection_pad = nn.ReflectionPad2d(padding_size) 151 | 152 | # Convolution Layer 153 | self.conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size, stride) 154 | 155 | # Normalization Layers 156 | self.norm_type = norm 157 | 158 | if (norm=="instance"): 159 | self.norm_layer = nn.InstanceNorm2d(out_channels, affine=True) 160 | elif (norm=="batch"): 161 | self.norm_layer = nn.BatchNorm2d(out_channels, affine=True) 162 | 163 | def forward(self, x): 164 | x = self.reflection_pad(x) 165 | x = self.conv_layer(x) 166 | if (self.norm_type=="None"): 167 | out = x 168 | else: 169 | out = self.norm_layer(x) 170 | return out 171 | 172 | class ResidualLayer(nn.Module): 173 | """ 174 | Deep Residual Learning for Image Recognition 175 | 176 | https://arxiv.org/abs/1512.03385 177 | """ 178 | def __init__(self, channels=128, kernel_size=3): 179 | super(ResidualLayer, self).__init__() 180 | self.conv1 = ConvolutionalLayer(channels, channels, kernel_size, stride=1) 181 | self.relu = nn.ReLU() 182 | self.conv2 = ConvolutionalLayer(channels, channels, kernel_size, stride=1) 183 | 184 | def forward(self, x): 185 | identity = x # preserve residual 186 | out = self.relu(self.conv1(x)) # 1st conv layer + activation 187 | out = self.conv2(out) # 2nd conv layer 188 | out = out + identity # add residual 189 | return out 190 | 191 | class DeconvolutionalLayer(nn.Module): 192 | def __init__(self, in_channels, out_channels, kernel_size, stride, output_padding, norm="instance"): 193 | super(DeconvolutionalLayer, self).__init__() 194 | 195 | # Transposed Convolution 196 | padding_size = kernel_size // 2 197 | self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding_size, output_padding) 198 | 199 | # Normalization Layers 200 | self.norm_type = norm 201 | 202 | if (norm=="instance"): 203 | self.norm_layer = nn.InstanceNorm2d(out_channels, affine=True) 204 | elif (norm=="batch"): 205 | self.norm_layer = nn.BatchNorm2d(out_channels, affine=True) 206 | 207 | def forward(self, x): 208 | x = self.conv_transpose(x) 209 | if (self.norm_type=="None"): 210 | out = x 211 | else: 212 | out = self.norm_layer(x) 213 | return out 214 | 215 | 216 | def train(): 217 | 218 | # Seeds 219 | torch.manual_seed(SEED) 220 | torch.cuda.manual_seed(SEED) 221 | np.random.seed(SEED) 222 | random.seed(SEED) 223 | 224 | # Device 225 | device = ("cuda" if torch.cuda.is_available() else "cpu") 226 | print(f"Train.py: Device is {device}") 227 | 228 | # Dataset and Dataloader 229 | transform = transforms.Compose([ 230 | transforms.Resize(TRAIN_IMAGE_SIZE), 231 | transforms.CenterCrop(TRAIN_IMAGE_SIZE), 232 | transforms.ToTensor(), 233 | transforms.Lambda(lambda x: x.mul(255)) 234 | ]) 235 | train_dataset = datasets.ImageFolder(DATASET_PATH, transform=transform) 236 | train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True) 237 | print(f"Train.py: train_loader length: {len(train_loader)}") 238 | 239 | # Load networks 240 | TransformerNetwork = TransformerNetworkClass().to(device) 241 | if MODEL != 'none': 242 | model_path = SAVE_MODEL_PATH + MODEL 243 | print(f"Loading {model_path} model") 244 | TransformerNetwork.load_state_dict(torch.load(model_path, map_location=device)) 245 | TransformerNetwork.to(device) 246 | VGG = VGG16(vgg_path=VGG_PATH).to(device) 247 | 248 | # Get Style Features 249 | imagenet_neg_mean = torch.tensor([-103.939, -116.779, -123.68], dtype=torch.float32).reshape(1,3,1,1).to(device) 250 | style_tensor = torch.load(STYLE_IMAGE_PATH).to(device) 251 | style_tensor = style_tensor.add(imagenet_neg_mean) 252 | B, C, H, W = style_tensor.shape 253 | style_features = VGG(style_tensor.expand([BATCH_SIZE, C, H, W])) 254 | style_gram = {} 255 | for key, value in style_features.items(): 256 | style_gram[key] = gram(value) 257 | 258 | # Optimizer settings 259 | optimizer = optim.AdamW(TransformerNetwork.parameters(), lr=ADAM_LR, fused=True, amsgrad=True) 260 | scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20, T_mult=2, eta_min=1e-6, last_epoch=-1) 261 | # Loss trackers 262 | content_loss_history = [] 263 | style_loss_history = [] 264 | total_loss_history = [] 265 | batch_content_loss_sum = 0 266 | batch_style_loss_sum = 0 267 | batch_total_loss_sum = 0 268 | # Optimization/Training Loop 269 | start_time = time.time() 270 | for epoch in range(NUM_EPOCHS): 271 | print("========Epoch {}/{}========".format(epoch+1, NUM_EPOCHS)) 272 | for batch_count, (content_batch, _) in enumerate(train_loader): 273 | optimizer.zero_grad() 274 | torch.cuda.empty_cache() 275 | # Generate images and get features 276 | content_batch = content_batch[:,[2,1,0]].to(device) 277 | generated_batch = TransformerNetwork(content_batch) 278 | content_features = VGG(content_batch.add(imagenet_neg_mean)) 279 | generated_features = VGG(generated_batch.add(imagenet_neg_mean)) 280 | 281 | # Content Loss 282 | MSELoss = nn.MSELoss().to(device) 283 | MAELoss = nn.L1Loss().to(device) 284 | content_loss = CONTENT_WEIGHT * MSELoss(generated_features['relu2_2'], content_features['relu2_2']) + MAELoss(generated_features['relu2_2'], content_features['relu2_2']) / 2 285 | batch_content_loss_sum += content_loss 286 | 287 | # Style Loss 288 | style_loss = 0.0 289 | for key, value in generated_features.items(): 290 | style_loss += MSELoss(gram(value), style_gram[key][:content_batch.shape[0]]) 291 | style_loss += MAELoss(gram(value), style_gram[key][:content_batch.shape[0]]) 292 | style_loss *= STYLE_WEIGHT/2 293 | batch_style_loss_sum += style_loss.item() 294 | 295 | # TV loss 296 | tv_loss = total_variation(generated_batch) * TV_WEIGHT 297 | 298 | # Total Loss 299 | total_loss = content_loss + style_loss + tv_loss 300 | batch_total_loss_sum += total_loss.item() 301 | 302 | total_loss.backward() 303 | optimizer.step() 304 | if batch_count % 50 == 0: 305 | scheduler.step() 306 | 307 | with torch.no_grad(): 308 | print(f'AdamW | iteration: {batch_count+1:03}, total loss={total_loss.item():12.4f}, content_loss={content_loss.item():12.4f}, style loss={style_loss.item():12.4f}, tv loss={tv_loss.item():12.4f}') 309 | 310 | # Save Model and Print Losses 311 | if (((batch_count)%SAVE_MODEL_EVERY == 0) or ((batch_count+1)==NUM_EPOCHS*len(train_loader))): 312 | # Print Losses 313 | print("========Iteration {}/{}========".format(batch_count+1, NUM_EPOCHS*len(train_loader))) 314 | print("\tContent Loss:\t{:.2f}".format(batch_content_loss_sum/(batch_count+1))) 315 | print("\tStyle Loss:\t{:.2f}".format(batch_style_loss_sum/(batch_count+1))) 316 | print("\tTotal Loss:\t{:.2f}".format(batch_total_loss_sum/(batch_count+1))) 317 | print("Time elapsed:\t{} seconds".format(time.time()-start_time)) 318 | 319 | # Save Model 320 | checkpoint_path = SAVE_MODEL_PATH + "checkpoint_" + str((batch_count+1)) + ".pth" 321 | torch.save(TransformerNetwork.state_dict(), checkpoint_path) 322 | print("Saved TransformerNetwork checkpoint file at {}".format(checkpoint_path)) 323 | 324 | # Save sample generated image 325 | sample_tensor = generated_batch[0].clone().detach().unsqueeze(dim=0) 326 | sample_image = ttoi(sample_tensor.clone().detach()) 327 | sample_image_path = SAVE_IMAGE_PATH + "sample0_" + str((batch_count+1)) + ".png" 328 | saveimg(sample_image, sample_image_path) 329 | print("Saved sample tranformed image at {}".format(sample_image_path)) 330 | 331 | # Save loss histories 332 | content_loss_history.append(batch_content_loss_sum/(batch_count+1)) 333 | style_loss_history.append(batch_style_loss_sum/(batch_count+1)) 334 | total_loss_history.append(batch_total_loss_sum/(batch_count+1)) 335 | 336 | 337 | 338 | 339 | 340 | 341 | stop_time = time.time() 342 | # Print loss histories 343 | print("Done Training the Transformer Network!") 344 | print("Training Time: {} seconds".format(stop_time-start_time)) 345 | print("========Content Loss========") 346 | print(content_loss_history) 347 | print("========Style Loss========") 348 | print(style_loss_history) 349 | print("========Total Loss========") 350 | print(total_loss_history) 351 | 352 | 353 | if __name__ == "__main__": 354 | train() 355 | 356 | -------------------------------------------------------------------------------- /vgg/vgg16-00b39a1b model goes here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zeroxoxo/ComfyUI-Fast-Style-Transfer/20363dc232300487bbbf736b3b493c06d7d84e23/vgg/vgg16-00b39a1b model goes here --------------------------------------------------------------------------------