├── models └── .gitkeep ├── requirements.txt ├── examples ├── cat.jpg ├── popova.jpg └── vangogh.jpg ├── workflows ├── wf_aespa.png ├── wf_cast.png ├── wf_efdm.png ├── wf_tssat.png ├── wf_microast.png ├── wf_unist_image.png ├── wf_aesfa_w_blend.png ├── wf_neural_neighbor.png ├── extra │ ├── extra_coral_color_transfer.png │ └── extra_coral_color_transfer.json ├── workflow_cast.json ├── workflow_aespa.json ├── workflow_microast.json ├── workflow_unist_image.json ├── workflow_tssat.json ├── workflow_neural_neighbor.json ├── workflow_unist_video.json ├── workflow_efdm.json └── workflow_aesfa_w_blend.json ├── example_outputs ├── aesfa.png ├── aespa.png ├── cast.png ├── efdm.png ├── tssat.png ├── microast.png ├── unist_image.png ├── neural_neighbor.png └── example_unist_video.png ├── pyproject.toml ├── module_efdm ├── sampler.py ├── efdm_model.py ├── function.py └── net.py ├── .github ├── workflows │ └── publish.yml └── FUNDING.yml ├── module_tssat ├── sampler.py └── tssat_model.py ├── constants.py ├── module_unist ├── transformer_modules.py ├── transformer_layers.py ├── unist_model.py ├── transformer_subLayers.py └── models.py ├── LICENSE ├── __init__.py ├── module_cast ├── cast_model.py └── net.py ├── module_extra └── function.py ├── module_neural_neighbor ├── imagePyramid.py ├── neural_neighbor_model.py ├── distance.py ├── vgg.py ├── zca.py ├── featureExtract.py └── colorization.py ├── module_aesfa ├── vgg19.py ├── model.py ├── aesfa_model.py ├── networks.py └── blocks.py ├── run_extra.py ├── module_microast ├── microast_model.py └── function.py ├── .gitignore ├── module_aespa ├── aespa_model.py └── utils.py └── README.md /models/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | scikit-image 4 | einops -------------------------------------------------------------------------------- /examples/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/examples/cat.jpg -------------------------------------------------------------------------------- /examples/popova.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/examples/popova.jpg -------------------------------------------------------------------------------- /examples/vangogh.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/examples/vangogh.jpg -------------------------------------------------------------------------------- /workflows/wf_aespa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/workflows/wf_aespa.png -------------------------------------------------------------------------------- /workflows/wf_cast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/workflows/wf_cast.png -------------------------------------------------------------------------------- /workflows/wf_efdm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/workflows/wf_efdm.png -------------------------------------------------------------------------------- /workflows/wf_tssat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/workflows/wf_tssat.png -------------------------------------------------------------------------------- /example_outputs/aesfa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/example_outputs/aesfa.png -------------------------------------------------------------------------------- /example_outputs/aespa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/example_outputs/aespa.png -------------------------------------------------------------------------------- /example_outputs/cast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/example_outputs/cast.png -------------------------------------------------------------------------------- /example_outputs/efdm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/example_outputs/efdm.png -------------------------------------------------------------------------------- /example_outputs/tssat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/example_outputs/tssat.png -------------------------------------------------------------------------------- /workflows/wf_microast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/workflows/wf_microast.png -------------------------------------------------------------------------------- /example_outputs/microast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/example_outputs/microast.png -------------------------------------------------------------------------------- /workflows/wf_unist_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/workflows/wf_unist_image.png -------------------------------------------------------------------------------- /workflows/wf_aesfa_w_blend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/workflows/wf_aesfa_w_blend.png -------------------------------------------------------------------------------- /example_outputs/unist_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/example_outputs/unist_image.png -------------------------------------------------------------------------------- /workflows/wf_neural_neighbor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/workflows/wf_neural_neighbor.png -------------------------------------------------------------------------------- /example_outputs/neural_neighbor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/example_outputs/neural_neighbor.png -------------------------------------------------------------------------------- /example_outputs/example_unist_video.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/example_outputs/example_unist_video.png -------------------------------------------------------------------------------- /workflows/extra/extra_coral_color_transfer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FuouM/ComfyUI-StyleTransferPlus/HEAD/workflows/extra/extra_coral_color_transfer.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comfyui-styletransferplus" 3 | description = "Nodes: Neural Neighbor, CAST, EFDM, MicroAST, UniST, AesPA, TSSAT, AesFA, ..." 4 | version = "1.0.3" 5 | license = {file = "LICENSE"} 6 | dependencies = ["torch", "torchvision", "scikit-image", "einops"] 7 | 8 | [project.urls] 9 | Repository = "https://github.com/FuouM/ComfyUI-StyleTransferPlus" 10 | # Used by Comfy Registry https://comfyregistry.org 11 | 12 | [tool.comfy] 13 | PublisherId = "fuoumarinas" 14 | DisplayName = "ComfyUI-StyleTransferPlus" 15 | Icon = "" 16 | -------------------------------------------------------------------------------- /module_efdm/sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils import data 3 | 4 | 5 | def InfiniteSampler(n): 6 | # i = 0 7 | i = n - 1 8 | order = np.random.permutation(n) 9 | while True: 10 | yield order[i] 11 | i += 1 12 | if i >= n: 13 | np.random.seed() 14 | order = np.random.permutation(n) 15 | i = 0 16 | 17 | 18 | class InfiniteSamplerWrapper(data.sampler.Sampler): 19 | def __init__(self, data_source): 20 | self.num_samples = len(data_source) 21 | 22 | def __iter__(self): 23 | return iter(InfiniteSampler(self.num_samples)) 24 | 25 | def __len__(self): 26 | return 2 ** 31 27 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to Comfy registry 2 | on: 3 | workflow_dispatch: 4 | push: 5 | branches: 6 | - main 7 | - master 8 | paths: 9 | - "pyproject.toml" 10 | 11 | permissions: 12 | issues: write 13 | 14 | jobs: 15 | publish-node: 16 | name: Publish Custom Node to registry 17 | runs-on: ubuntu-latest 18 | if: ${{ github.repository_owner == 'FuouM' }} 19 | steps: 20 | - name: Check out code 21 | uses: actions/checkout@v4 22 | - name: Publish Custom Node 23 | uses: Comfy-Org/publish-node-action@v1 24 | with: 25 | ## Add your own personal access token to your Github Repository secrets and reference it here. 26 | personal_access_token: ${{ secrets.REGISTRY_ACCESS_TOKEN }} 27 | -------------------------------------------------------------------------------- /module_tssat/sampler.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Aug 1 11:38:11 2020 4 | 5 | @author: ZJU 6 | """ 7 | 8 | import numpy as np 9 | from torch.utils import data 10 | 11 | 12 | def InfiniteSampler(n): 13 | # i = 0 14 | i = n - 1 15 | order = np.random.permutation(n) 16 | while True: 17 | yield order[i] 18 | i += 1 19 | if i >= n: 20 | np.random.seed() 21 | order = np.random.permutation(n) 22 | i = 0 23 | 24 | 25 | class InfiniteSamplerWrapper(data.sampler.Sampler): 26 | def __init__(self, data_source): 27 | self.num_samples = len(data_source) 28 | 29 | def __iter__(self): 30 | return iter(InfiniteSampler(self.num_samples)) 31 | 32 | def __len__(self): 33 | return 2**31 34 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: fuoumarinas # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 12 | polar: # Replace with a single Polar username 13 | buy_me_a_coffee: # Replace with a single Buy Me a Coffee username 14 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 15 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | CAST_TYPES = ["CAST", "UCAST"] 2 | CAST_DEFAULT = "CAST" 3 | 4 | CAST_NET_AE_PATH = "latest_net_AE.pth" 5 | CAST_NET_DEC_B_PATH = "latest_net_Dec_B.pth" 6 | CAST_VGG_PATH = "models/vgg_normalised.pth" 7 | 8 | EFDM_STYLE_TYPES = ["adain", "adamean", "adastd", "efdm", "hm"] 9 | EFDM_DEFAULT = "efdm" 10 | EFDM_PATH = "models/hm_decoder_iter_160000.pth" 11 | 12 | MICROAST_CONTENT_ENCODER_PATH = "models/microast/content_encoder_iter_160000.pth.tar" 13 | MICROAST_STYLE_ENCODER_PATH = "models/microast/style_encoder_iter_160000.pth.tar" 14 | MICROAST_DECODER_PATH = "models/microast/decoder_iter_160000.pth.tar" 15 | MICROAST_MODULATOR_PATH = "models/microast/modulator_iter_160000.pth.tar" 16 | 17 | UNIST_PATH = "models/unist/UniST_model.pt" 18 | UNIST_DEC_PATH = "models/unist/dec_r41.pth" 19 | UNIST_ENC_PATH = "models/unist/vgg_r41.pth" 20 | 21 | VGG_NORM_CONV51_PATH = "models/vgg_normalised_conv5_1.pth" 22 | AESPA_DEC_PATH = "models/aespa/dec_model.pth" 23 | AESPA_TRANSFORMER_PATH = "models/aespa/transformer_model.pth" 24 | 25 | TSSAT_DEC_PATH = "models/tssat/decoder_iter_160000.pth" 26 | 27 | AESFA_PATH = "models/aesfa/main.pth" -------------------------------------------------------------------------------- /module_unist/transformer_modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | __author__ = "Yu-Hsiang Huang" 8 | 9 | 10 | class ScaledDotProductAttention(nn.Module): 11 | """Scaled Dot-Product Attention""" 12 | 13 | def __init__(self, temperature, attn_dropout=0.1): 14 | super().__init__() 15 | self.temperature = temperature 16 | self.dropout = nn.Dropout(attn_dropout) 17 | 18 | def forward(self, q, k, v, mask=None): 19 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 20 | 21 | if mask is not None: 22 | attn = attn.masked_fill(mask == 0, -1e9) 23 | 24 | attn = self.dropout(F.softmax(attn, dim=-1)) 25 | output = torch.matmul(attn, v) 26 | 27 | return output, attn 28 | 29 | 30 | class Attention(nn.Module): 31 | def forward(self, query, key, value): 32 | scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1)) 33 | 34 | p_attn = F.softmax(scores, dim=-1) 35 | p_val = torch.matmul(p_attn, value) 36 | return p_val, p_attn 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Fuou Marinas 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | from .run import ( 2 | AesFA, 3 | AesFAStyleBlend, 4 | CAST, 5 | EFDM, 6 | TSSAT, 7 | AesPA, 8 | MicroAST, 9 | NeuralNeighbor, 10 | UniST, 11 | UniST_Video, 12 | ) 13 | from .run_extra import CoralColorTransfer 14 | 15 | NODE_CLASS_MAPPINGS = { 16 | "NeuralNeighbor": NeuralNeighbor, 17 | "CAST": CAST, 18 | "EFDM": EFDM, 19 | "MicroAST": MicroAST, 20 | "CoralColorTransfer": CoralColorTransfer, 21 | "UniST": UniST, 22 | "UniST_Video": UniST_Video, 23 | "AesPA": AesPA, 24 | "TSSAT": TSSAT, 25 | "AESFA": AesFA, 26 | "AesFAStyleBlend": AesFAStyleBlend, 27 | } 28 | 29 | NODE_DISPLAY_NAME_MAPPINGS = { 30 | "NeuralNeighbor": "Neural Neighbor", 31 | "CAST": "CAST", 32 | "EFDM": "EFDM", 33 | "MicroAST": "MicroAST", 34 | "CoralColorTransfer": "Coral Color Transfer", 35 | "UniST": "UniST", 36 | "UniST_Video": "UniST Video", 37 | "AesPA": "AesPA-Net", 38 | "TSSAT": "TSSAT", 39 | "AESFA": "AESFA", 40 | "AesFAStyleBlend": "AesFA Styles Blending", 41 | } 42 | 43 | 44 | __all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] 45 | -------------------------------------------------------------------------------- /module_cast/cast_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from . import net 4 | 5 | 6 | class MODEL_CAST: 7 | def __init__( 8 | self, vgg_path: str, net_ae_path: str, net_dec_b_path: str, device 9 | ) -> None: 10 | vgg = net.vgg 11 | vgg.load_state_dict(torch.load(vgg_path)) 12 | vgg = torch.nn.Sequential(*list(vgg.children())[:31]) 13 | 14 | self.netAE = net.ADAIN_Encoder(vgg) 15 | self.netDec_B = net.Decoder() 16 | 17 | self.netAE.load_state_dict(load_a_ckpt(net_ae_path)) 18 | self.netDec_B.load_state_dict(load_a_ckpt(net_dec_b_path)) 19 | 20 | self.netAE.to(device).eval() 21 | self.netDec_B.to(device).eval() 22 | 23 | 24 | def inference_ucast( 25 | src_img: torch.Tensor, 26 | style_img: torch.Tensor, 27 | netAE, 28 | netDec_B, 29 | device, 30 | ): 31 | real_A = src_img.permute(0, 3, 1, 2).to(device) 32 | real_B = style_img.permute(0, 3, 1, 2).to(device) 33 | 34 | with torch.no_grad(): 35 | real_A_feat = netAE.forward(real_A, real_B) 36 | fake_B = netDec_B.forward(real_A_feat) 37 | 38 | if device.type != "cpu": 39 | torch.cuda.empty_cache() 40 | 41 | return fake_B 42 | 43 | 44 | def load_a_ckpt(ckpt_path: str): 45 | state_dict = torch.load(ckpt_path) 46 | if hasattr(state_dict, "_metadata"): 47 | del state_dict._metadata 48 | return state_dict 49 | -------------------------------------------------------------------------------- /module_extra/function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def coral(source: torch.Tensor, target: torch.Tensor): 4 | # assume both source and target are 3D array (C, H, W) 5 | # Note: flatten -> f 6 | 7 | source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source) 8 | source_f_norm = ( 9 | source_f - source_f_mean.expand_as(source_f) 10 | ) / source_f_std.expand_as(source_f) 11 | source_f_cov_eye = torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3) 12 | 13 | target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target) 14 | target_f_norm = ( 15 | target_f - target_f_mean.expand_as(target_f) 16 | ) / target_f_std.expand_as(target_f) 17 | target_f_cov_eye = torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3) 18 | 19 | source_f_norm_transfer = torch.mm( 20 | _mat_sqrt(target_f_cov_eye), 21 | torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)), source_f_norm), 22 | ) 23 | 24 | source_f_transfer = source_f_norm_transfer * target_f_std.expand_as( 25 | source_f_norm 26 | ) + target_f_mean.expand_as(source_f_norm) 27 | 28 | return source_f_transfer.view(source.size()) 29 | 30 | def _calc_feat_flatten_mean_std(feat): 31 | # takes 3D feat (C, H, W), return mean and std of array within channels 32 | assert feat.size()[0] == 3 33 | assert isinstance(feat, torch.FloatTensor) 34 | feat_flatten = feat.view(3, -1) 35 | mean = feat_flatten.mean(dim=-1, keepdim=True) 36 | std = feat_flatten.std(dim=-1, keepdim=True) 37 | return feat_flatten, mean, std 38 | 39 | def _mat_sqrt(x): 40 | U, D, V = torch.svd(x) 41 | return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t()) -------------------------------------------------------------------------------- /module_neural_neighbor/imagePyramid.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | def dec_lap_pyr(x, levs): 5 | """constructs batch of 'levs' level laplacian pyramids from x 6 | Inputs: 7 | x -- BxCxHxW pytorch tensor 8 | levs -- integer number of pyramid levels to construct 9 | Outputs: 10 | pyr -- a list of pytorch tensors, each representing a pyramid level, 11 | pyr[0] contains the finest level, pyr[-1] the coarsest 12 | """ 13 | pyr = [] 14 | cur = x # Initialize approx. coefficients with original image 15 | for i in range(levs): 16 | # Construct and store detail coefficients from current approx. coefficients 17 | h = cur.size(2) 18 | w = cur.size(3) 19 | x_small = F.interpolate(cur, (h // 2, w // 2), mode="bilinear") 20 | x_back = F.interpolate(x_small, (h, w), mode="bilinear") 21 | lap = cur - x_back 22 | pyr.append(lap) 23 | 24 | # Store new approx. coefficients 25 | cur = x_small 26 | 27 | pyr.append(cur) 28 | 29 | return pyr 30 | 31 | 32 | def syn_lap_pyr(pyr): 33 | """collapse batch of laplacian pyramids stored in list of pytorch tensors 34 | 'pyr' into a single tensor. 35 | Inputs: 36 | pyr -- list of pytorch tensors, where pyr[i] has size BxCx(H/(2**i)x(W/(2**i)) 37 | Outpus: 38 | x -- a BxCxHxW pytorch tensor 39 | """ 40 | cur = pyr[-1] 41 | levs = len(pyr) 42 | 43 | for i in range(0, levs - 1)[::-1]: 44 | # Create new approximation coefficients from current approx. and detail coefficients 45 | # at next finest pyramid level 46 | up_x = pyr[i].size(2) 47 | up_y = pyr[i].size(3) 48 | cur = pyr[i] + F.interpolate(cur, (up_x, up_y), mode="bilinear") 49 | x = cur 50 | 51 | return x 52 | -------------------------------------------------------------------------------- /module_neural_neighbor/neural_neighbor_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .stylize import produce_stylization 5 | from .vgg import Vgg16Pretrained 6 | 7 | 8 | def resize(src: torch.Tensor, target_size: int, scale_long=True): 9 | # [B, C, H, W] 10 | src_h, src_w = src.shape[2:] 11 | if scale_long: 12 | factor = target_size / max(src_h, src_w) 13 | else: 14 | factor = target_size / min(src_h, src_w) 15 | 16 | new_h = int(src_h * factor) 17 | new_w = int(src_w * factor) 18 | 19 | new_x = F.interpolate(src, (new_h, new_w), mode="bilinear", align_corners=True) 20 | 21 | return new_x 22 | 23 | 24 | def inference_neural_neighbor( 25 | src_img: torch.Tensor, 26 | style_img: torch.Tensor, 27 | device, 28 | size=512, 29 | scale_long=True, 30 | flip=False, 31 | content_loss=False, 32 | colorize=True, 33 | alpha=0.75, 34 | max_iter=200, 35 | ): 36 | # [B, H, W, C] 37 | content_im_orig = resize( 38 | src_img.permute(0, 3, 1, 2).contiguous(), size, scale_long 39 | ).to(device) 40 | style_im_orig = resize( 41 | style_img.permute(0, 3, 1, 2).contiguous(), size, scale_long 42 | ).to(device) 43 | content_weight = 1 - alpha 44 | 45 | max_scls = 4 if size == 512 else 5 46 | 47 | cnn = Vgg16Pretrained().to(device) 48 | 49 | def phi(x, y, z): 50 | return cnn.forward(x, inds=y, concat=z) 51 | 52 | torch.cuda.synchronize() 53 | output = produce_stylization( 54 | content_im_orig, 55 | style_im_orig, 56 | phi, 57 | max_iter=max_iter, 58 | lr=2e-3, 59 | content_weight=content_weight, 60 | max_scls=max_scls, 61 | flip_aug=flip, 62 | content_loss=content_loss, 63 | dont_colorize=not colorize, 64 | device=device, 65 | ) 66 | torch.cuda.synchronize() 67 | 68 | output = torch.clip(output, 0, 1).permute(0, 2, 3, 1) 69 | 70 | return output 71 | -------------------------------------------------------------------------------- /module_aesfa/vgg19.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | vgg = nn.Sequential( 4 | nn.Conv2d(3, 3, (1, 1)), 5 | nn.ReflectionPad2d((1, 1, 1, 1)), 6 | nn.Conv2d(3, 64, (3, 3)), 7 | nn.ReLU(), # relu1-1 8 | nn.ReflectionPad2d((1, 1, 1, 1)), 9 | nn.Conv2d(64, 64, (3, 3)), 10 | nn.ReLU(), # relu1-2 11 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 12 | nn.ReflectionPad2d((1, 1, 1, 1)), 13 | nn.Conv2d(64, 128, (3, 3)), 14 | nn.ReLU(), # relu2-1 15 | nn.ReflectionPad2d((1, 1, 1, 1)), 16 | nn.Conv2d(128, 128, (3, 3)), 17 | nn.ReLU(), # relu2-2 18 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 19 | nn.ReflectionPad2d((1, 1, 1, 1)), 20 | nn.Conv2d(128, 256, (3, 3)), 21 | nn.ReLU(), # relu3-1 22 | nn.ReflectionPad2d((1, 1, 1, 1)), 23 | nn.Conv2d(256, 256, (3, 3)), 24 | nn.ReLU(), # relu3-2 25 | nn.ReflectionPad2d((1, 1, 1, 1)), 26 | nn.Conv2d(256, 256, (3, 3)), 27 | nn.ReLU(), # relu3-3 28 | nn.ReflectionPad2d((1, 1, 1, 1)), 29 | nn.Conv2d(256, 256, (3, 3)), 30 | nn.ReLU(), # relu3-4 31 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 32 | nn.ReflectionPad2d((1, 1, 1, 1)), 33 | nn.Conv2d(256, 512, (3, 3)), 34 | nn.ReLU(), # relu4-1, this is the last layer used 35 | nn.ReflectionPad2d((1, 1, 1, 1)), 36 | nn.Conv2d(512, 512, (3, 3)), 37 | nn.ReLU(), # relu4-2 38 | nn.ReflectionPad2d((1, 1, 1, 1)), 39 | nn.Conv2d(512, 512, (3, 3)), 40 | nn.ReLU(), # relu4-3 41 | nn.ReflectionPad2d((1, 1, 1, 1)), 42 | nn.Conv2d(512, 512, (3, 3)), 43 | nn.ReLU(), # relu4-4 44 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 45 | nn.ReflectionPad2d((1, 1, 1, 1)), 46 | nn.Conv2d(512, 512, (3, 3)), 47 | nn.ReLU(), # relu5-1 48 | nn.ReflectionPad2d((1, 1, 1, 1)), 49 | nn.Conv2d(512, 512, (3, 3)), 50 | nn.ReLU(), # relu5-2 51 | nn.ReflectionPad2d((1, 1, 1, 1)), 52 | nn.Conv2d(512, 512, (3, 3)), 53 | nn.ReLU(), # relu5-3 54 | nn.ReflectionPad2d((1, 1, 1, 1)), 55 | nn.Conv2d(512, 512, (3, 3)), 56 | nn.ReLU(), # relu5-4 57 | ) 58 | 59 | -------------------------------------------------------------------------------- /module_aesfa/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from . import networks 5 | 6 | 7 | class AesFA_test(nn.Module): 8 | def __init__( 9 | self, alpha_in, alpha_out, style_kernel, input_nc, nf, output_nc, freq_ratio 10 | ): 11 | super(AesFA_test, self).__init__() 12 | 13 | self.netE = networks.define_network( 14 | "Encoder", 15 | alpha_in, 16 | alpha_out, 17 | style_kernel, 18 | input_nc, 19 | nf, 20 | output_nc, 21 | freq_ratio, 22 | ) 23 | self.netS = networks.define_network( 24 | "Encoder", 25 | alpha_in, 26 | alpha_out, 27 | style_kernel, 28 | input_nc, 29 | nf, 30 | output_nc, 31 | freq_ratio, 32 | ) 33 | self.netG = networks.define_network( 34 | "Generator", 35 | alpha_in, 36 | alpha_out, 37 | style_kernel, 38 | input_nc, 39 | nf, 40 | output_nc, 41 | freq_ratio, 42 | ) 43 | 44 | def forward(self, real_A, real_B): 45 | with torch.no_grad(): 46 | content_A = self.netE.forward_test(real_A, "content") 47 | style_B = self.netS.forward_test(real_B, "style") 48 | # if freq: 49 | # trs_AtoB, trs_AtoB_high, trs_AtoB_low = self.netG(content_A, style_B) 50 | # end = time.time() 51 | # during = end - start 52 | # return trs_AtoB, trs_AtoB_high, trs_AtoB_low, during 53 | # else: 54 | trs_AtoB = self.netG.forward_test(content_A, style_B) 55 | return trs_AtoB 56 | 57 | def style_blending(self, real_A, real_B_1, real_B_2): 58 | with torch.no_grad(): 59 | content_A = self.netE.forward_test(real_A, "content") 60 | style_B1_h = self.netS.forward_test(real_B_1, "style")[0] 61 | style_B2_l = self.netS.forward_test(real_B_2, "style")[1] 62 | style_B = style_B1_h, style_B2_l 63 | 64 | trs_AtoB = self.netG.forward_test(content_A, style_B) 65 | 66 | return trs_AtoB 67 | -------------------------------------------------------------------------------- /module_unist/transformer_layers.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .transformer_subLayers import MultiHeadAttention, PositionwiseFeedForward 4 | 5 | 6 | class EncoderLayer(nn.Module): 7 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 8 | super(EncoderLayer, self).__init__() 9 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 10 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 11 | 12 | def forward(self, enc_input, slf_attn_mask=None): 13 | enc_output, enc_slf_attn = self.slf_attn( 14 | enc_input, enc_input, enc_input, mask=slf_attn_mask 15 | ) 16 | enc_output = self.pos_ffn(enc_output) 17 | return enc_output 18 | 19 | 20 | class EncoderLayer_cross(nn.Module): 21 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 22 | super(EncoderLayer_cross, self).__init__() 23 | self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 24 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 25 | 26 | def forward(self, dec_input, enc_input, slf_attn_mask=None): 27 | enc_output, enc_slf_attn = self.slf_attn( 28 | dec_input, enc_input, enc_input, mask=slf_attn_mask 29 | ) 30 | enc_output = self.pos_ffn(enc_output) 31 | return enc_output 32 | 33 | 34 | class DecoderLayer_style(nn.Module): 35 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 36 | super(DecoderLayer_style, self).__init__() 37 | self.enc_attn1 = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 38 | self.enc_attn2 = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout) 39 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 40 | 41 | # style decoder no need mask 42 | def forward(self, dec_input, enc_input): 43 | dec_output, dec_slf_attn = self.enc_attn1( 44 | dec_input, enc_input, enc_input, mask=None 45 | ) 46 | dec_output, dec_enc_attn = self.enc_attn2( 47 | dec_output, enc_input, enc_input, mask=None 48 | ) 49 | dec_output = self.pos_ffn(dec_output) 50 | return dec_output 51 | -------------------------------------------------------------------------------- /run_extra.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import torch 4 | 5 | from .module_extra.function import coral 6 | 7 | base_dir = Path(__file__).resolve().parent 8 | 9 | 10 | class CoralColorTransfer: 11 | @classmethod 12 | def INPUT_TYPES(s): 13 | return { 14 | "required": { 15 | "src_img": ("IMAGE",), 16 | "style_img": ("IMAGE",), 17 | }, 18 | } 19 | 20 | RETURN_TYPES = ("IMAGE",) 21 | RETURN_NAMES = ("res_img",) 22 | FUNCTION = "todo" 23 | CATEGORY = "StyleTransferPlus/Extra" 24 | 25 | def todo( 26 | self, 27 | src_img: torch.Tensor, 28 | style_img: torch.Tensor, 29 | ): 30 | print(f"{src_img.shape=}") 31 | print(f"{style_img.shape=}") 32 | 33 | device = torch.device("cpu") 34 | 35 | src_img = src_img.permute(0, 3, 1, 2).to(device) # [B, C, H, W] 36 | style_img = style_img.permute(0, 3, 1, 2).to(device) 37 | 38 | result = [] 39 | if src_img.shape[0] == 1 and style_img.shape[0] == 1: 40 | # Case 1: Single source image, single style image 41 | result = [coral(src_img[0], style_img[0]).unsqueeze(0)] 42 | elif src_img.shape[0] == 1: 43 | # Case 2: Single source image, multiple style images 44 | num_frames = style_img.shape[0] 45 | for i in range(num_frames): 46 | transferred = coral(src_img[0], style_img[i]) 47 | result.append(transferred.unsqueeze(0)) 48 | elif style_img.shape[0] == 1: 49 | # Case 3: Multiple source images, single style image 50 | num_frames = src_img.shape[0] 51 | for i in range(num_frames): 52 | transferred = coral(src_img[i], style_img[0]) 53 | result.append(transferred.unsqueeze(0)) 54 | else: 55 | # Case 4: Multiple source images, multiple style images 56 | num_frames = min(src_img.shape[0], style_img.shape[0]) 57 | for i in range(num_frames): 58 | transferred = coral(src_img[i], style_img[i]) 59 | result.append(transferred.unsqueeze(0)) 60 | 61 | res_tensor = torch.cat(result, dim=0).permute(0, 2, 3, 1) 62 | 63 | return (res_tensor,) 64 | -------------------------------------------------------------------------------- /module_microast/microast_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchvision import transforms 4 | 5 | from . import net_microAST 6 | 7 | 8 | class MODEL_MICROAST: 9 | def __init__( 10 | self, 11 | content_dec_path: str, 12 | style_enc_path: str, 13 | modulator_path: str, 14 | decoder_path: str, 15 | device, 16 | ) -> None: 17 | content_encoder = net_microAST.Encoder() 18 | style_encoder = net_microAST.Encoder() 19 | modulator = net_microAST.Modulator() 20 | decoder = net_microAST.Decoder() 21 | 22 | content_encoder.eval() 23 | style_encoder.eval() 24 | modulator.eval() 25 | decoder.eval() 26 | 27 | content_encoder.load_state_dict(torch.load(content_dec_path)) 28 | style_encoder.load_state_dict(torch.load(style_enc_path)) 29 | modulator.load_state_dict(torch.load(modulator_path)) 30 | decoder.load_state_dict(torch.load(decoder_path)) 31 | 32 | self.network = net_microAST.TestNet( 33 | content_encoder, style_encoder, modulator, decoder 34 | ) 35 | self.network.to(device) 36 | 37 | 38 | def inference_microast( 39 | src_img: torch.Tensor, 40 | style_img: torch.Tensor, 41 | size: int, 42 | do_crop: bool, 43 | alpha: float, 44 | device, 45 | network, 46 | ): 47 | content_tf = test_transform(size, do_crop) 48 | 49 | content = content_tf(src_img).to(device) 50 | style = content_tf(style_img).to(device) 51 | style = F.interpolate( 52 | style, 53 | size=(content.shape[2], content.shape[3]), 54 | mode="bilinear", 55 | align_corners=False, 56 | ) 57 | 58 | # print(f"Resized: {content.shape=}") 59 | # print(f"Resized: {style.shape=}") 60 | 61 | torch.cuda.synchronize() 62 | output = network(content, style, alpha) 63 | torch.cuda.synchronize() 64 | 65 | return output 66 | 67 | 68 | def test_transform(size, crop): 69 | transform_list = [] 70 | if size != 0: 71 | transform_list.append(transforms.Resize(size)) 72 | if crop: 73 | transform_list.append(transforms.CenterCrop(size)) 74 | transform = transforms.Compose(transform_list) 75 | # [B, C, H, W] 76 | return transform 77 | -------------------------------------------------------------------------------- /module_neural_neighbor/distance.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def center(x): 5 | """Subtract the mean of 'x' over leading dimension""" 6 | return x - torch.mean(x, 0, keepdim=True) 7 | 8 | 9 | def pairwise_distances_cos(x, y): 10 | """Compute all pairwise cosine distances between rows of matrix 'x' and matrix 'y' 11 | Inputs: 12 | x -- NxD pytorch tensor 13 | y -- MxD pytorch tensor 14 | Outputs: 15 | d -- NxM pytorch tensor where d[i,j] is the cosine distance between 16 | the vector at row i of matrix 'x' and the vector at row j of 17 | matrix 'y' 18 | """ 19 | assert x.size(1) == y.size( 20 | 1 21 | ), "can only compute distance between vectors of same length" 22 | assert (len(x.size()) == 2) and (len(y.size()) == 2), ( 23 | "pairwise distance computation" " assumes input tensors are matrices" 24 | ) 25 | 26 | x_norm = torch.sqrt((x**2).sum(1).view(-1, 1)) 27 | y_norm = torch.sqrt((y**2).sum(1).view(-1, 1)) 28 | y_t = torch.transpose(y / y_norm, 0, 1) 29 | 30 | d = 1.0 - torch.mm(x / x_norm, y_t) 31 | return d 32 | 33 | 34 | def pairwise_distances_sq_l2(x, y): 35 | """Compute all pairwise squared l2 distances between rows of matrix 'x' and matrix 'y' 36 | Inputs: 37 | x -- NxD pytorch tensor 38 | y -- MxD pytorch tensor 39 | Outputs: 40 | d -- NxM pytorch tensor where d[i,j] is the squared l2 distance between 41 | the vector at row i of matrix 'x' and the vector at row j of 42 | matrix 'y' 43 | """ 44 | assert x.size(1) == y.size( 45 | 1 46 | ), "can only compute distance between vectors of same length" 47 | assert (len(x.size()) == 2) and (len(y.size()) == 2), ( 48 | "pairwise distance computation" " assumes input tensors are matrices" 49 | ) 50 | 51 | x_norm = (x**2).sum(1).view(-1, 1) 52 | y_t = torch.transpose(y, 0, 1) 53 | y_norm = (y**2).sum(1).view(1, -1) 54 | 55 | d = -2.0 * torch.mm(x, y_t) 56 | d += x_norm 57 | d += y_norm 58 | 59 | return d 60 | 61 | 62 | def pairwise_distances_l2(x, y): 63 | """Compute all pairwise l2 distances between rows of 'x' and 'y', 64 | thresholds minimum of squared l2 distance for stability of sqrt 65 | """ 66 | d = torch.clamp(pairwise_distances_sq_l2(x, y), min=1e-8) 67 | return torch.sqrt(d) 68 | 69 | 70 | def pairwise_distances_l2_center(x, y): 71 | """subtracts mean row from 'x' and 'y' before computing pairwise l2 distance between all rows""" 72 | return pairwise_distances_l2(center(x), center(y)) 73 | 74 | 75 | def pairwise_distances_cos_center(x, y): 76 | """subtracts mean row from 'x' and 'y' before computing pairwise cosine distance between all rows""" 77 | return pairwise_distances_cos(center(x), center(y)) 78 | -------------------------------------------------------------------------------- /module_neural_neighbor/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchvision import models 4 | 5 | 6 | class Vgg16Pretrained(torch.nn.Module): 7 | def __init__(self, requires_grad=False): 8 | super(Vgg16Pretrained, self).__init__() 9 | 10 | vgg_pretrained_features = models.vgg16( 11 | weights=models.VGG16_Weights.DEFAULT 12 | ).features 13 | 14 | self.vgg_layers = vgg_pretrained_features 15 | self.slice1 = torch.nn.Sequential() 16 | self.slice2 = torch.nn.Sequential() 17 | self.slice3 = torch.nn.Sequential() 18 | self.slice4 = torch.nn.Sequential() 19 | for x in range(1): 20 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 21 | for x in range(1, 9): 22 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 23 | for x in range(9, 16): 24 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 25 | for x in range(16, 23): 26 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 27 | if not requires_grad: 28 | for param in self.parameters(): 29 | param.requires_grad = False 30 | 31 | def forward(self, x_in, inds=[1, 3, 6, 8, 11, 13, 15, 22, 29], concat=True): 32 | x = x_in.clone() # prevent accidentally modifying input in place 33 | # Preprocess input according to original imagenet training 34 | mean = [0.485, 0.456, 0.406] 35 | std = [0.229, 0.224, 0.225] 36 | for i in range(3): 37 | x[:, i : (i + 1), :, :] = (x[:, i : (i + 1), :, :] - mean[i]) / std[i] 38 | 39 | # Get hidden state at layers specified by 'inds' 40 | l2 = [] 41 | if -1 in inds: 42 | l2.append(x_in) 43 | 44 | # Only need to run network until we get to the max depth we want outputs from 45 | for i in range(max(inds) + 1): 46 | x = self.vgg_layers[i].forward(x) 47 | if i in inds: 48 | l2.append(x) 49 | 50 | # Concatenate hidden states if desired (after upsampling to spatial size of largest output) 51 | if concat: 52 | if len(l2) > 1: 53 | zi_list = [] 54 | max_h = l2[0].size(2) 55 | max_w = l2[0].size(3) 56 | for zi in l2: 57 | if len(zi_list) == 0: 58 | zi_list.append(zi) 59 | else: 60 | zi_list.append( 61 | F.interpolate(zi, (max_h, max_w), mode="bilinear") 62 | ) 63 | 64 | z = torch.cat(zi_list, 1) 65 | else: # don't bother doing anything if only returning one hidden state 66 | z = l2[0] 67 | else: # Otherwise return list of hidden states 68 | z = l2 69 | 70 | return z 71 | -------------------------------------------------------------------------------- /module_microast/function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def calc_mean_std(feat, eps=1e-5): 5 | # eps is a small value added to the variance to avoid divide-by-zero. 6 | size = feat.size() 7 | assert len(size) == 4 8 | N, C = size[:2] 9 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 10 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 11 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 12 | return feat_mean, feat_std 13 | 14 | 15 | def normalization(content_feat): 16 | size = content_feat.size() 17 | content_mean, content_std = calc_mean_std(content_feat) 18 | 19 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand( 20 | size 21 | ) 22 | return normalized_feat 23 | 24 | 25 | def adaptive_instance_normalization(content_feat, style_feat): 26 | assert content_feat.size()[:2] == style_feat.size()[:2] 27 | size = content_feat.size() 28 | style_mean, style_std = calc_mean_std(style_feat) 29 | content_mean, content_std = calc_mean_std(content_feat) 30 | 31 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand( 32 | size 33 | ) 34 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 35 | 36 | 37 | def _calc_feat_flatten_mean_std(feat): 38 | # takes 3D feat (C, H, W), return mean and std of array within channels 39 | assert feat.size()[0] == 3 40 | assert isinstance(feat, torch.FloatTensor) 41 | feat_flatten = feat.view(3, -1) 42 | mean = feat_flatten.mean(dim=-1, keepdim=True) 43 | std = feat_flatten.std(dim=-1, keepdim=True) 44 | return feat_flatten, mean, std 45 | 46 | 47 | def _mat_sqrt(x): 48 | U, D, V = torch.svd(x) 49 | return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t()) 50 | 51 | 52 | def coral(source, target): 53 | # assume both source and target are 3D array (C, H, W) 54 | # Note: flatten -> f 55 | 56 | source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source) 57 | source_f_norm = ( 58 | source_f - source_f_mean.expand_as(source_f) 59 | ) / source_f_std.expand_as(source_f) 60 | source_f_cov_eye = torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3) 61 | 62 | target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target) 63 | target_f_norm = ( 64 | target_f - target_f_mean.expand_as(target_f) 65 | ) / target_f_std.expand_as(target_f) 66 | target_f_cov_eye = torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3) 67 | 68 | source_f_norm_transfer = torch.mm( 69 | _mat_sqrt(target_f_cov_eye), 70 | torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)), source_f_norm), 71 | ) 72 | 73 | source_f_transfer = source_f_norm_transfer * target_f_std.expand_as( 74 | source_f_norm 75 | ) + target_f_mean.expand_as(source_f_norm) 76 | 77 | return source_f_transfer.view(source.size()) 78 | -------------------------------------------------------------------------------- /module_aesfa/aesfa_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | 4 | from .model import AesFA_test 5 | 6 | 7 | class MODEL_AESFA: 8 | def __init__(self, checkpoint_path: str, device) -> None: 9 | self.aesfa_model = AesFA_test( 10 | alpha_in=0.5, # input ratio of low-frequency channel 11 | alpha_out=0.5, # output ratio of low-frequency channel 12 | style_kernel=3, # size of style kernel 13 | input_nc=3, # of input image channel 14 | nf=64, # of feature map channel after Encoder first layer 15 | output_nc=3, # of output image channel 16 | freq_ratio=[1, 1], # [high, low] ratio at the last layer 17 | ) 18 | 19 | dict_model = torch.load(checkpoint_path) 20 | 21 | self.aesfa_model.netE.load_state_dict(dict_model["netE"]) 22 | self.aesfa_model.netS.load_state_dict(dict_model["netS"]) 23 | self.aesfa_model.netG.load_state_dict(dict_model["netG"]) 24 | 25 | self.aesfa_model.to(device) 26 | 27 | 28 | def inference_aesfa( 29 | src_img: torch.Tensor, 30 | style_img: torch.Tensor, 31 | device, 32 | size: int, 33 | do_crop: bool, 34 | model: MODEL_AESFA, 35 | ): 36 | content_tf = test_transform( 37 | size, do_crop, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 38 | ) 39 | content = content_tf(src_img).to(device) 40 | style = content_tf(style_img).to(device) 41 | 42 | stylized_image = model.aesfa_model.forward(content, style) 43 | output = un_normalize_batch(stylized_image) 44 | 45 | return output 46 | 47 | 48 | def inference_aesfa_style_blend( 49 | src_img: torch.Tensor, 50 | style_hi: torch.Tensor, 51 | style_lo: torch.Tensor, 52 | device, 53 | size: int, 54 | do_crop: bool, 55 | model: MODEL_AESFA, 56 | ): 57 | content_tf = test_transform( 58 | size, do_crop, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 59 | ) 60 | content = content_tf(src_img).to(device) 61 | style_high = content_tf(style_hi).to(device) 62 | style_low = content_tf(style_lo).to(device) 63 | 64 | stylized_image = model.aesfa_model.style_blending(content, style_high, style_low) 65 | output = un_normalize_batch(stylized_image) 66 | 67 | return output 68 | 69 | 70 | def test_transform(size, crop, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): 71 | transform_list = [] 72 | 73 | if crop and size != 0: 74 | transform_list.append(transforms.Resize(size)) 75 | transform_list.append(transforms.CenterCrop(size)) 76 | elif size != 0: 77 | transform_list.append(transforms.Resize((size, size))) 78 | 79 | transform_list.append(transforms.Normalize(mean, std)) 80 | 81 | return transforms.Compose(transform_list) 82 | 83 | 84 | def un_normalize_batch(tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): 85 | for t, m, s in zip(tensor, mean, std): 86 | t.mul_(s).add_(m) 87 | return tensor 88 | -------------------------------------------------------------------------------- /module_neural_neighbor/zca.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def whiten(x, ui, u, s): 5 | """ 6 | Applies whitening as described in: 7 | https://openaccess.thecvf.com/content_ICCV_2019/papers/Chiu_Understanding_Generalized_Whitening_and_Coloring_Transform_for_Universal_Style_Transfer_ICCV_2019_paper.pdf 8 | x -- N x D pytorch tensor 9 | ui -- D x D transposed eigenvectors of whitening covariance 10 | u -- D x D eigenvectors of whitening covariance 11 | s -- D x 1 eigenvalues of whitening covariance 12 | """ 13 | def tps(x): 14 | return x.transpose(1, 0) 15 | return tps(torch.matmul(u, torch.matmul(ui, tps(x)) / s)) 16 | 17 | 18 | def colorize(x, ui, u, s): 19 | """ 20 | Applies "coloring transform" as described in: 21 | https://openaccess.thecvf.com/content_ICCV_2019/papers/Chiu_Understanding_Generalized_Whitening_and_Coloring_Transform_for_Universal_Style_Transfer_ICCV_2019_paper.pdf 22 | x -- N x D pytorch tensor 23 | ui -- D x D transposed eigenvectors of coloring covariance 24 | u -- D x D eigenvectors of coloring covariance 25 | s -- D x 1 eigenvalues of coloring covariance 26 | """ 27 | def tps(x): 28 | return x.transpose(1, 0) 29 | return tps(torch.matmul(u, torch.matmul(ui, tps(x)) * s)) 30 | 31 | 32 | def zca(content, style): 33 | """ 34 | Matches the mean and covariance of 'content' to those of 'style' 35 | content -- N x D pytorch tensor of content feature vectors 36 | style -- N x D pytorch tensor of style feature vectors 37 | """ 38 | mu_c = content.mean(0, keepdim=True) 39 | mu_s = style.mean(0, keepdim=True) 40 | 41 | content = content - mu_c 42 | style = style - mu_s 43 | 44 | cov_c = torch.matmul(content.transpose(1, 0), content) / float(content.size(0)) 45 | cov_s = torch.matmul(style.transpose(1, 0), style) / float(style.size(0)) 46 | 47 | u_c, sig_c, _ = torch.svd(cov_c + torch.eye(cov_c.size(0)).cuda() * 1e-4) 48 | u_s, sig_s, _ = torch.svd(cov_s + torch.eye(cov_s.size(0)).cuda() * 1e-4) 49 | 50 | sig_c = sig_c.unsqueeze(1) 51 | sig_s = sig_s.unsqueeze(1) 52 | 53 | u_c_i = u_c.transpose(1, 0) 54 | u_s_i = u_s.transpose(1, 0) 55 | 56 | scl_c = torch.sqrt(torch.clamp(sig_c, 1e-8, 1e8)) 57 | scl_s = torch.sqrt(torch.clamp(sig_s, 1e-8, 1e8)) 58 | 59 | whiten_c = whiten(content, u_c_i, u_c, scl_c) 60 | color_c = colorize(whiten_c, u_s_i, u_s, scl_s) + mu_s 61 | 62 | return color_c, cov_s 63 | 64 | 65 | def zca_tensor(content, style): 66 | """ 67 | Matches the mean and covariance of 'content' to those of 'style' 68 | content -- B x D x H x W pytorch tensor of content feature vectors 69 | style -- B x D x H x W pytorch tensor of style feature vectors 70 | """ 71 | content_rs = content.permute(0, 2, 3, 1).contiguous().view(-1, content.size(1)) 72 | style_rs = style.permute(0, 2, 3, 1).contiguous().view(-1, style.size(1)) 73 | 74 | cs, cov_s = zca(content_rs, style_rs) 75 | 76 | cs = cs.view( 77 | content.size(0), content.size(2), content.size(3), content.size(1) 78 | ).permute(0, 3, 1, 2) 79 | return cs.contiguous(), cov_s 80 | -------------------------------------------------------------------------------- /module_neural_neighbor/featureExtract.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def get_feat_norms(x): 5 | """Makes l2 norm of x[i,:,j,k] = 1 for all i,j,k. Clamps before sqrt for 6 | stability 7 | """ 8 | return torch.clamp(x.pow(2).sum(1, keepdim=True), 1e-8, 1e8).sqrt() 9 | 10 | 11 | def phi_cat(x, phi, layer_l): 12 | """Extract conv features from 'x' at list of VGG16 layers 'layer_l'. Then 13 | normalize features from each conv block based on # of channels, resize, 14 | and concatenate into hypercolumns 15 | Inputs: 16 | x -- Bx3xHxW pytorch tensor, presumed to contain rgb images 17 | phi -- lambda function calling a pretrained Vgg16Pretrained model 18 | layer_l -- layer indexes to form hypercolumns out of 19 | Outputs: 20 | feats -- BxCxHxW pytorch tensor of hypercolumns extracted from 'x' 21 | C depends on 'layer_l' 22 | """ 23 | h = x.size(2) 24 | w = x.size(3) 25 | 26 | feats = phi(x, layer_l, False) 27 | # Normalize each layer by # channels so # of channels doesn't dominate 28 | # cosine distance 29 | feats = [f / f.size(1) for f in feats] 30 | 31 | # Scale layers' features to target size and concatenate 32 | feats = torch.cat( 33 | [ 34 | F.interpolate(f, (h // 4, w // 4), mode="bilinear", align_corners=True) 35 | for f in feats 36 | ], 37 | 1, 38 | ) 39 | 40 | return feats 41 | 42 | 43 | def extract_feats(im, phi, flip_aug=False): 44 | """Extract hypercolumns from 'im' using pretrained VGG16 (passed as phi), 45 | if speficied, extract hypercolumns from rotations of 'im' as well 46 | Inputs: 47 | im -- a Bx3xHxW pytorch tensor, presumed to contain rgb images 48 | phi -- a lambda function calling a pretrained Vgg16Pretrained model 49 | flip_aug -- whether to extract hypercolumns from rotations of 'im' 50 | as well 51 | Outputs: 52 | feats -- a tensor of hypercolumns extracted from 'im', spatial 53 | index is presumed to no longer matter 54 | """ 55 | # In the original paper used all layers, but dropping conv5 block increases 56 | # speed without harming quality 57 | layer_l = [22, 20, 18, 15, 13, 11, 8, 6, 3, 1] 58 | feats = phi_cat(im, phi, layer_l) 59 | 60 | # If specified, extract features from 90, 180, 270 degree rotations of 'im' 61 | if flip_aug: 62 | aug_list = [ 63 | torch.flip(im, [2]).transpose(2, 3), 64 | torch.flip(im, [2, 3]), 65 | torch.flip(im, [3]).transpose(2, 3), 66 | ] 67 | 68 | for i, im_aug in enumerate(aug_list): 69 | feats_new = phi_cat(im_aug, phi, layer_l) 70 | 71 | # Code never looks at patches of features, so fine to just stick 72 | # features from rotated images in adjacent spatial indexes, since 73 | # they will only be accessed in isolation 74 | if i == 1: 75 | feats = torch.cat([feats, feats_new], 2) 76 | else: 77 | feats = torch.cat([feats, feats_new.transpose(2, 3)], 2) 78 | 79 | return feats 80 | -------------------------------------------------------------------------------- /module_unist/unist_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | 5 | from .models import Transformer 6 | 7 | 8 | class video_Style_transfer(nn.Module): 9 | def __init__(self, ckpt_path: str, encoder_path: str, decoder_path: str): 10 | super(video_Style_transfer, self).__init__() 11 | self.ckpt_path = ckpt_path 12 | self.model = Transformer(encoder_path, decoder_path) 13 | self.load_weights() 14 | 15 | def load_weights(self): 16 | # print("Loading model from checkpoint") 17 | ckpt = torch.load(self.ckpt_path, map_location="cpu") 18 | self.model.load_state_dict(get_keys(ckpt, "model"), strict=False) 19 | 20 | def forward( 21 | self, 22 | content_frames, 23 | style_images, 24 | content_type="image", 25 | id_loss="transfer", 26 | tab=None, 27 | ): 28 | transfer_result = self.model.forward( 29 | content_frames, style_images, content_type, id_loss, tab 30 | ) 31 | return transfer_result 32 | 33 | 34 | def inference_unist( 35 | src: torch.Tensor, 36 | src_style: torch.Tensor, 37 | device, 38 | size: int, 39 | do_crop: bool, 40 | network: video_Style_transfer, 41 | content_type="image", 42 | ): 43 | content_tf = test_transform(size, do_crop) 44 | content = content_tf(src).unsqueeze(0).to(device) # [1, B, C, H, W] 45 | 46 | style = content_tf(src_style).to(device) # [B, C, H, W] 47 | 48 | if content_type == "video": 49 | style = match_shape(style, content[0]) 50 | 51 | y_hat = network.forward(content, style, content_type, tab="inference") 52 | output = un_normalize_batch(y_hat) # [B, C, H, W] 53 | 54 | return output 55 | 56 | 57 | def test_transform(size, crop, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): 58 | transform_list = [] 59 | 60 | if crop and size != 0: 61 | transform_list.append(transforms.Resize(size)) 62 | transform_list.append(transforms.CenterCrop(size)) 63 | elif size != 0: 64 | transform_list.append(transforms.Resize((size, size))) 65 | 66 | transform_list.append(transforms.Normalize(mean, std)) 67 | 68 | transform = transforms.Compose(transform_list) 69 | # [B, C, H, W] 70 | return transform 71 | 72 | 73 | def un_normalize_batch(tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): 74 | for t, m, s in zip(tensor, mean, std): 75 | t.mul_(s).add_(m) 76 | return tensor 77 | 78 | 79 | def match_shape(a: torch.Tensor, b: torch.Tensor): 80 | B_a, C, H, W = a.shape 81 | B_b = b.shape[0] 82 | 83 | if B_a == B_b: 84 | return a 85 | elif B_a == 1: 86 | return a.expand(B_b, -1, -1, -1) 87 | elif B_a < B_b: 88 | repeat_factor = B_b // B_a 89 | remainder = B_b % B_a 90 | return torch.cat([a.repeat(repeat_factor, 1, 1, 1), a[:remainder]], dim=0) 91 | else: # B_a > B_b 92 | return a[:B_b] 93 | 94 | 95 | def get_keys(d, name): 96 | if "state_dict" in d: 97 | d = d["state_dict"] 98 | d_filt = {k[len(name) + 1 :]: v for k, v in d.items() if k[: len(name)] == name} 99 | return d_filt 100 | -------------------------------------------------------------------------------- /workflows/workflow_cast.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 6, 3 | "last_link_id": 10, 4 | "nodes": [ 5 | { 6 | "id": 2, 7 | "type": "LoadImage", 8 | "pos": [ 9 | 309, 10 | 187 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 314 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "IMAGE", 22 | "type": "IMAGE", 23 | "links": [ 24 | 8 25 | ], 26 | "shape": 3, 27 | "slot_index": 0 28 | }, 29 | { 30 | "name": "MASK", 31 | "type": "MASK", 32 | "links": null, 33 | "shape": 3 34 | } 35 | ], 36 | "properties": { 37 | "Node name for S&R": "LoadImage" 38 | }, 39 | "widgets_values": [ 40 | "006.jpg", 41 | "image" 42 | ] 43 | }, 44 | { 45 | "id": 3, 46 | "type": "LoadImage", 47 | "pos": [ 48 | 315, 49 | 570 50 | ], 51 | "size": { 52 | "0": 315, 53 | "1": 314 54 | }, 55 | "flags": {}, 56 | "order": 1, 57 | "mode": 0, 58 | "outputs": [ 59 | { 60 | "name": "IMAGE", 61 | "type": "IMAGE", 62 | "links": [ 63 | 9 64 | ], 65 | "shape": 3, 66 | "slot_index": 0 67 | }, 68 | { 69 | "name": "MASK", 70 | "type": "MASK", 71 | "links": null, 72 | "shape": 3 73 | } 74 | ], 75 | "properties": { 76 | "Node name for S&R": "LoadImage" 77 | }, 78 | "widgets_values": [ 79 | "S4.jpg", 80 | "image" 81 | ] 82 | }, 83 | { 84 | "id": 4, 85 | "type": "SaveImage", 86 | "pos": [ 87 | 1050, 88 | 327 89 | ], 90 | "size": { 91 | "0": 315, 92 | "1": 270 93 | }, 94 | "flags": {}, 95 | "order": 3, 96 | "mode": 0, 97 | "inputs": [ 98 | { 99 | "name": "images", 100 | "type": "IMAGE", 101 | "link": 10 102 | } 103 | ], 104 | "properties": {}, 105 | "widgets_values": [ 106 | "StyleTransfer" 107 | ] 108 | }, 109 | { 110 | "id": 6, 111 | "type": "CAST", 112 | "pos": [ 113 | 767, 114 | 326 115 | ], 116 | "size": { 117 | "0": 216.59999084472656, 118 | "1": 78 119 | }, 120 | "flags": {}, 121 | "order": 2, 122 | "mode": 0, 123 | "inputs": [ 124 | { 125 | "name": "src_img", 126 | "type": "IMAGE", 127 | "link": 8 128 | }, 129 | { 130 | "name": "style_img", 131 | "type": "IMAGE", 132 | "link": 9 133 | } 134 | ], 135 | "outputs": [ 136 | { 137 | "name": "res_img", 138 | "type": "IMAGE", 139 | "links": [ 140 | 10 141 | ], 142 | "shape": 3 143 | } 144 | ], 145 | "properties": { 146 | "Node name for S&R": "CAST" 147 | }, 148 | "widgets_values": [ 149 | "CAST" 150 | ] 151 | } 152 | ], 153 | "links": [ 154 | [ 155 | 8, 156 | 2, 157 | 0, 158 | 6, 159 | 0, 160 | "IMAGE" 161 | ], 162 | [ 163 | 9, 164 | 3, 165 | 0, 166 | 6, 167 | 1, 168 | "IMAGE" 169 | ], 170 | [ 171 | 10, 172 | 6, 173 | 0, 174 | 4, 175 | 0, 176 | "IMAGE" 177 | ] 178 | ], 179 | "groups": [], 180 | "config": {}, 181 | "extra": { 182 | "ds": { 183 | "scale": 1.100000000000002, 184 | "offset": [ 185 | -106.49900776658754, 186 | -117.32656624605369 187 | ] 188 | } 189 | }, 190 | "version": 0.4 191 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Checkpoint files 2 | *.pth 3 | *.tar 4 | *.pt 5 | *.t7 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | share/python-wheels/ 30 | *.egg-info/ 31 | .installed.cfg 32 | *.egg 33 | MANIFEST 34 | 35 | # PyInstaller 36 | # Usually these files are written by a python script from a template 37 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 38 | *.manifest 39 | *.spec 40 | 41 | # Installer logs 42 | pip-log.txt 43 | pip-delete-this-directory.txt 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | .pybuilder/ 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | # For a library or package, you might want to ignore these files since the code is 93 | # intended to run in multiple environments; otherwise, check them in: 94 | # .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | -------------------------------------------------------------------------------- /workflows/extra/extra_coral_color_transfer.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 8, 3 | "last_link_id": 23, 4 | "nodes": [ 5 | { 6 | "id": 5, 7 | "type": "CoralColorTransfer", 8 | "pos": [ 9 | 843, 10 | 329 11 | ], 12 | "size": { 13 | "0": 216.59999084472656, 14 | "1": 46 15 | }, 16 | "flags": {}, 17 | "order": 2, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "src_img", 22 | "type": "IMAGE", 23 | "link": 23 24 | }, 25 | { 26 | "name": "style_img", 27 | "type": "IMAGE", 28 | "link": 22 29 | } 30 | ], 31 | "outputs": [ 32 | { 33 | "name": "res_img", 34 | "type": "IMAGE", 35 | "links": [ 36 | 21 37 | ], 38 | "shape": 3, 39 | "slot_index": 0 40 | } 41 | ], 42 | "properties": { 43 | "Node name for S&R": "CoralColorTransfer" 44 | } 45 | }, 46 | { 47 | "id": 2, 48 | "type": "LoadImage", 49 | "pos": [ 50 | 319, 51 | 222 52 | ], 53 | "size": { 54 | "0": 315, 55 | "1": 314 56 | }, 57 | "flags": {}, 58 | "order": 0, 59 | "mode": 0, 60 | "outputs": [ 61 | { 62 | "name": "IMAGE", 63 | "type": "IMAGE", 64 | "links": [ 65 | 23 66 | ], 67 | "shape": 3, 68 | "slot_index": 0 69 | }, 70 | { 71 | "name": "MASK", 72 | "type": "MASK", 73 | "links": null, 74 | "shape": 3 75 | } 76 | ], 77 | "properties": { 78 | "Node name for S&R": "LoadImage" 79 | }, 80 | "widgets_values": [ 81 | "006.jpg", 82 | "image" 83 | ] 84 | }, 85 | { 86 | "id": 3, 87 | "type": "LoadImage", 88 | "pos": [ 89 | 315, 90 | 583 91 | ], 92 | "size": { 93 | "0": 315, 94 | "1": 314 95 | }, 96 | "flags": {}, 97 | "order": 1, 98 | "mode": 0, 99 | "outputs": [ 100 | { 101 | "name": "IMAGE", 102 | "type": "IMAGE", 103 | "links": [ 104 | 22 105 | ], 106 | "shape": 3, 107 | "slot_index": 0 108 | }, 109 | { 110 | "name": "MASK", 111 | "type": "MASK", 112 | "links": null, 113 | "shape": 3 114 | } 115 | ], 116 | "properties": { 117 | "Node name for S&R": "LoadImage" 118 | }, 119 | "widgets_values": [ 120 | "S4.jpg", 121 | "image" 122 | ] 123 | }, 124 | { 125 | "id": 4, 126 | "type": "SaveImage", 127 | "pos": [ 128 | 845, 129 | 430 130 | ], 131 | "size": [ 132 | 435.2039984095661, 133 | 302.78571714680663 134 | ], 135 | "flags": {}, 136 | "order": 3, 137 | "mode": 0, 138 | "inputs": [ 139 | { 140 | "name": "images", 141 | "type": "IMAGE", 142 | "link": 21 143 | } 144 | ], 145 | "properties": {}, 146 | "widgets_values": [ 147 | "StyleTransfer" 148 | ] 149 | } 150 | ], 151 | "links": [ 152 | [ 153 | 21, 154 | 5, 155 | 0, 156 | 4, 157 | 0, 158 | "IMAGE" 159 | ], 160 | [ 161 | 22, 162 | 3, 163 | 0, 164 | 5, 165 | 1, 166 | "IMAGE" 167 | ], 168 | [ 169 | 23, 170 | 2, 171 | 0, 172 | 5, 173 | 0, 174 | "IMAGE" 175 | ] 176 | ], 177 | "groups": [], 178 | "config": {}, 179 | "extra": { 180 | "ds": { 181 | "scale": 1.1, 182 | "offset": [ 183 | -100.20813064097217, 184 | -119.37827913027836 185 | ] 186 | } 187 | }, 188 | "version": 0.4 189 | } -------------------------------------------------------------------------------- /workflows/workflow_aespa.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 8, 3 | "last_link_id": 19, 4 | "nodes": [ 5 | { 6 | "id": 4, 7 | "type": "SaveImage", 8 | "pos": [ 9 | 1050, 10 | 327 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 270 15 | }, 16 | "flags": {}, 17 | "order": 3, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "images", 22 | "type": "IMAGE", 23 | "link": 19 24 | } 25 | ], 26 | "properties": {}, 27 | "widgets_values": [ 28 | "StyleTransfer" 29 | ] 30 | }, 31 | { 32 | "id": 2, 33 | "type": "LoadImage", 34 | "pos": [ 35 | 309, 36 | 187 37 | ], 38 | "size": { 39 | "0": 315, 40 | "1": 314 41 | }, 42 | "flags": {}, 43 | "order": 0, 44 | "mode": 0, 45 | "outputs": [ 46 | { 47 | "name": "IMAGE", 48 | "type": "IMAGE", 49 | "links": [ 50 | 17 51 | ], 52 | "shape": 3, 53 | "slot_index": 0 54 | }, 55 | { 56 | "name": "MASK", 57 | "type": "MASK", 58 | "links": null, 59 | "shape": 3 60 | } 61 | ], 62 | "properties": { 63 | "Node name for S&R": "LoadImage" 64 | }, 65 | "widgets_values": [ 66 | "006.jpg", 67 | "image" 68 | ] 69 | }, 70 | { 71 | "id": 3, 72 | "type": "LoadImage", 73 | "pos": [ 74 | 315, 75 | 570 76 | ], 77 | "size": { 78 | "0": 315, 79 | "1": 314 80 | }, 81 | "flags": {}, 82 | "order": 1, 83 | "mode": 0, 84 | "outputs": [ 85 | { 86 | "name": "IMAGE", 87 | "type": "IMAGE", 88 | "links": [ 89 | 18 90 | ], 91 | "shape": 3, 92 | "slot_index": 0 93 | }, 94 | { 95 | "name": "MASK", 96 | "type": "MASK", 97 | "links": null, 98 | "shape": 3 99 | } 100 | ], 101 | "properties": { 102 | "Node name for S&R": "LoadImage" 103 | }, 104 | "widgets_values": [ 105 | "S4.jpg", 106 | "image" 107 | ] 108 | }, 109 | { 110 | "id": 8, 111 | "type": "AesPA", 112 | "pos": [ 113 | 674, 114 | 349 115 | ], 116 | "size": { 117 | "0": 315, 118 | "1": 102 119 | }, 120 | "flags": {}, 121 | "order": 2, 122 | "mode": 0, 123 | "inputs": [ 124 | { 125 | "name": "src_img", 126 | "type": "IMAGE", 127 | "link": 17 128 | }, 129 | { 130 | "name": "style_img", 131 | "type": "IMAGE", 132 | "link": 18 133 | } 134 | ], 135 | "outputs": [ 136 | { 137 | "name": "out_img", 138 | "type": "IMAGE", 139 | "links": [ 140 | 19 141 | ], 142 | "shape": 3, 143 | "slot_index": 0 144 | } 145 | ], 146 | "properties": { 147 | "Node name for S&R": "AesPA" 148 | }, 149 | "widgets_values": [ 150 | true, 151 | 1024 152 | ] 153 | } 154 | ], 155 | "links": [ 156 | [ 157 | 17, 158 | 2, 159 | 0, 160 | 8, 161 | 0, 162 | "IMAGE" 163 | ], 164 | [ 165 | 18, 166 | 3, 167 | 0, 168 | 8, 169 | 1, 170 | "IMAGE" 171 | ], 172 | [ 173 | 19, 174 | 8, 175 | 0, 176 | 4, 177 | 0, 178 | "IMAGE" 179 | ] 180 | ], 181 | "groups": [], 182 | "config": {}, 183 | "extra": { 184 | "ds": { 185 | "scale": 1.1000000000000016, 186 | "offset": [ 187 | -196.07907685850128, 188 | -119.07837620860246 189 | ] 190 | } 191 | }, 192 | "version": 0.4 193 | } -------------------------------------------------------------------------------- /workflows/workflow_microast.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 7, 3 | "last_link_id": 15, 4 | "nodes": [ 5 | { 6 | "id": 4, 7 | "type": "SaveImage", 8 | "pos": [ 9 | 1050, 10 | 327 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 270 15 | }, 16 | "flags": {}, 17 | "order": 3, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "images", 22 | "type": "IMAGE", 23 | "link": 15 24 | } 25 | ], 26 | "properties": {}, 27 | "widgets_values": [ 28 | "StyleTransfer" 29 | ] 30 | }, 31 | { 32 | "id": 2, 33 | "type": "LoadImage", 34 | "pos": [ 35 | 309, 36 | 187 37 | ], 38 | "size": { 39 | "0": 315, 40 | "1": 314 41 | }, 42 | "flags": {}, 43 | "order": 0, 44 | "mode": 0, 45 | "outputs": [ 46 | { 47 | "name": "IMAGE", 48 | "type": "IMAGE", 49 | "links": [ 50 | 13 51 | ], 52 | "shape": 3, 53 | "slot_index": 0 54 | }, 55 | { 56 | "name": "MASK", 57 | "type": "MASK", 58 | "links": null, 59 | "shape": 3 60 | } 61 | ], 62 | "properties": { 63 | "Node name for S&R": "LoadImage" 64 | }, 65 | "widgets_values": [ 66 | "006.jpg", 67 | "image" 68 | ] 69 | }, 70 | { 71 | "id": 3, 72 | "type": "LoadImage", 73 | "pos": [ 74 | 315, 75 | 570 76 | ], 77 | "size": { 78 | "0": 315, 79 | "1": 314 80 | }, 81 | "flags": {}, 82 | "order": 1, 83 | "mode": 0, 84 | "outputs": [ 85 | { 86 | "name": "IMAGE", 87 | "type": "IMAGE", 88 | "links": [ 89 | 14 90 | ], 91 | "shape": 3, 92 | "slot_index": 0 93 | }, 94 | { 95 | "name": "MASK", 96 | "type": "MASK", 97 | "links": null, 98 | "shape": 3 99 | } 100 | ], 101 | "properties": { 102 | "Node name for S&R": "LoadImage" 103 | }, 104 | "widgets_values": [ 105 | "S4.jpg", 106 | "image" 107 | ] 108 | }, 109 | { 110 | "id": 7, 111 | "type": "MicroAST", 112 | "pos": [ 113 | 680, 114 | 292 115 | ], 116 | "size": { 117 | "0": 315, 118 | "1": 102 119 | }, 120 | "flags": {}, 121 | "order": 2, 122 | "mode": 0, 123 | "inputs": [ 124 | { 125 | "name": "src_img", 126 | "type": "IMAGE", 127 | "link": 13 128 | }, 129 | { 130 | "name": "style_img", 131 | "type": "IMAGE", 132 | "link": 14 133 | } 134 | ], 135 | "outputs": [ 136 | { 137 | "name": "res_img", 138 | "type": "IMAGE", 139 | "links": [ 140 | 15 141 | ], 142 | "shape": 3, 143 | "slot_index": 0 144 | } 145 | ], 146 | "properties": { 147 | "Node name for S&R": "MicroAST" 148 | }, 149 | "widgets_values": [ 150 | false, 151 | 512 152 | ] 153 | } 154 | ], 155 | "links": [ 156 | [ 157 | 13, 158 | 2, 159 | 0, 160 | 7, 161 | 0, 162 | "IMAGE" 163 | ], 164 | [ 165 | 14, 166 | 3, 167 | 0, 168 | 7, 169 | 1, 170 | "IMAGE" 171 | ], 172 | [ 173 | 15, 174 | 7, 175 | 0, 176 | 4, 177 | 0, 178 | "IMAGE" 179 | ] 180 | ], 181 | "groups": [], 182 | "config": {}, 183 | "extra": { 184 | "ds": { 185 | "scale": 1.1000000000000008, 186 | "offset": [ 187 | -166.99103191074192, 188 | -98.00415725314276 189 | ] 190 | } 191 | }, 192 | "version": 0.4 193 | } -------------------------------------------------------------------------------- /workflows/workflow_unist_image.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 13, 3 | "last_link_id": 21, 4 | "nodes": [ 5 | { 6 | "id": 12, 7 | "type": "LoadImage", 8 | "pos": [ 9 | 192, 10 | 179 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 314 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "IMAGE", 22 | "type": "IMAGE", 23 | "links": [ 24 | 20 25 | ], 26 | "shape": 3, 27 | "slot_index": 0 28 | }, 29 | { 30 | "name": "MASK", 31 | "type": "MASK", 32 | "links": null, 33 | "shape": 3 34 | } 35 | ], 36 | "properties": { 37 | "Node name for S&R": "LoadImage" 38 | }, 39 | "widgets_values": [ 40 | "006.jpg", 41 | "image" 42 | ] 43 | }, 44 | { 45 | "id": 3, 46 | "type": "LoadImage", 47 | "pos": [ 48 | 195, 49 | 548 50 | ], 51 | "size": { 52 | "0": 315, 53 | "1": 314 54 | }, 55 | "flags": {}, 56 | "order": 1, 57 | "mode": 0, 58 | "outputs": [ 59 | { 60 | "name": "IMAGE", 61 | "type": "IMAGE", 62 | "links": [ 63 | 19 64 | ], 65 | "shape": 3, 66 | "slot_index": 0 67 | }, 68 | { 69 | "name": "MASK", 70 | "type": "MASK", 71 | "links": null, 72 | "shape": 3 73 | } 74 | ], 75 | "properties": { 76 | "Node name for S&R": "LoadImage" 77 | }, 78 | "widgets_values": [ 79 | "S4.jpg", 80 | "image" 81 | ] 82 | }, 83 | { 84 | "id": 13, 85 | "type": "SaveImage", 86 | "pos": [ 87 | 572, 88 | 415 89 | ], 90 | "size": { 91 | "0": 315, 92 | "1": 270 93 | }, 94 | "flags": {}, 95 | "order": 3, 96 | "mode": 0, 97 | "inputs": [ 98 | { 99 | "name": "images", 100 | "type": "IMAGE", 101 | "link": 21 102 | } 103 | ], 104 | "properties": {}, 105 | "widgets_values": [ 106 | "StyleTransfer" 107 | ] 108 | }, 109 | { 110 | "id": 11, 111 | "type": "UniST", 112 | "pos": [ 113 | 568, 114 | 246 115 | ], 116 | "size": { 117 | "0": 315, 118 | "1": 122 119 | }, 120 | "flags": {}, 121 | "order": 2, 122 | "mode": 0, 123 | "inputs": [ 124 | { 125 | "name": "src_img", 126 | "type": "IMAGE", 127 | "link": 20 128 | }, 129 | { 130 | "name": "style_img", 131 | "type": "IMAGE", 132 | "link": 19 133 | } 134 | ], 135 | "outputs": [ 136 | { 137 | "name": "out_img", 138 | "type": "IMAGE", 139 | "links": [ 140 | 21 141 | ], 142 | "shape": 3, 143 | "slot_index": 0 144 | } 145 | ], 146 | "properties": { 147 | "Node name for S&R": "UniST" 148 | }, 149 | "widgets_values": [ 150 | true, 151 | 1024 152 | ] 153 | } 154 | ], 155 | "links": [ 156 | [ 157 | 19, 158 | 3, 159 | 0, 160 | 11, 161 | 1, 162 | "IMAGE" 163 | ], 164 | [ 165 | 20, 166 | 12, 167 | 0, 168 | 11, 169 | 0, 170 | "IMAGE" 171 | ], 172 | [ 173 | 21, 174 | 11, 175 | 0, 176 | 13, 177 | 0, 178 | "IMAGE" 179 | ] 180 | ], 181 | "groups": [], 182 | "config": {}, 183 | "extra": { 184 | "ds": { 185 | "scale": 1.331000000000001, 186 | "offset": [ 187 | -38.8281700653258, 188 | -204.83558799978883 189 | ] 190 | } 191 | }, 192 | "version": 0.4 193 | } -------------------------------------------------------------------------------- /workflows/workflow_tssat.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 9, 3 | "last_link_id": 23, 4 | "nodes": [ 5 | { 6 | "id": 4, 7 | "type": "SaveImage", 8 | "pos": [ 9 | 1050, 10 | 327 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 270 15 | }, 16 | "flags": {}, 17 | "order": 3, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "images", 22 | "type": "IMAGE", 23 | "link": 23 24 | } 25 | ], 26 | "properties": {}, 27 | "widgets_values": [ 28 | "StyleTransfer" 29 | ] 30 | }, 31 | { 32 | "id": 2, 33 | "type": "LoadImage", 34 | "pos": [ 35 | 309, 36 | 187 37 | ], 38 | "size": { 39 | "0": 315, 40 | "1": 314 41 | }, 42 | "flags": {}, 43 | "order": 0, 44 | "mode": 0, 45 | "outputs": [ 46 | { 47 | "name": "IMAGE", 48 | "type": "IMAGE", 49 | "links": [ 50 | 21 51 | ], 52 | "shape": 3, 53 | "slot_index": 0 54 | }, 55 | { 56 | "name": "MASK", 57 | "type": "MASK", 58 | "links": null, 59 | "shape": 3 60 | } 61 | ], 62 | "properties": { 63 | "Node name for S&R": "LoadImage" 64 | }, 65 | "widgets_values": [ 66 | "006.jpg", 67 | "image" 68 | ] 69 | }, 70 | { 71 | "id": 3, 72 | "type": "LoadImage", 73 | "pos": [ 74 | 315, 75 | 570 76 | ], 77 | "size": { 78 | "0": 315, 79 | "1": 314 80 | }, 81 | "flags": {}, 82 | "order": 1, 83 | "mode": 0, 84 | "outputs": [ 85 | { 86 | "name": "IMAGE", 87 | "type": "IMAGE", 88 | "links": [ 89 | 22 90 | ], 91 | "shape": 3, 92 | "slot_index": 0 93 | }, 94 | { 95 | "name": "MASK", 96 | "type": "MASK", 97 | "links": null, 98 | "shape": 3 99 | } 100 | ], 101 | "properties": { 102 | "Node name for S&R": "LoadImage" 103 | }, 104 | "widgets_values": [ 105 | "S4.jpg", 106 | "image" 107 | ] 108 | }, 109 | { 110 | "id": 9, 111 | "type": "TSSAT", 112 | "pos": [ 113 | 704.2608950403187, 114 | 332.7147398449658 115 | ], 116 | "size": { 117 | "0": 315, 118 | "1": 126 119 | }, 120 | "flags": {}, 121 | "order": 2, 122 | "mode": 0, 123 | "inputs": [ 124 | { 125 | "name": "src_img", 126 | "type": "IMAGE", 127 | "link": 21 128 | }, 129 | { 130 | "name": "style_img", 131 | "type": "IMAGE", 132 | "link": 22 133 | } 134 | ], 135 | "outputs": [ 136 | { 137 | "name": "out_img", 138 | "type": "IMAGE", 139 | "links": [ 140 | 23 141 | ], 142 | "shape": 3, 143 | "slot_index": 0 144 | } 145 | ], 146 | "properties": { 147 | "Node name for S&R": "TSSAT" 148 | }, 149 | "widgets_values": [ 150 | false, 151 | 1024, 152 | 1 153 | ] 154 | } 155 | ], 156 | "links": [ 157 | [ 158 | 21, 159 | 2, 160 | 0, 161 | 9, 162 | 0, 163 | "IMAGE" 164 | ], 165 | [ 166 | 22, 167 | 3, 168 | 0, 169 | 9, 170 | 1, 171 | "IMAGE" 172 | ], 173 | [ 174 | 23, 175 | 9, 176 | 0, 177 | 4, 178 | 0, 179 | "IMAGE" 180 | ] 181 | ], 182 | "groups": [], 183 | "config": {}, 184 | "extra": { 185 | "ds": { 186 | "scale": 1.2100000000000002, 187 | "offset": [ 188 | -289.7051144856484, 189 | -164.0565891162732 190 | ] 191 | } 192 | }, 193 | "version": 0.4 194 | } -------------------------------------------------------------------------------- /workflows/workflow_neural_neighbor.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 4, 3 | "last_link_id": 3, 4 | "nodes": [ 5 | { 6 | "id": 3, 7 | "type": "LoadImage", 8 | "pos": [ 9 | 315, 10 | 570 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 314 15 | }, 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "IMAGE", 22 | "type": "IMAGE", 23 | "links": [ 24 | 2 25 | ], 26 | "shape": 3 27 | }, 28 | { 29 | "name": "MASK", 30 | "type": "MASK", 31 | "links": null, 32 | "shape": 3 33 | } 34 | ], 35 | "properties": { 36 | "Node name for S&R": "LoadImage" 37 | }, 38 | "widgets_values": [ 39 | "S4.jpg", 40 | "image" 41 | ] 42 | }, 43 | { 44 | "id": 2, 45 | "type": "LoadImage", 46 | "pos": [ 47 | 309, 48 | 187 49 | ], 50 | "size": { 51 | "0": 315, 52 | "1": 314 53 | }, 54 | "flags": {}, 55 | "order": 1, 56 | "mode": 0, 57 | "outputs": [ 58 | { 59 | "name": "IMAGE", 60 | "type": "IMAGE", 61 | "links": [ 62 | 1 63 | ], 64 | "shape": 3 65 | }, 66 | { 67 | "name": "MASK", 68 | "type": "MASK", 69 | "links": null, 70 | "shape": 3 71 | } 72 | ], 73 | "properties": { 74 | "Node name for S&R": "LoadImage" 75 | }, 76 | "widgets_values": [ 77 | "006.jpg", 78 | "image" 79 | ] 80 | }, 81 | { 82 | "id": 4, 83 | "type": "SaveImage", 84 | "pos": [ 85 | 1050, 86 | 327 87 | ], 88 | "size": [ 89 | 315, 90 | 270 91 | ], 92 | "flags": {}, 93 | "order": 3, 94 | "mode": 0, 95 | "inputs": [ 96 | { 97 | "name": "images", 98 | "type": "IMAGE", 99 | "link": 3 100 | } 101 | ], 102 | "properties": {}, 103 | "widgets_values": [ 104 | "StyleTransfer" 105 | ] 106 | }, 107 | { 108 | "id": 1, 109 | "type": "NeuralNeighbor", 110 | "pos": [ 111 | 681, 112 | 330 113 | ], 114 | "size": { 115 | "0": 315, 116 | "1": 222 117 | }, 118 | "flags": {}, 119 | "order": 2, 120 | "mode": 0, 121 | "inputs": [ 122 | { 123 | "name": "src_img", 124 | "type": "IMAGE", 125 | "link": 1, 126 | "slot_index": 0 127 | }, 128 | { 129 | "name": "style_img", 130 | "type": "IMAGE", 131 | "link": 2, 132 | "slot_index": 1 133 | } 134 | ], 135 | "outputs": [ 136 | { 137 | "name": "res_img", 138 | "type": "IMAGE", 139 | "links": [ 140 | 3 141 | ], 142 | "shape": 3, 143 | "slot_index": 0 144 | } 145 | ], 146 | "properties": { 147 | "Node name for S&R": "NeuralNeighbor" 148 | }, 149 | "widgets_values": [ 150 | 1024, 151 | true, 152 | false, 153 | false, 154 | true, 155 | 0.75, 156 | 200 157 | ] 158 | } 159 | ], 160 | "links": [ 161 | [ 162 | 1, 163 | 2, 164 | 0, 165 | 1, 166 | 0, 167 | "IMAGE" 168 | ], 169 | [ 170 | 2, 171 | 3, 172 | 0, 173 | 1, 174 | 1, 175 | "IMAGE" 176 | ], 177 | [ 178 | 3, 179 | 1, 180 | 0, 181 | 4, 182 | 0, 183 | "IMAGE" 184 | ] 185 | ], 186 | "groups": [], 187 | "config": {}, 188 | "extra": { 189 | "ds": { 190 | "scale": 0.9090909090909091, 191 | "offset": [ 192 | 226.81640833417435, 193 | 19.300671759087805 194 | ] 195 | } 196 | }, 197 | "version": 0.4 198 | } -------------------------------------------------------------------------------- /module_efdm/efdm_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchvision import transforms 4 | 5 | from . import net 6 | from .function import ( 7 | adaptive_instance_normalization, 8 | adaptive_mean_normalization, 9 | adaptive_std_normalization, 10 | coral, 11 | exact_feature_distribution_matching, 12 | histogram_matching, 13 | ) 14 | 15 | 16 | class MODEL_EFDM: 17 | def __init__(self, dec_path: str, vgg_path: str, device) -> None: 18 | self.decoder = net.decoder 19 | vgg = net.vgg 20 | 21 | self.decoder.load_state_dict(torch.load(dec_path)) 22 | vgg.load_state_dict(torch.load(vgg_path)) 23 | 24 | self.vgg = torch.nn.Sequential(*list(vgg.children())[:31]) 25 | 26 | self.vgg.eval().to(device) 27 | self.decoder.eval().to(device) 28 | 29 | 30 | def inference_efdm( 31 | src_img: torch.Tensor, 32 | style_img: torch.Tensor, 33 | device, 34 | decoder, 35 | vgg, 36 | alpha: float, 37 | size: int, 38 | style_type: str, 39 | do_crop=False, 40 | preserve_color=False, 41 | style_interpolation_weights: list[int] | None = None, 42 | ): 43 | if style_img.shape[0] > 1: 44 | if style_interpolation_weights is None: 45 | style_interpolation_weights = [ 46 | 1 / style_img.shape[0] for _ in range(style_img.shape[0]) 47 | ] 48 | else: 49 | style_interpolation_weights = [ 50 | w / sum(style_interpolation_weights) 51 | for w in style_interpolation_weights 52 | ] 53 | else: 54 | style_interpolation_weights = None 55 | 56 | content_tf = test_transform(size, do_crop) 57 | 58 | content = content_tf(src_img).to(device) 59 | style = content_tf(style_img).to(device) 60 | style = F.interpolate( 61 | style, 62 | size=(content.shape[2], content.shape[3]), 63 | mode="bilinear", 64 | align_corners=False, 65 | ) 66 | content = content.expand_as(style) 67 | 68 | # print(f"Resized: {content.shape=}") 69 | # print(f"Resized: {style.shape=}") 70 | 71 | if preserve_color: 72 | tmp_content = content.squeeze() 73 | tmp_styles = [coral(stl, tmp_content).unsqueeze(0) for stl in style] 74 | style = torch.cat(tmp_styles, dim=0) 75 | style = style.to(device) 76 | 77 | if device.type != "cpu": 78 | torch.cuda.empty_cache() 79 | 80 | output = style_transfer( 81 | vgg, 82 | decoder, 83 | content, 84 | style, 85 | device, 86 | alpha, 87 | style_interpolation_weights, 88 | style_type=style_type, 89 | ) 90 | 91 | if device.type != "cpu": 92 | torch.cuda.empty_cache() 93 | 94 | # [N, C, H, W] 95 | return output 96 | 97 | 98 | def test_transform(size, crop): 99 | transform_list = [] 100 | if size != 0: 101 | transform_list.append(transforms.Resize(size)) 102 | if crop: 103 | transform_list.append(transforms.CenterCrop(size)) 104 | transform = transforms.Compose(transform_list) 105 | # [B, C, H, W] 106 | return transform 107 | 108 | 109 | def style_transfer( 110 | vgg, 111 | decoder, 112 | content, 113 | style, 114 | device, 115 | alpha=1.0, 116 | interpolation_weights=None, 117 | style_type="adain", 118 | ): 119 | assert 0.0 <= alpha <= 1.0 120 | content_f = vgg(content) 121 | style_f = vgg(style) 122 | if interpolation_weights: 123 | _, C, H, W = content_f.size() 124 | feat = torch.FloatTensor(1, C, H, W).zero_().to(device) 125 | if style_type == "adain": 126 | base_feat = adaptive_instance_normalization(content_f, style_f) 127 | elif style_type == "adamean": 128 | base_feat = adaptive_mean_normalization(content_f, style_f) 129 | elif style_type == "adastd": 130 | base_feat = adaptive_std_normalization(content_f, style_f) 131 | elif style_type == "efdm": 132 | base_feat = exact_feature_distribution_matching(content_f, style_f) 133 | elif style_type == "hm": 134 | feat = histogram_matching(content_f, style_f) 135 | else: 136 | raise NotImplementedError 137 | for i, w in enumerate(interpolation_weights): 138 | feat = feat + w * base_feat[i : i + 1] 139 | content_f = content_f[0:1] 140 | else: 141 | if style_type == "adain": 142 | feat = adaptive_instance_normalization(content_f, style_f) 143 | elif style_type == "adamean": 144 | feat = adaptive_mean_normalization(content_f, style_f) 145 | elif style_type == "adastd": 146 | feat = adaptive_std_normalization(content_f, style_f) 147 | elif style_type == "efdm": 148 | feat = exact_feature_distribution_matching(content_f, style_f) 149 | elif style_type == "hm": 150 | feat = histogram_matching(content_f, style_f) 151 | else: 152 | raise NotImplementedError 153 | feat = feat * alpha + content_f * (1 - alpha) 154 | return decoder(feat) 155 | -------------------------------------------------------------------------------- /module_efdm/function.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from skimage.exposure import match_histograms 4 | 5 | 6 | def calc_mean_std(feat, eps=1e-5): 7 | # eps is a small value added to the variance to avoid divide-by-zero. 8 | size = feat.size() 9 | assert len(size) == 4 10 | N, C = size[:2] 11 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 12 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 13 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 14 | return feat_mean, feat_std 15 | 16 | 17 | def adaptive_instance_normalization(content_feat, style_feat): 18 | assert content_feat.size()[:2] == style_feat.size()[:2] 19 | size = content_feat.size() 20 | style_mean, style_std = calc_mean_std(style_feat) 21 | content_mean, content_std = calc_mean_std(content_feat) 22 | 23 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand( 24 | size 25 | ) 26 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 27 | 28 | 29 | ## AdaMean 30 | def adaptive_mean_normalization(content_feat, style_feat): 31 | assert content_feat.size()[:2] == style_feat.size()[:2] 32 | size = content_feat.size() 33 | style_mean, style_std = calc_mean_std(style_feat) 34 | content_mean, content_std = calc_mean_std(content_feat) 35 | 36 | normalized_feat = content_feat - content_mean.expand(size) 37 | return normalized_feat + style_mean.expand(size) 38 | 39 | 40 | ## AdaStd 41 | def adaptive_std_normalization(content_feat, style_feat): 42 | assert content_feat.size()[:2] == style_feat.size()[:2] 43 | size = content_feat.size() 44 | style_mean, style_std = calc_mean_std(style_feat) 45 | content_mean, content_std = calc_mean_std(content_feat) 46 | 47 | normalized_feat = (content_feat) / content_std.expand(size) 48 | return normalized_feat * style_std.expand(size) 49 | 50 | 51 | ## EFDM 52 | def exact_feature_distribution_matching(content_feat, style_feat): 53 | assert content_feat.size() == style_feat.size() 54 | B, C, W, H = ( 55 | content_feat.size(0), 56 | content_feat.size(1), 57 | content_feat.size(2), 58 | content_feat.size(3), 59 | ) 60 | value_content, index_content = torch.sort( 61 | content_feat.view(B, C, -1) 62 | ) # sort conduct a deep copy here. 63 | value_style, _ = torch.sort( 64 | style_feat.view(B, C, -1) 65 | ) # sort conduct a deep copy here. 66 | inverse_index = index_content.argsort(-1) 67 | new_content = content_feat.view(B, C, -1) + ( 68 | value_style.gather(-1, inverse_index) - content_feat.view(B, C, -1).detach() 69 | ) 70 | 71 | return new_content.view(B, C, W, H) 72 | 73 | 74 | ## HM 75 | def histogram_matching(content_feat, style_feat): 76 | assert content_feat.size() == style_feat.size() 77 | B, C, W, H = ( 78 | content_feat.size(0), 79 | content_feat.size(1), 80 | content_feat.size(2), 81 | content_feat.size(3), 82 | ) 83 | x_view = content_feat.view(-1, W, H) 84 | image1_temp = match_histograms( 85 | np.array(x_view.detach().clone().cpu().float().transpose(0, 2)), 86 | np.array( 87 | style_feat.view(-1, W, H).detach().clone().cpu().float().transpose(0, 2) 88 | ), 89 | ) 90 | image1_temp = ( 91 | torch.from_numpy(image1_temp) 92 | .float() 93 | .to(content_feat.device) 94 | .transpose(0, 2) 95 | .view(B, C, W, H) 96 | ) 97 | return content_feat + (image1_temp - content_feat).detach() 98 | 99 | 100 | def _calc_feat_flatten_mean_std(feat): 101 | # takes 3D feat (C, H, W), return mean and std of array within channels 102 | assert feat.size()[0] == 3 103 | assert isinstance(feat, torch.FloatTensor) 104 | feat_flatten = feat.view(3, -1) 105 | mean = feat_flatten.mean(dim=-1, keepdim=True) 106 | std = feat_flatten.std(dim=-1, keepdim=True) 107 | return feat_flatten, mean, std 108 | 109 | 110 | def _mat_sqrt(x): 111 | U, D, V = torch.svd(x) 112 | return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t()) 113 | 114 | 115 | def coral(source: torch.Tensor, target: torch.Tensor): 116 | # assume both source and target are 3D array (C, H, W) 117 | # Note: flatten -> f 118 | 119 | source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source.to('cpu')) 120 | source_f_norm = ( 121 | source_f - source_f_mean.expand_as(source_f) 122 | ) / source_f_std.expand_as(source_f) 123 | source_f_cov_eye = torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3) 124 | 125 | target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target.to('cpu')) 126 | target_f_norm = ( 127 | target_f - target_f_mean.expand_as(target_f) 128 | ) / target_f_std.expand_as(target_f) 129 | target_f_cov_eye = torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3) 130 | 131 | source_f_norm_transfer = torch.mm( 132 | _mat_sqrt(target_f_cov_eye), 133 | torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)), source_f_norm), 134 | ) 135 | 136 | source_f_transfer = source_f_norm_transfer * target_f_std.expand_as( 137 | source_f_norm 138 | ) + target_f_mean.expand_as(source_f_norm) 139 | 140 | return source_f_transfer.view(source.size()) 141 | -------------------------------------------------------------------------------- /module_unist/transformer_subLayers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | 7 | from .transformer_modules import ScaledDotProductAttention 8 | 9 | 10 | class MultiHeadAttention(nn.Module): 11 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1, granular=True): 12 | super().__init__() 13 | self.n_head = n_head 14 | self.d_k = d_k 15 | self.d_v = d_v 16 | self.query_embedding = nn.Conv2d(512, 512, kernel_size=1, padding=0) 17 | self.value_embedding = nn.Conv2d(512, 512, kernel_size=1, padding=0) 18 | self.key_embedding = nn.Conv2d(512, 512, kernel_size=1, padding=0) 19 | 20 | self.granular = granular 21 | 22 | # option b 23 | self.w_qs_h = nn.Linear(d_model, n_head * d_k, bias=False) 24 | self.w_ks_h = nn.Linear(d_model, n_head * d_k, bias=False) 25 | self.w_vs_h = nn.Linear(d_model, n_head * d_v, bias=False) 26 | self.fc_h = nn.Linear(n_head * d_v, d_model, bias=False) 27 | 28 | self.attention_h = ScaledDotProductAttention(temperature=d_k**0.5) 29 | 30 | self.w_qs_w = nn.Linear(d_model, n_head * d_k, bias=False) 31 | self.w_ks_w = nn.Linear(d_model, n_head * d_k, bias=False) 32 | self.w_vs_w = nn.Linear(d_model, n_head * d_v, bias=False) 33 | self.fc_w = nn.Linear(n_head * d_v, d_model, bias=False) 34 | 35 | self.attention_w = ScaledDotProductAttention(temperature=d_k**0.5) 36 | 37 | self.dropout1 = nn.Dropout(dropout) 38 | self.dropout2 = nn.Dropout(dropout) 39 | self.layer_norm1 = nn.LayerNorm(d_model, eps=1e-6) 40 | self.layer_norm2 = nn.LayerNorm(d_model, eps=1e-6) 41 | self.relu = nn.ReLU(inplace=True) 42 | 43 | def forward(self, q, k, v, mask=None): 44 | residual = q 45 | bt, l, _ = q.size() # noqa: E741 46 | 47 | q = rearrange(q, "b (m1 m2) c -> b c m1 m2", m1=int(math.sqrt(l))) 48 | k = rearrange(k, "b (m1 m2) c -> b c m1 m2", m1=int(math.sqrt(l))) 49 | v = rearrange(v, "b (m1 m2) c -> b c m1 m2", m1=int(math.sqrt(l))) 50 | 51 | q = self.query_embedding(q) 52 | k_prev = self.key_embedding(k) 53 | v = self.value_embedding(v) 54 | 55 | q = rearrange(q, "b c m1 m2 -> (b m1) m2 c") 56 | k = rearrange(k_prev, "b c m1 m2 -> (b m1) m2 c") 57 | v = rearrange(v, "b c m1 m2 -> (b m1) m2 c") 58 | 59 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 60 | q = self.w_qs_h(q) 61 | k = self.w_ks_h(k) 62 | v = self.w_vs_h(v) 63 | 64 | q = rearrange(q, "b len (head d_k) -> b len head d_k", d_k=d_k, head=n_head) 65 | k = rearrange(k, "b len (head d_k) -> b len head d_k", d_k=d_k, head=n_head) 66 | v = rearrange(v, "b len (head d_v) -> b len head d_v", d_v=d_v, head=n_head) 67 | 68 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 69 | 70 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 71 | 72 | if mask is not None: 73 | mask = mask.unsqueeze(1) 74 | 75 | q, attn = self.attention_h(q, k, v, mask=mask) 76 | 77 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 78 | q = self.dropout1(self.fc_h(q)) 79 | q = rearrange(q, "(bt h) w c ->bt (h w) c", h=int(math.sqrt(l))) 80 | q_mid = self.layer_norm1(q) 81 | 82 | q = rearrange(q_mid, "bt (h w) c ->(bt w) h c", h=int(math.sqrt(l))) 83 | if self.granular: 84 | k = rearrange(q_mid,'bt (h w) c ->(bt w) h c',h =int(math.sqrt(l))) 85 | else: 86 | k = rearrange(k_prev, "b c m1 m2 -> (b m2) m1 c") 87 | v = rearrange(q_mid, "bt (h w) c ->(bt w) h c", h=int(math.sqrt(l))) 88 | 89 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 90 | q = self.w_qs_w(q) 91 | k = self.w_ks_w(k) 92 | v = self.w_vs_w(v) 93 | 94 | q = rearrange(q, "b len (head d_k) -> b len head d_k", d_k=d_k, head=n_head) 95 | k = rearrange(k, "b len (head d_k) -> b len head d_k", d_k=d_k, head=n_head) 96 | v = rearrange(v, "b len (head d_v) -> b len head d_v", d_v=d_v, head=n_head) 97 | 98 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) # noqa: F841 99 | 100 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 101 | 102 | if mask is not None: 103 | mask = mask.unsqueeze(1) 104 | 105 | q, attn = self.attention_w(q, k, v, mask=mask) 106 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 107 | q = self.dropout2(self.fc_w(q)) 108 | q = rearrange(q, "(bt w) h c -> bt (h w) c", w=int(math.sqrt(l))) 109 | q += residual 110 | q = self.layer_norm2(q) 111 | 112 | return q, attn 113 | 114 | 115 | class PositionwiseFeedForward(nn.Module): 116 | def __init__(self, d_in, d_hid, dropout=0.1): 117 | super().__init__() 118 | self.w_1 = nn.Linear(d_in, d_hid) 119 | self.w_2 = nn.Linear(d_hid, d_in) 120 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 121 | self.dropout = nn.Dropout(dropout) 122 | 123 | def forward(self, x): 124 | residual = x 125 | x = self.w_2(F.relu(self.w_1(x))) 126 | x = self.dropout(x) 127 | x += residual 128 | 129 | x = self.layer_norm(x) 130 | 131 | return x 132 | -------------------------------------------------------------------------------- /module_aespa/aespa_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import torchvision.transforms.functional as tf_f 4 | 5 | from .aespanet_models import ( 6 | AdaptiveMultiAttn_Transformer_v2, 7 | VGGDecoder, 8 | VGGEncoder, 9 | ) 10 | from .utils import adaptive_get_keys, adaptive_gram_weight, feature_wct_simple 11 | 12 | 13 | class MODEL_AESPA: 14 | def __init__( 15 | self, vgg_51_path: str, dec_path: str, transformer_path: str, device 16 | ) -> None: 17 | pretrained_vgg = torch.load(vgg_51_path) 18 | self.encoder = VGGEncoder(pretrained_vgg) 19 | self.decoder = VGGDecoder() 20 | self.transformer = AdaptiveMultiAttn_Transformer_v2( 21 | in_planes=512, 22 | out_planes=512, 23 | query_planes=512, 24 | key_planes=512 + 256 + 128 + 64, 25 | ) 26 | 27 | self.decoder.load_state_dict(torch.load(dec_path)["state_dict"]) 28 | self.transformer.load_state_dict(torch.load(transformer_path)["state_dict"]) 29 | 30 | self.encoder.eval().to(device) 31 | self.decoder.eval().to(device) 32 | self.transformer.eval().to(device) 33 | 34 | self.device = device 35 | 36 | def forward(self, content: torch.Tensor, style: torch.Tensor): 37 | gray_content = tf_f.rgb_to_grayscale(content).repeat(1, 3, 1, 1).to(self.device) 38 | gray_style = tf_f.rgb_to_grayscale(style).repeat(1, 3, 1, 1).to(self.device) 39 | 40 | style_weight = ( 41 | adaptive_gram_weight(style, 1, 8, self.encoder) 42 | + adaptive_gram_weight(style, 2, 8, self.encoder) 43 | + adaptive_gram_weight(style, 3, 8, self.encoder) 44 | ) / 3 45 | 46 | gray_style_weight = ( 47 | adaptive_gram_weight(gray_style, 1, 8, self.encoder) 48 | + adaptive_gram_weight(gray_style, 2, 8, self.encoder) 49 | + adaptive_gram_weight(gray_style, 3, 8, self.encoder) 50 | ) / 3 51 | 52 | style_adaptive_alpha = ( 53 | style_weight.unsqueeze(1).to(self.device) 54 | + gray_style_weight.unsqueeze(1).to(self.device) 55 | ) / 2 56 | 57 | content_skips = {} 58 | style_skips = {} 59 | 60 | _ = self.encoder.encode(content, content_skips) # content_feat 61 | _ = self.encoder.encode(style, style_skips) # style_feat 62 | 63 | if gray_content is not None: 64 | gray_content_skips = {} 65 | _ = self.encoder.encode( 66 | gray_content, gray_content_skips 67 | ) # gray_content_feat 68 | if gray_style is not None: 69 | gray_style_skips = {} 70 | _ = self.encoder.encode(gray_style, gray_style_skips) # gray_style_feat 71 | 72 | ( 73 | local_transformed_feature, 74 | _, 75 | _, 76 | _, 77 | _, 78 | ) = self.transformer.forward( 79 | content_skips["conv4_1"], 80 | style_skips["conv4_1"], 81 | content_skips["conv5_1"], 82 | style_skips["conv5_1"], 83 | adaptive_get_keys( 84 | content_skips, 4, 4, target_feat=content_skips["conv4_1"] 85 | ), 86 | adaptive_get_keys(style_skips, 1, 4, target_feat=style_skips["conv4_1"]), 87 | adaptive_get_keys( 88 | content_skips, 5, 5, target_feat=content_skips["conv5_1"] 89 | ), 90 | adaptive_get_keys(style_skips, 1, 5, target_feat=style_skips["conv5_1"]), 91 | ) 92 | 93 | if gray_content is not None: 94 | global_transformed_feat = feature_wct_simple( 95 | gray_content_skips["conv4_1"], gray_style_skips["conv4_1"] 96 | ) 97 | else: 98 | global_transformed_feat = feature_wct_simple( 99 | content_skips["conv4_1"], style_skips["conv4_1"] 100 | ) 101 | 102 | transformed_feature = ( 103 | global_transformed_feat 104 | * (1 - style_adaptive_alpha.unsqueeze(-1).unsqueeze(-1)) 105 | + style_adaptive_alpha.unsqueeze(-1).unsqueeze(-1) 106 | * local_transformed_feature 107 | ) 108 | 109 | stylized_image = self.decoder.decode( 110 | transformed_feature, content_skips, style_skips 111 | ) 112 | 113 | return stylized_image 114 | 115 | 116 | def inference_aespa( 117 | src_img: torch.Tensor, 118 | style_img: torch.Tensor, 119 | device, 120 | size: int, 121 | do_crop: bool, 122 | model: MODEL_AESPA, 123 | ): 124 | content_tf = test_transform( 125 | size, do_crop, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 126 | ) 127 | content = content_tf(src_img).to(device) # [B, C, H, W] 128 | style = content_tf(style_img).to(device) 129 | 130 | stylized_image = model.forward(content, style) 131 | 132 | output = un_normalize_batch(stylized_image) 133 | 134 | return output 135 | 136 | 137 | def test_transform(size, crop, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): 138 | transform_list = [] 139 | 140 | if size != 0: 141 | # Ensure size is even 142 | size = (size // 2) * 2 143 | 144 | # Limit size to 1024 145 | if size > 1024: 146 | size = 1024 147 | 148 | if crop: 149 | transform_list.extend( 150 | [transforms.Resize(size), transforms.CenterCrop(size)] 151 | ) 152 | else: 153 | transform_list.append(transforms.Resize((size, size))) 154 | 155 | transform_list.append(transforms.Normalize(mean, std)) 156 | 157 | return transforms.Compose(transform_list) 158 | 159 | 160 | def un_normalize_batch(tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): 161 | for t, m, s in zip(tensor, mean, std): 162 | t.mul_(s).add_(m) 163 | return tensor 164 | -------------------------------------------------------------------------------- /module_tssat/tssat_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import tqdm 5 | from comfy.utils import ProgressBar 6 | from torchvision import transforms 7 | 8 | from . import net 9 | 10 | 11 | class MODEL_TSSAT: 12 | def __init__(self, dec_path: str, vgg_path: str, device) -> None: 13 | self.decoder = net.decoder 14 | vgg = net.vgg 15 | 16 | self.decoder.load_state_dict(torch.load(dec_path)) 17 | vgg.load_state_dict(torch.load(vgg_path)) 18 | 19 | self.enc_1 = nn.Sequential(*list(vgg.children())[:4]) # input -> relu1_1 20 | self.enc_2 = nn.Sequential(*list(vgg.children())[4:11]) # relu1_1 -> relu2_1 21 | self.enc_3 = nn.Sequential(*list(vgg.children())[11:18]) # relu2_1 -> relu3_1 22 | self.enc_4 = nn.Sequential(*list(vgg.children())[18:31]) # relu3_1 -> relu4_1 23 | 24 | self.enc_1.eval().to(device) 25 | self.enc_2.eval().to(device) 26 | self.enc_3.eval().to(device) 27 | self.enc_4.eval().to(device) 28 | self.decoder.eval().to(device) 29 | 30 | def encode(self, item): 31 | out = self.enc_1(item) 32 | out = self.enc_2(out) 33 | out = self.enc_3(out) 34 | out = self.enc_4(out) 35 | return out 36 | 37 | 38 | def inference_tssat( 39 | src_img: torch.Tensor, 40 | style_img: torch.Tensor, 41 | device, 42 | size: int, 43 | do_crop: False, 44 | max_steps: int, 45 | model: MODEL_TSSAT, 46 | ): 47 | content_tf = test_transform(size, do_crop) 48 | content = content_tf(src_img).to(device) # [1, C, H, W] 49 | style = content_tf(style_img).to(device) # [1, C, H, W] 50 | 51 | for _ in range(max_steps): 52 | Content4_1 = model.encode(content) 53 | Style4_1 = model.encode(style) 54 | content = model.decoder(TSSAT(Content4_1, Style4_1)) 55 | 56 | return content 57 | 58 | 59 | def forward(item, enc_1, enc_2, enc_3, enc_4): 60 | out = enc_1(item) 61 | out = enc_2(out) 62 | out = enc_3(out) 63 | out = enc_4(out) 64 | return out 65 | 66 | 67 | def test_transform(size, crop): 68 | transform_list = [] 69 | if size != 0: 70 | transform_list.append(transforms.Resize(size)) 71 | if crop: 72 | transform_list.append(transforms.CenterCrop(size)) 73 | transform = transforms.Compose(transform_list) 74 | # [B, C, H, W] 75 | return transform 76 | 77 | 78 | def TSSAT(cf, sf, patch_size=5, stride=1): # cf,sf Batch_size x C x H x W 79 | b, c, h, w = sf.size() # 2 x 256 x 64 x 64 80 | # print(cf.size()) 81 | kh, kw = patch_size, patch_size 82 | sh, sw = stride, stride 83 | 84 | # Create convolutional filters by style features 85 | sf_unfold = sf.unfold(2, kh, sh).unfold(3, kw, sw) 86 | patches = sf_unfold.permute(0, 2, 3, 1, 4, 5) 87 | patches = patches.reshape(b, -1, c, kh, kw) 88 | patches_norm = torch.norm(patches.reshape(*patches.shape[:2], -1), dim=2).reshape( 89 | b, -1, 1, 1, 1 90 | ) 91 | patches_norm = patches / patches_norm 92 | # patches size is 2 x 3844 x 256 x 3 x 3 93 | 94 | cf = adaptive_instance_normalization(cf, sf) 95 | 96 | for i in range(b): 97 | cf_temp = cf[i].unsqueeze(0) # [1 x 256 x 64 x 64] 98 | patches_norm_temp = patches_norm[i] # [3844, 256, 3, 3] 99 | patches_temp = patches[i] 100 | 101 | _, _, ch, cw = cf.size() 102 | pbar = ProgressBar(ch) 103 | for c_i in tqdm.tqdm(range(0, ch, patch_size), desc="Optimizing"): 104 | ################################################### 105 | if (c_i + patch_size) > ch: 106 | break 107 | elif (c_i + 2 * patch_size) > ch: 108 | ckh = ch - c_i 109 | else: 110 | ckh = patch_size 111 | ################################################### 112 | 113 | for c_j in range(0, cw, patch_size): 114 | ################################################### 115 | if (c_j + patch_size) > cw: 116 | break 117 | elif (c_j + 2 * patch_size) > cw: 118 | ckw = cw - c_j 119 | else: 120 | ckw = patch_size 121 | ################################################### 122 | 123 | temp = cf_temp[:, :, c_i : c_i + ckh, c_j : c_j + ckw] 124 | conv_out = F.conv2d(temp, patches_norm_temp, stride=patch_size) 125 | index = conv_out.argmax(dim=1).squeeze() 126 | style_temp = patches_temp[index].unsqueeze(0) 127 | stylized_part = adaptive_instance_normalization(temp, style_temp) 128 | 129 | if c_j == 0: 130 | p = stylized_part 131 | else: 132 | p = torch.cat([p, stylized_part], 3) 133 | 134 | if c_i == 0: 135 | q = p 136 | else: 137 | q = torch.cat([q, p], 2) 138 | 139 | pbar.update_absolute(c_i, ch) 140 | 141 | if i == 0: 142 | out = q 143 | else: 144 | out = torch.cat([out, q], 0) 145 | 146 | return out 147 | 148 | 149 | def adaptive_instance_normalization(content_feat, style_feat): 150 | assert content_feat.size()[:2] == style_feat.size()[:2] 151 | size = content_feat.size() 152 | style_mean, style_std = calc_mean_std(style_feat) 153 | content_mean, content_std = calc_mean_std(content_feat) 154 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand( 155 | size 156 | ) 157 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 158 | 159 | 160 | def calc_mean_std(feat, eps=1e-5): 161 | # eps is a small value added to the variance to avoid divide-by-zero. 162 | size = feat.size() 163 | assert len(size) == 4 164 | N, C = size[:2] 165 | feat_var = feat.contiguous().view(N, C, -1).var(dim=2) + eps 166 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 167 | feat_mean = feat.contiguous().view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 168 | return feat_mean, feat_std 169 | -------------------------------------------------------------------------------- /module_cast/net.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | vgg = nn.Sequential( 4 | nn.Conv2d(3, 3, (1, 1)), 5 | nn.ReflectionPad2d((1, 1, 1, 1)), 6 | nn.Conv2d(3, 64, (3, 3)), 7 | nn.ReLU(), # relu1-1 8 | nn.ReflectionPad2d((1, 1, 1, 1)), 9 | nn.Conv2d(64, 64, (3, 3)), 10 | nn.ReLU(), # relu1-2 11 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 12 | nn.ReflectionPad2d((1, 1, 1, 1)), 13 | nn.Conv2d(64, 128, (3, 3)), 14 | nn.ReLU(), # relu2-1 15 | nn.ReflectionPad2d((1, 1, 1, 1)), 16 | nn.Conv2d(128, 128, (3, 3)), 17 | nn.ReLU(), # relu2-2 18 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 19 | nn.ReflectionPad2d((1, 1, 1, 1)), 20 | nn.Conv2d(128, 256, (3, 3)), 21 | nn.ReLU(), # relu3-1 22 | nn.ReflectionPad2d((1, 1, 1, 1)), 23 | nn.Conv2d(256, 256, (3, 3)), 24 | nn.ReLU(), # relu3-2 25 | nn.ReflectionPad2d((1, 1, 1, 1)), 26 | nn.Conv2d(256, 256, (3, 3)), 27 | nn.ReLU(), # relu3-3 28 | nn.ReflectionPad2d((1, 1, 1, 1)), 29 | nn.Conv2d(256, 256, (3, 3)), 30 | nn.ReLU(), # relu3-4 31 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 32 | nn.ReflectionPad2d((1, 1, 1, 1)), 33 | nn.Conv2d(256, 512, (3, 3)), 34 | nn.ReLU(), # relu4-1, this is the last layer used 35 | nn.ReflectionPad2d((1, 1, 1, 1)), 36 | nn.Conv2d(512, 512, (3, 3)), 37 | nn.ReLU(), # relu4-2 38 | nn.ReflectionPad2d((1, 1, 1, 1)), 39 | nn.Conv2d(512, 512, (3, 3)), 40 | nn.ReLU(), # relu4-3 41 | nn.ReflectionPad2d((1, 1, 1, 1)), 42 | nn.Conv2d(512, 512, (3, 3)), 43 | nn.ReLU(), # relu4-4 44 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 45 | nn.ReflectionPad2d((1, 1, 1, 1)), 46 | nn.Conv2d(512, 512, (3, 3)), 47 | nn.ReLU(), # relu5-1 48 | nn.ReflectionPad2d((1, 1, 1, 1)), 49 | nn.Conv2d(512, 512, (3, 3)), 50 | nn.ReLU(), # relu5-2 51 | nn.ReflectionPad2d((1, 1, 1, 1)), 52 | nn.Conv2d(512, 512, (3, 3)), 53 | nn.ReLU(), # relu5-3 54 | nn.ReflectionPad2d((1, 1, 1, 1)), 55 | nn.Conv2d(512, 512, (3, 3)), 56 | nn.ReLU(), # relu5-4 57 | ) 58 | 59 | 60 | class ADAIN_Encoder(nn.Module): 61 | def __init__(self, encoder): 62 | super(ADAIN_Encoder, self).__init__() 63 | enc_layers = list(encoder.children()) 64 | self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1 64 65 | self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1 128 66 | self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1 256 67 | self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1 512 68 | 69 | self.mse_loss = nn.MSELoss() 70 | 71 | # fix the encoder 72 | for name in ["enc_1", "enc_2", "enc_3", "enc_4"]: 73 | for param in getattr(self, name).parameters(): 74 | param.requires_grad = False 75 | 76 | # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image 77 | def encode_with_intermediate(self, input): 78 | results = [input] 79 | for i in range(4): 80 | func = getattr(self, "enc_{:d}".format(i + 1)) 81 | results.append(func(results[-1])) 82 | return results[1:] 83 | 84 | def calc_mean_std(self, feat, eps=1e-5): 85 | # eps is a small value added to the variance to avoid divide-by-zero. 86 | size = feat.size() 87 | assert len(size) == 4 88 | N, C = size[:2] # [N, C, H, W] 89 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 90 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 91 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 92 | return feat_mean, feat_std 93 | 94 | def adain(self, content_feat, style_feat): 95 | assert content_feat.size()[:2] == style_feat.size()[:2] 96 | size = content_feat.size() 97 | style_mean, style_std = self.calc_mean_std(style_feat) 98 | content_mean, content_std = self.calc_mean_std(content_feat) 99 | 100 | normalized_feat = ( 101 | content_feat - content_mean.expand(size) 102 | ) / content_std.expand(size) 103 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 104 | 105 | def forward(self, content, style, encoded_only=False): 106 | style_feats = self.encode_with_intermediate(style) 107 | content_feats = self.encode_with_intermediate(content) 108 | if encoded_only: 109 | return content_feats[-1], style_feats[-1] 110 | else: 111 | adain_feat = self.adain(content_feats[-1], style_feats[-1]) 112 | return adain_feat 113 | 114 | 115 | class Decoder(nn.Module): 116 | def __init__(self): 117 | super(Decoder, self).__init__() 118 | decoder = [ 119 | nn.ReflectionPad2d((1, 1, 1, 1)), 120 | nn.Conv2d(512, 256, (3, 3)), 121 | nn.ReLU(), # 256 122 | nn.Upsample(scale_factor=2, mode="nearest"), 123 | nn.ReflectionPad2d((1, 1, 1, 1)), 124 | nn.Conv2d(256, 256, (3, 3)), 125 | nn.ReLU(), 126 | nn.ReflectionPad2d((1, 1, 1, 1)), 127 | nn.Conv2d(256, 256, (3, 3)), 128 | nn.ReLU(), 129 | nn.ReflectionPad2d((1, 1, 1, 1)), 130 | nn.Conv2d(256, 256, (3, 3)), 131 | nn.ReLU(), 132 | nn.ReflectionPad2d((1, 1, 1, 1)), 133 | nn.Conv2d(256, 128, (3, 3)), 134 | nn.ReLU(), # 128 135 | nn.Upsample(scale_factor=2, mode="nearest"), 136 | nn.ReflectionPad2d((1, 1, 1, 1)), 137 | nn.Conv2d(128, 128, (3, 3)), 138 | nn.ReLU(), 139 | nn.ReflectionPad2d((1, 1, 1, 1)), 140 | nn.Conv2d(128, 64, (3, 3)), 141 | nn.ReLU(), # 64 142 | nn.Upsample(scale_factor=2, mode="nearest"), 143 | nn.ReflectionPad2d((1, 1, 1, 1)), 144 | nn.Conv2d(64, 64, (3, 3)), 145 | nn.ReLU(), 146 | nn.ReflectionPad2d((1, 1, 1, 1)), 147 | nn.Conv2d(64, 3, (3, 3)), 148 | ] 149 | self.decoder = nn.Sequential(*decoder) 150 | 151 | def forward(self, adain_feat): 152 | fake_image = self.decoder(adain_feat) 153 | 154 | return fake_image 155 | -------------------------------------------------------------------------------- /workflows/workflow_unist_video.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 17, 3 | "last_link_id": 33, 4 | "nodes": [ 5 | { 6 | "id": 16, 7 | "type": "VHS_VideoCombine", 8 | "pos": [ 9 | 544, 10 | 260 11 | ], 12 | "size": [ 13 | 315, 14 | 619 15 | ], 16 | "flags": {}, 17 | "order": 3, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "images", 22 | "type": "IMAGE", 23 | "link": 25 24 | }, 25 | { 26 | "name": "audio", 27 | "type": "AUDIO", 28 | "link": null 29 | }, 30 | { 31 | "name": "meta_batch", 32 | "type": "VHS_BatchManager", 33 | "link": null 34 | }, 35 | { 36 | "name": "vae", 37 | "type": "VAE", 38 | "link": null 39 | } 40 | ], 41 | "outputs": [ 42 | { 43 | "name": "Filenames", 44 | "type": "VHS_FILENAMES", 45 | "links": null, 46 | "shape": 3 47 | } 48 | ], 49 | "properties": { 50 | "Node name for S&R": "VHS_VideoCombine" 51 | }, 52 | "widgets_values": { 53 | "frame_rate": 20, 54 | "loop_count": 0, 55 | "filename_prefix": "StyleTransfer", 56 | "format": "video/h264-mp4", 57 | "pix_fmt": "yuv420p", 58 | "crf": 2, 59 | "save_metadata": true, 60 | "pingpong": false, 61 | "save_output": true, 62 | "videopreview": { 63 | "hidden": false, 64 | "paused": false, 65 | "params": { 66 | "filename": "StyleTransfer_00110.mp4", 67 | "subfolder": "", 68 | "type": "output", 69 | "format": "video/h264-mp4", 70 | "frame_rate": 20 71 | } 72 | } 73 | } 74 | }, 75 | { 76 | "id": 3, 77 | "type": "LoadImage", 78 | "pos": [ 79 | 152, 80 | 564 81 | ], 82 | "size": { 83 | "0": 315, 84 | "1": 314 85 | }, 86 | "flags": {}, 87 | "order": 0, 88 | "mode": 0, 89 | "outputs": [ 90 | { 91 | "name": "IMAGE", 92 | "type": "IMAGE", 93 | "links": [ 94 | 32 95 | ], 96 | "shape": 3, 97 | "slot_index": 0 98 | }, 99 | { 100 | "name": "MASK", 101 | "type": "MASK", 102 | "links": null, 103 | "shape": 3 104 | } 105 | ], 106 | "properties": { 107 | "Node name for S&R": "LoadImage" 108 | }, 109 | "widgets_values": [ 110 | "S4.jpg", 111 | "image" 112 | ] 113 | }, 114 | { 115 | "id": 14, 116 | "type": "VHS_LoadVideo", 117 | "pos": [ 118 | 238, 119 | 114 120 | ], 121 | "size": [ 122 | 235.1999969482422, 123 | 398.1134866403979 124 | ], 125 | "flags": {}, 126 | "order": 1, 127 | "mode": 0, 128 | "inputs": [ 129 | { 130 | "name": "meta_batch", 131 | "type": "VHS_BatchManager", 132 | "link": null 133 | }, 134 | { 135 | "name": "vae", 136 | "type": "VAE", 137 | "link": null 138 | } 139 | ], 140 | "outputs": [ 141 | { 142 | "name": "IMAGE", 143 | "type": "IMAGE", 144 | "links": [ 145 | 33 146 | ], 147 | "shape": 3, 148 | "slot_index": 0 149 | }, 150 | { 151 | "name": "frame_count", 152 | "type": "INT", 153 | "links": null, 154 | "shape": 3 155 | }, 156 | { 157 | "name": "audio", 158 | "type": "AUDIO", 159 | "links": null, 160 | "shape": 3 161 | }, 162 | { 163 | "name": "video_info", 164 | "type": "VHS_VIDEOINFO", 165 | "links": null, 166 | "shape": 3 167 | } 168 | ], 169 | "properties": { 170 | "Node name for S&R": "VHS_LoadVideo" 171 | }, 172 | "widgets_values": { 173 | "video": "cat_full.mp4", 174 | "force_rate": 0, 175 | "force_size": "Disabled", 176 | "custom_width": 512, 177 | "custom_height": 512, 178 | "frame_load_cap": 20, 179 | "skip_first_frames": 0, 180 | "select_every_nth": 1, 181 | "choose video to upload": "image", 182 | "videopreview": { 183 | "hidden": false, 184 | "paused": false, 185 | "params": { 186 | "frame_load_cap": 20, 187 | "skip_first_frames": 0, 188 | "force_rate": 0, 189 | "filename": "cat_full.mp4", 190 | "type": "input", 191 | "format": "video/mp4", 192 | "select_every_nth": 1 193 | } 194 | } 195 | } 196 | }, 197 | { 198 | "id": 15, 199 | "type": "UniST_Video", 200 | "pos": [ 201 | 542, 202 | 111 203 | ], 204 | "size": { 205 | "0": 315, 206 | "1": 102 207 | }, 208 | "flags": {}, 209 | "order": 2, 210 | "mode": 0, 211 | "inputs": [ 212 | { 213 | "name": "src_video", 214 | "type": "IMAGE", 215 | "link": 33 216 | }, 217 | { 218 | "name": "style_video", 219 | "type": "IMAGE", 220 | "link": 32 221 | } 222 | ], 223 | "outputs": [ 224 | { 225 | "name": "out_img", 226 | "type": "IMAGE", 227 | "links": [ 228 | 25 229 | ], 230 | "shape": 3, 231 | "slot_index": 0 232 | } 233 | ], 234 | "properties": { 235 | "Node name for S&R": "UniST_Video" 236 | }, 237 | "widgets_values": [ 238 | true, 239 | 768 240 | ] 241 | } 242 | ], 243 | "links": [ 244 | [ 245 | 25, 246 | 15, 247 | 0, 248 | 16, 249 | 0, 250 | "IMAGE" 251 | ], 252 | [ 253 | 32, 254 | 3, 255 | 0, 256 | 15, 257 | 1, 258 | "IMAGE" 259 | ], 260 | [ 261 | 33, 262 | 14, 263 | 0, 264 | 15, 265 | 0, 266 | "IMAGE" 267 | ] 268 | ], 269 | "groups": [], 270 | "config": {}, 271 | "extra": { 272 | "ds": { 273 | "scale": 1, 274 | "offset": [ 275 | 258.79825358315156, 276 | -25.984416674369626 277 | ] 278 | } 279 | }, 280 | "version": 0.4 281 | } -------------------------------------------------------------------------------- /module_neural_neighbor/colorization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from .imagePyramid import dec_lap_pyr as dec_pyr 5 | from .zca import zca_tensor 6 | 7 | 8 | def linear_2_oklab(x): 9 | """Converts pytorch tensor 'x' from Linear to OkLAB colorspace, described here: 10 | https://bottosson.github.io/posts/oklab/ 11 | Inputs: 12 | x -- pytorch tensor of size B x 3 x H x W, assumed to be in linear 13 | srgb colorspace, scaled between 0. and 1. 14 | Returns: 15 | y -- pytorch tensor of size B x 3 x H x W in OkLAB colorspace 16 | """ 17 | assert x.size(1) == 3, "attempted to convert colorspace of tensor w/ > 3 channels" 18 | 19 | x = torch.clamp(x, 0.0, 1.0) 20 | 21 | r = x[:, 0:1, :, :] 22 | g = x[:, 1:2, :, :] 23 | b = x[:, 2:3, :, :] 24 | 25 | li = 0.4121656120 * r + 0.5362752080 * g + 0.0514575653 * b 26 | m = 0.2118591070 * r + 0.6807189584 * g + 0.1074065790 * b 27 | s = 0.0883097947 * r + 0.2818474174 * g + 0.6302613616 * b 28 | 29 | li = torch.pow(li, 1.0 / 3.0) 30 | m = torch.pow(m, 1.0 / 3.0) 31 | s = torch.pow(s, 1.0 / 3.0) 32 | 33 | L = 0.2104542553 * li + 0.7936177850 * m - 0.0040720468 * s 34 | A = 1.9779984951 * li - 2.4285922050 * m + 0.4505937099 * s 35 | B = 0.0259040371 * li + 0.7827717662 * m - 0.8086757660 * s 36 | 37 | y = torch.cat([L, A, B], 1) 38 | return y 39 | 40 | 41 | def oklab_2_linear(x): 42 | """Converts pytorch tensor 'x' from OkLAB to Linear colorspace, described here: 43 | https://bottosson.github.io/posts/oklab/ 44 | Inputs: 45 | x -- pytorch tensor of size B x 3 x H x W, assumed to be in OkLAB colorspace 46 | Returns: 47 | y -- pytorch tensor of size B x 3 x H x W in Linear sRGB colorspace 48 | """ 49 | assert x.size(1) == 3, "attempted to convert colorspace of tensor w/ > 3 channels" 50 | 51 | L = x[:, 0:1, :, :] 52 | A = x[:, 1:2, :, :] 53 | B = x[:, 2:3, :, :] 54 | 55 | li = L + 0.3963377774 * A + 0.2158037573 * B 56 | m = L - 0.1055613458 * A - 0.0638541728 * B 57 | s = L - 0.0894841775 * A - 1.2914855480 * B 58 | 59 | li = torch.pow(li, 3) 60 | m = torch.pow(m, 3) 61 | s = torch.pow(s, 3) 62 | 63 | r = 4.0767245293 * li - 3.3072168827 * m + 0.2307590544 * s 64 | g = -1.2681437731 * li + 2.6093323231 * m - 0.3411344290 * s 65 | b = -0.0041119885 * li - 0.7034763098 * m + 1.7068625689 * s 66 | 67 | y = torch.cat([r, g, b], 1) 68 | return torch.clamp(y, 0.0, 1.0) 69 | 70 | 71 | def get_pad(x): 72 | """ 73 | Applies 1 pixel of replication padding to x 74 | x -- B x D x H x W pytorch tensor 75 | """ 76 | return F.pad(x, (1, 1, 1, 1), mode="replicate") 77 | 78 | 79 | def filter(x): 80 | """ 81 | applies modified bilateral filter to AB channels of x, guided by L channel 82 | x -- B x 3 x H x W pytorch tensor containing an image in LAB colorspace 83 | """ 84 | 85 | h = x.size(2) 86 | w = x.size(3) 87 | 88 | # Seperate out luminance channel, don't use AB channels to measure similarity 89 | xl = x[:, :1, :, :] 90 | xab = x[:, 1:, :, :] 91 | xl_pad = get_pad(xl) 92 | 93 | xl_w = {} 94 | for i in range(3): 95 | for j in range(3): 96 | xl_w[str(i) + str(j)] = xl_pad[:, :, i : (i + h), j : (j + w)] 97 | 98 | # Iteratively apply in 3x3 window rather than use spatial kernel 99 | max_iters = 5 100 | cur = torch.zeros_like(xab) 101 | 102 | # comparison function for pixel intensity 103 | def comp(x, y): 104 | d = torch.abs(x - y) * 5.0 105 | return torch.pow(torch.exp(-1.0 * d), 2) 106 | 107 | # apply bilateral filtering to AB channels, guideded by L channel 108 | cur = xab.clone() 109 | for it in range(max_iters): 110 | cur_pad = get_pad(cur) 111 | xl_v = {} 112 | for i in range(3): 113 | for j in range(3): 114 | xl_v[str(i) + str(j)] = cur_pad[:, :, i : (i + h), j : (j + w)] 115 | 116 | denom = torch.zeros_like(xl) 117 | cur = cur * 0.0 118 | 119 | for i in range(3): 120 | for j in range(3): 121 | scl = comp(xl, xl_w[str(i) + str(j)]) 122 | cur = cur + xl_v[str(i) + str(j)] * scl 123 | denom = denom + scl 124 | 125 | cur = cur / denom 126 | # store result and return 127 | x[:, 1:, :, :] = cur 128 | return x 129 | 130 | 131 | def clamp_range(x, y): 132 | """ 133 | clamp the range of x to [min(y), max(y)] 134 | x -- pytorch tensor 135 | y -- pytorch tensor 136 | """ 137 | return torch.clamp(x, y.min(), y.max()) 138 | 139 | 140 | def color_match(c, s, o, moment_only=False): 141 | """ 142 | Constrain the low frequences of the AB channels of output image 'o' (containing hue and saturation) 143 | to be an affine transformation of 'c' matching the mean and covariance of the style image 's'. 144 | Compared to the raw output of optimization this is highly constrained, but in practice 145 | we find the benefit to robustness to be worth the reduced stylization. 146 | c -- B x 3 x H x W pytorch tensor containing content image 147 | s -- B x 3 x H x W pytorch tensor containing style image 148 | o -- B x 3 x H x W pytorch tensor containing initial output image 149 | moment_only -- boolean, prevents applying bilateral filter to AB channels of final output to match luminance's edges 150 | """ 151 | c = torch.clamp(c, 0.0, 1.0) 152 | s = torch.clamp(s, 0.0, 1.0) 153 | o = torch.clamp(o, 0.0, 1.0) 154 | 155 | x = linear_2_oklab(c) 156 | # x_flat = x.view(x.size(0), x.size(1), -1, 1) 157 | y = linear_2_oklab(s) 158 | o = linear_2_oklab(o) 159 | 160 | x_new = o.clone() 161 | for i in range(3): 162 | x_new[:, i : i + 1, :, :] = clamp_range( 163 | x_new[:, i : i + 1, :, :], y[:, i : i + 1, :, :] 164 | ) 165 | 166 | _, cov_s = zca_tensor(x_new, y) 167 | 168 | if moment_only or cov_s[1:, 1:].abs().max() < 6e-5: 169 | x_new[:, 1:, :, :] = o[:, 1:, :, :] 170 | x_new, _ = zca_tensor(x_new, y) 171 | else: 172 | x_new[:, 1:, :, :] = x[:, 1:, :, :] 173 | x_new[:, 1:, :, :] = zca_tensor(x_new[:, 1:, :, :], y[:, 1:, :, :])[0] 174 | x_new = filter(x_new) 175 | 176 | for i in range(3): 177 | x_new[:, i : i + 1, :, :] = clamp_range( 178 | x_new[:, i : i + 1, :, :], y[:, i : i + 1, :, :] 179 | ) 180 | 181 | # x_pyr = dec_pyr(x,4) 182 | # y_pyr = dec_pyr(y,4) 183 | x_new_pyr = dec_pyr(x_new, 4) 184 | o_pyr = dec_pyr(o, 4) 185 | x_new_pyr[:-1] = o_pyr[:-1] 186 | 187 | return oklab_2_linear(x_new) 188 | -------------------------------------------------------------------------------- /module_aespa/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def init_weights(net): 9 | for m in net.modules(): 10 | if isinstance(m, nn.Conv2d): 11 | m.weight.data.normal_(0, 0.02) 12 | if m.bias is not None: 13 | m.bias.data.zero_() 14 | elif isinstance(m, nn.ConvTranspose2d): 15 | m.weight.data.normal_(0, 0.02) 16 | if m.bias is not None: 17 | m.bias.data.zero_() 18 | elif isinstance(m, nn.Linear): 19 | m.weight.data.normal_(0, 0.02) 20 | if m.bias is not None: 21 | m.bias.data.zero_() 22 | 23 | 24 | def feature_wct_simple(content_feat, style_feat, alpha=1): 25 | target_feature = Bw_wct_core(content_feat, style_feat) 26 | 27 | target_feature = target_feature.view_as(content_feat) 28 | target_feature = alpha * target_feature + (1 - alpha) * content_feat 29 | return target_feature 30 | 31 | 32 | def adaptive_get_keys(feat_skips, start_layer_idx, last_layer_idx, target_feat): 33 | B, C, th, tw = target_feat.shape 34 | results = [] 35 | target_conv_layer = "conv" + str(last_layer_idx) + "_1" 36 | _, _, h, w = feat_skips[target_conv_layer].shape 37 | for i in range(start_layer_idx, last_layer_idx + 1): 38 | target_conv_layer = "conv" + str(i) + "_1" 39 | if i == last_layer_idx: 40 | results.append(mean_variance_norm(feat_skips[target_conv_layer])) 41 | else: 42 | results.append( 43 | mean_variance_norm(F.interpolate(feat_skips[target_conv_layer], (h, w))) 44 | ) 45 | 46 | return F.interpolate(torch.cat(results, dim=1), (th, tw)) 47 | 48 | 49 | def adaptive_gram_weight(image, level, ratio, encoder): 50 | if level == 0: 51 | encoded_features = image 52 | else: 53 | encoded_features = encoder.get_features(image, level) # B x C x W x H 54 | global_gram = gram_matrix(encoded_features) 55 | 56 | B, C, w, h = encoded_features.size() 57 | target_w, target_h = w // ratio, h // ratio 58 | 59 | patches = extract_image_patches(encoded_features, target_w, target_h) 60 | _, patches_num, _, _, _ = patches.size() 61 | cos = torch.nn.CosineSimilarity(eps=1e-6) 62 | 63 | intra_gram_statistic = [] 64 | inter_gram_statistic = [] 65 | comb = torch.combinations(torch.arange(patches_num), r=2) 66 | if patches_num >= 10: 67 | sampling_num = int(comb.size(0) * 0.05) 68 | else: 69 | sampling_num = comb.size(0) 70 | for idx in range(B): 71 | if patches_num < 2: 72 | continue 73 | cos_gram = [] 74 | 75 | for patch in range(0, patches_num): 76 | cos_gram.append( 77 | cos(global_gram, gram_matrix(patches[idx][patch].unsqueeze(0))) 78 | .mean() 79 | .item() 80 | ) 81 | 82 | intra_gram_statistic.append(torch.tensor(cos_gram)) 83 | 84 | cos_gram = [] 85 | for idxes in random.choices(list(comb), k=sampling_num): 86 | cos_gram.append( 87 | cos( 88 | gram_matrix(patches[idx][idxes[0]].unsqueeze(0)), 89 | gram_matrix(patches[idx][idxes[1]].unsqueeze(0)), 90 | ) 91 | .mean() 92 | .item() 93 | ) 94 | 95 | inter_gram_statistic.append(torch.tensor(cos_gram)) 96 | 97 | intra_gram_statistic = torch.stack(intra_gram_statistic).mean(dim=1) 98 | inter_gram_statistic = torch.stack(inter_gram_statistic).mean(dim=1) 99 | results = (intra_gram_statistic + inter_gram_statistic) / 2 100 | 101 | ##For boosting value 102 | results = 1 / (1 + torch.exp(-10 * (results - 0.6))) 103 | 104 | return results 105 | 106 | 107 | def Bw_wct_core(content_feat, style_feat, weight=1, registers=None, device="cpu"): 108 | N, C, H, W = content_feat.size() 109 | cont_min = content_feat.min().item() 110 | cont_max = content_feat.max().item() 111 | 112 | whiten_cF, _, _ = SwitchWhiten2d(content_feat) 113 | _, wm_s, s_mean = SwitchWhiten2d(style_feat) 114 | 115 | targetFeature = torch.bmm(torch.inverse(wm_s), whiten_cF) 116 | targetFeature = targetFeature.view(N, C, H, W) 117 | targetFeature = targetFeature + s_mean.unsqueeze(2).expand_as(targetFeature) 118 | targetFeature.clamp_(cont_min, cont_max) 119 | 120 | return targetFeature 121 | 122 | 123 | def SwitchWhiten2d(x): 124 | N, C, H, W = x.size() 125 | 126 | in_data = x.view(N, C, -1) 127 | 128 | eye = in_data.data.new().resize_(C, C) 129 | eye = torch.nn.init.eye_(eye).view(1, C, C).expand(N, C, C) 130 | 131 | # calculate other statistics 132 | mean_in = in_data.mean(-1, keepdim=True) 133 | x_in = in_data - mean_in 134 | # (N x g) x C x C 135 | cov_in = torch.bmm(x_in, torch.transpose(x_in, 1, 2)).div(H * W) 136 | 137 | mean = mean_in 138 | cov = cov_in + 1e-5 * eye 139 | 140 | # perform whitening using Newton's iteration 141 | Ng, c, _ = cov.size() 142 | P = torch.eye(c).to(cov).expand(Ng, c, c) 143 | 144 | rTr = (cov * P).sum((1, 2), keepdim=True).reciprocal_() 145 | cov_N = cov * rTr 146 | for k in range(5): 147 | # P = torch.baddbmm(1.5, P, -0.5, torch.matrix_power(P, 3), cov_N) 148 | P = torch.baddbmm(P, torch.matrix_power(P, 3), cov_N, beta=1.5, alpha=-0.5) 149 | 150 | wm = P.mul_(rTr.sqrt()) 151 | x_hat = torch.bmm(wm, in_data - mean) 152 | 153 | return x_hat, wm, mean 154 | 155 | 156 | def mean_variance_norm(feat): 157 | size = feat.size() 158 | mean, std = calc_mean_std(feat) 159 | normalized_feat = (feat - mean.expand(size)) / std.expand(size) 160 | return normalized_feat 161 | 162 | 163 | def calc_mean_std(feat, eps=1e-5): 164 | # eps is a small value added to the variance to avoid divide-by-zero. 165 | size = feat.size() 166 | assert len(size) == 4 167 | N, C = size[:2] 168 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 169 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 170 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 171 | return feat_mean, feat_std 172 | 173 | 174 | def gram_matrix(y): 175 | (b, ch, h, w) = y.size() 176 | features = y.view(b, ch, w * h) 177 | features_t = features.transpose(1, 2) 178 | gram = features.bmm(features_t) / (ch * h * w) 179 | return gram 180 | 181 | 182 | def extract_image_patches(x, kernel, stride=1): 183 | b, c, h, w = x.shape 184 | 185 | # Extract patches 186 | patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride) 187 | patches = patches.contiguous().view(b, c, -1, kernel, kernel) 188 | patches = patches.permute(0, 2, 1, 3, 4).contiguous() 189 | 190 | # return patches.view(b, number_of_patches, c, h, w) 191 | return patches.view(b, -1, c, kernel, kernel) 192 | -------------------------------------------------------------------------------- /workflows/workflow_efdm.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 9, 3 | "last_link_id": 18, 4 | "nodes": [ 5 | { 6 | "id": 8, 7 | "type": "LoadImage", 8 | "pos": [ 9 | 242, 10 | 460 11 | ], 12 | "size": [ 13 | 315, 14 | 314.00001525878906 15 | ], 16 | "flags": {}, 17 | "order": 0, 18 | "mode": 0, 19 | "outputs": [ 20 | { 21 | "name": "IMAGE", 22 | "type": "IMAGE", 23 | "links": [ 24 | 12 25 | ], 26 | "shape": 3, 27 | "slot_index": 0 28 | }, 29 | { 30 | "name": "MASK", 31 | "type": "MASK", 32 | "links": null, 33 | "shape": 3 34 | } 35 | ], 36 | "properties": { 37 | "Node name for S&R": "LoadImage" 38 | }, 39 | "widgets_values": [ 40 | "popova.jpg", 41 | "image" 42 | ] 43 | }, 44 | { 45 | "id": 9, 46 | "type": "Reroute", 47 | "pos": [ 48 | 461, 49 | 279 50 | ], 51 | "size": [ 52 | 75, 53 | 26 54 | ], 55 | "flags": {}, 56 | "order": 4, 57 | "mode": 0, 58 | "inputs": [ 59 | { 60 | "name": "", 61 | "type": "*", 62 | "link": 18 63 | } 64 | ], 65 | "outputs": [ 66 | { 67 | "name": "", 68 | "type": "IMAGE", 69 | "links": [ 70 | 16 71 | ], 72 | "slot_index": 0 73 | } 74 | ], 75 | "properties": { 76 | "showOutputText": false, 77 | "horizontal": false 78 | } 79 | }, 80 | { 81 | "id": 2, 82 | "type": "LoadImage", 83 | "pos": [ 84 | 622, 85 | 519 86 | ], 87 | "size": { 88 | "0": 315, 89 | "1": 314 90 | }, 91 | "flags": {}, 92 | "order": 1, 93 | "mode": 0, 94 | "outputs": [ 95 | { 96 | "name": "IMAGE", 97 | "type": "IMAGE", 98 | "links": [ 99 | 9 100 | ], 101 | "shape": 3, 102 | "slot_index": 0 103 | }, 104 | { 105 | "name": "MASK", 106 | "type": "MASK", 107 | "links": null, 108 | "shape": 3 109 | } 110 | ], 111 | "properties": { 112 | "Node name for S&R": "LoadImage" 113 | }, 114 | "widgets_values": [ 115 | "006.jpg", 116 | "image" 117 | ] 118 | }, 119 | { 120 | "id": 4, 121 | "type": "SaveImage", 122 | "pos": [ 123 | 961, 124 | 282 125 | ], 126 | "size": { 127 | "0": 315, 128 | "1": 270 129 | }, 130 | "flags": {}, 131 | "order": 6, 132 | "mode": 0, 133 | "inputs": [ 134 | { 135 | "name": "images", 136 | "type": "IMAGE", 137 | "link": 11 138 | } 139 | ], 140 | "properties": {}, 141 | "widgets_values": [ 142 | "StyleTransfer" 143 | ] 144 | }, 145 | { 146 | "id": 7, 147 | "type": "ImageBatch", 148 | "pos": [ 149 | 249, 150 | 362 151 | ], 152 | "size": { 153 | "0": 210, 154 | "1": 46 155 | }, 156 | "flags": {}, 157 | "order": 3, 158 | "mode": 0, 159 | "inputs": [ 160 | { 161 | "name": "image1", 162 | "type": "IMAGE", 163 | "link": 13 164 | }, 165 | { 166 | "name": "image2", 167 | "type": "IMAGE", 168 | "link": 12 169 | } 170 | ], 171 | "outputs": [ 172 | { 173 | "name": "IMAGE", 174 | "type": "IMAGE", 175 | "links": [], 176 | "shape": 3, 177 | "slot_index": 0 178 | } 179 | ], 180 | "properties": { 181 | "Node name for S&R": "ImageBatch" 182 | } 183 | }, 184 | { 185 | "id": 3, 186 | "type": "LoadImage", 187 | "pos": [ 188 | -104, 189 | 293 190 | ], 191 | "size": { 192 | "0": 315, 193 | "1": 314 194 | }, 195 | "flags": {}, 196 | "order": 2, 197 | "mode": 0, 198 | "outputs": [ 199 | { 200 | "name": "IMAGE", 201 | "type": "IMAGE", 202 | "links": [ 203 | 13, 204 | 18 205 | ], 206 | "shape": 3, 207 | "slot_index": 0 208 | }, 209 | { 210 | "name": "MASK", 211 | "type": "MASK", 212 | "links": null, 213 | "shape": 3 214 | } 215 | ], 216 | "properties": { 217 | "Node name for S&R": "LoadImage" 218 | }, 219 | "widgets_values": [ 220 | "S4.jpg", 221 | "image" 222 | ] 223 | }, 224 | { 225 | "id": 6, 226 | "type": "EFDM", 227 | "pos": [ 228 | 623, 229 | 275 230 | ], 231 | "size": { 232 | "0": 315, 233 | "1": 198 234 | }, 235 | "flags": {}, 236 | "order": 5, 237 | "mode": 0, 238 | "inputs": [ 239 | { 240 | "name": "src_img", 241 | "type": "IMAGE", 242 | "link": 9 243 | }, 244 | { 245 | "name": "style_img", 246 | "type": "IMAGE", 247 | "link": 16 248 | } 249 | ], 250 | "outputs": [ 251 | { 252 | "name": "res_img", 253 | "type": "IMAGE", 254 | "links": [ 255 | 11 256 | ], 257 | "shape": 3, 258 | "slot_index": 0 259 | } 260 | ], 261 | "properties": { 262 | "Node name for S&R": "EFDM" 263 | }, 264 | "widgets_values": [ 265 | "0.7, 0.1", 266 | "efdm", 267 | 1, 268 | false, 269 | false, 270 | 512 271 | ] 272 | } 273 | ], 274 | "links": [ 275 | [ 276 | 9, 277 | 2, 278 | 0, 279 | 6, 280 | 0, 281 | "IMAGE" 282 | ], 283 | [ 284 | 11, 285 | 6, 286 | 0, 287 | 4, 288 | 0, 289 | "IMAGE" 290 | ], 291 | [ 292 | 12, 293 | 8, 294 | 0, 295 | 7, 296 | 1, 297 | "IMAGE" 298 | ], 299 | [ 300 | 13, 301 | 3, 302 | 0, 303 | 7, 304 | 0, 305 | "IMAGE" 306 | ], 307 | [ 308 | 16, 309 | 9, 310 | 0, 311 | 6, 312 | 1, 313 | "IMAGE" 314 | ], 315 | [ 316 | 18, 317 | 3, 318 | 0, 319 | 9, 320 | 0, 321 | "*" 322 | ] 323 | ], 324 | "groups": [], 325 | "config": {}, 326 | "extra": { 327 | "ds": { 328 | "scale": 1.1000000000000008, 329 | "offset": [ 330 | 173.09566993240344, 331 | -50.5718089715092 332 | ] 333 | } 334 | }, 335 | "version": 0.4 336 | } -------------------------------------------------------------------------------- /workflows/workflow_aesfa_w_blend.json: -------------------------------------------------------------------------------- 1 | { 2 | "last_node_id": 13, 3 | "last_link_id": 34, 4 | "nodes": [ 5 | { 6 | "id": 9, 7 | "type": "AESFA", 8 | "pos": [ 9 | 678, 10 | 313 11 | ], 12 | "size": { 13 | "0": 315, 14 | "1": 102 15 | }, 16 | "flags": {}, 17 | "order": 5, 18 | "mode": 0, 19 | "inputs": [ 20 | { 21 | "name": "src_img", 22 | "type": "IMAGE", 23 | "link": 33 24 | }, 25 | { 26 | "name": "style_img", 27 | "type": "IMAGE", 28 | "link": 30 29 | } 30 | ], 31 | "outputs": [ 32 | { 33 | "name": "out_img", 34 | "type": "IMAGE", 35 | "links": [], 36 | "shape": 3, 37 | "slot_index": 0 38 | } 39 | ], 40 | "properties": { 41 | "Node name for S&R": "AESFA" 42 | }, 43 | "widgets_values": [ 44 | true, 45 | 1024 46 | ] 47 | }, 48 | { 49 | "id": 11, 50 | "type": "LoadImage", 51 | "pos": [ 52 | 682, 53 | 672 54 | ], 55 | "size": { 56 | "0": 315, 57 | "1": 314 58 | }, 59 | "flags": {}, 60 | "order": 0, 61 | "mode": 0, 62 | "outputs": [ 63 | { 64 | "name": "IMAGE", 65 | "type": "IMAGE", 66 | "links": [ 67 | 25 68 | ], 69 | "shape": 3, 70 | "slot_index": 0 71 | }, 72 | { 73 | "name": "MASK", 74 | "type": "MASK", 75 | "links": null, 76 | "shape": 3 77 | } 78 | ], 79 | "properties": { 80 | "Node name for S&R": "LoadImage" 81 | }, 82 | "widgets_values": [ 83 | "popova.jpg", 84 | "image" 85 | ] 86 | }, 87 | { 88 | "id": 10, 89 | "type": "AesFAStyleBlend", 90 | "pos": [ 91 | 682, 92 | 501 93 | ], 94 | "size": { 95 | "0": 315, 96 | "1": 122 97 | }, 98 | "flags": {}, 99 | "order": 6, 100 | "mode": 0, 101 | "inputs": [ 102 | { 103 | "name": "src_img", 104 | "type": "IMAGE", 105 | "link": 34 106 | }, 107 | { 108 | "name": "style_hi", 109 | "type": "IMAGE", 110 | "link": 31 111 | }, 112 | { 113 | "name": "style_lo", 114 | "type": "IMAGE", 115 | "link": 25 116 | } 117 | ], 118 | "outputs": [ 119 | { 120 | "name": "out_img", 121 | "type": "IMAGE", 122 | "links": [ 123 | 26 124 | ], 125 | "shape": 3, 126 | "slot_index": 0 127 | } 128 | ], 129 | "properties": { 130 | "Node name for S&R": "AesFAStyleBlend" 131 | }, 132 | "widgets_values": [ 133 | true, 134 | 1024 135 | ] 136 | }, 137 | { 138 | "id": 12, 139 | "type": "Reroute", 140 | "pos": [ 141 | 539, 142 | 586 143 | ], 144 | "size": [ 145 | 75, 146 | 26 147 | ], 148 | "flags": {}, 149 | "order": 3, 150 | "mode": 0, 151 | "inputs": [ 152 | { 153 | "name": "", 154 | "type": "*", 155 | "link": 28 156 | } 157 | ], 158 | "outputs": [ 159 | { 160 | "name": "", 161 | "type": "IMAGE", 162 | "links": [ 163 | 30, 164 | 31 165 | ], 166 | "slot_index": 0 167 | } 168 | ], 169 | "properties": { 170 | "showOutputText": false, 171 | "horizontal": false 172 | } 173 | }, 174 | { 175 | "id": 3, 176 | "type": "LoadImage", 177 | "pos": [ 178 | 338, 179 | 668 180 | ], 181 | "size": { 182 | "0": 315, 183 | "1": 314 184 | }, 185 | "flags": {}, 186 | "order": 1, 187 | "mode": 0, 188 | "outputs": [ 189 | { 190 | "name": "IMAGE", 191 | "type": "IMAGE", 192 | "links": [ 193 | 28 194 | ], 195 | "shape": 3, 196 | "slot_index": 0 197 | }, 198 | { 199 | "name": "MASK", 200 | "type": "MASK", 201 | "links": null, 202 | "shape": 3 203 | } 204 | ], 205 | "properties": { 206 | "Node name for S&R": "LoadImage" 207 | }, 208 | "widgets_values": [ 209 | "S4.jpg", 210 | "image" 211 | ] 212 | }, 213 | { 214 | "id": 13, 215 | "type": "Reroute", 216 | "pos": [ 217 | 528.2286926786944, 218 | 402.29128892008794 219 | ], 220 | "size": [ 221 | 75, 222 | 26 223 | ], 224 | "flags": {}, 225 | "order": 4, 226 | "mode": 0, 227 | "inputs": [ 228 | { 229 | "name": "", 230 | "type": "*", 231 | "link": 32 232 | } 233 | ], 234 | "outputs": [ 235 | { 236 | "name": "", 237 | "type": "IMAGE", 238 | "links": [ 239 | 33, 240 | 34 241 | ], 242 | "slot_index": 0 243 | } 244 | ], 245 | "properties": { 246 | "showOutputText": false, 247 | "horizontal": false 248 | } 249 | }, 250 | { 251 | "id": 2, 252 | "type": "LoadImage", 253 | "pos": [ 254 | 174, 255 | 290 256 | ], 257 | "size": { 258 | "0": 315, 259 | "1": 314 260 | }, 261 | "flags": {}, 262 | "order": 2, 263 | "mode": 0, 264 | "outputs": [ 265 | { 266 | "name": "IMAGE", 267 | "type": "IMAGE", 268 | "links": [ 269 | 32 270 | ], 271 | "shape": 3, 272 | "slot_index": 0 273 | }, 274 | { 275 | "name": "MASK", 276 | "type": "MASK", 277 | "links": null, 278 | "shape": 3 279 | } 280 | ], 281 | "properties": { 282 | "Node name for S&R": "LoadImage" 283 | }, 284 | "widgets_values": [ 285 | "006.jpg", 286 | "image" 287 | ] 288 | }, 289 | { 290 | "id": 4, 291 | "type": "SaveImage", 292 | "pos": [ 293 | 1039, 294 | 287 295 | ], 296 | "size": [ 297 | 321.92869267869446, 298 | 328.691288920088 299 | ], 300 | "flags": {}, 301 | "order": 7, 302 | "mode": 0, 303 | "inputs": [ 304 | { 305 | "name": "images", 306 | "type": "IMAGE", 307 | "link": 26 308 | } 309 | ], 310 | "properties": {}, 311 | "widgets_values": [ 312 | "StyleTransfer" 313 | ] 314 | } 315 | ], 316 | "links": [ 317 | [ 318 | 25, 319 | 11, 320 | 0, 321 | 10, 322 | 2, 323 | "IMAGE" 324 | ], 325 | [ 326 | 26, 327 | 10, 328 | 0, 329 | 4, 330 | 0, 331 | "IMAGE" 332 | ], 333 | [ 334 | 28, 335 | 3, 336 | 0, 337 | 12, 338 | 0, 339 | "*" 340 | ], 341 | [ 342 | 30, 343 | 12, 344 | 0, 345 | 9, 346 | 1, 347 | "IMAGE" 348 | ], 349 | [ 350 | 31, 351 | 12, 352 | 0, 353 | 10, 354 | 1, 355 | "IMAGE" 356 | ], 357 | [ 358 | 32, 359 | 2, 360 | 0, 361 | 13, 362 | 0, 363 | "*" 364 | ], 365 | [ 366 | 33, 367 | 13, 368 | 0, 369 | 9, 370 | 0, 371 | "IMAGE" 372 | ], 373 | [ 374 | 34, 375 | 13, 376 | 0, 377 | 10, 378 | 0, 379 | "IMAGE" 380 | ] 381 | ], 382 | "groups": [], 383 | "config": {}, 384 | "extra": { 385 | "ds": { 386 | "scale": 0.9090909090909091, 387 | "offset": [ 388 | -109.12869267869443, 389 | -65.69128892008793 390 | ] 391 | } 392 | }, 393 | "version": 0.4 394 | } -------------------------------------------------------------------------------- /module_efdm/net.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from skimage.exposure import match_histograms 5 | 6 | from .function import adaptive_instance_normalization as adain 7 | from .function import adaptive_mean_normalization as adamean 8 | from .function import adaptive_std_normalization as adastd 9 | from .function import calc_mean_std 10 | from .function import exact_feature_distribution_matching as efdm 11 | from .function import histogram_matching as hm 12 | 13 | decoder = nn.Sequential( 14 | nn.ReflectionPad2d((1, 1, 1, 1)), 15 | nn.Conv2d(512, 256, (3, 3)), 16 | nn.ReLU(), 17 | nn.Upsample(scale_factor=2, mode="nearest"), 18 | nn.ReflectionPad2d((1, 1, 1, 1)), 19 | nn.Conv2d(256, 256, (3, 3)), 20 | nn.ReLU(), 21 | nn.ReflectionPad2d((1, 1, 1, 1)), 22 | nn.Conv2d(256, 256, (3, 3)), 23 | nn.ReLU(), 24 | nn.ReflectionPad2d((1, 1, 1, 1)), 25 | nn.Conv2d(256, 256, (3, 3)), 26 | nn.ReLU(), 27 | nn.ReflectionPad2d((1, 1, 1, 1)), 28 | nn.Conv2d(256, 128, (3, 3)), 29 | nn.ReLU(), 30 | nn.Upsample(scale_factor=2, mode="nearest"), 31 | nn.ReflectionPad2d((1, 1, 1, 1)), 32 | nn.Conv2d(128, 128, (3, 3)), 33 | nn.ReLU(), 34 | nn.ReflectionPad2d((1, 1, 1, 1)), 35 | nn.Conv2d(128, 64, (3, 3)), 36 | nn.ReLU(), 37 | nn.Upsample(scale_factor=2, mode="nearest"), 38 | nn.ReflectionPad2d((1, 1, 1, 1)), 39 | nn.Conv2d(64, 64, (3, 3)), 40 | nn.ReLU(), 41 | nn.ReflectionPad2d((1, 1, 1, 1)), 42 | nn.Conv2d(64, 3, (3, 3)), 43 | ) 44 | 45 | vgg = nn.Sequential( 46 | nn.Conv2d(3, 3, (1, 1)), 47 | nn.ReflectionPad2d((1, 1, 1, 1)), 48 | nn.Conv2d(3, 64, (3, 3)), 49 | nn.ReLU(), # relu1-1 50 | nn.ReflectionPad2d((1, 1, 1, 1)), 51 | nn.Conv2d(64, 64, (3, 3)), 52 | nn.ReLU(), # relu1-2 53 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 54 | nn.ReflectionPad2d((1, 1, 1, 1)), 55 | nn.Conv2d(64, 128, (3, 3)), 56 | nn.ReLU(), # relu2-1 57 | nn.ReflectionPad2d((1, 1, 1, 1)), 58 | nn.Conv2d(128, 128, (3, 3)), 59 | nn.ReLU(), # relu2-2 60 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 61 | nn.ReflectionPad2d((1, 1, 1, 1)), 62 | nn.Conv2d(128, 256, (3, 3)), 63 | nn.ReLU(), # relu3-1 64 | nn.ReflectionPad2d((1, 1, 1, 1)), 65 | nn.Conv2d(256, 256, (3, 3)), 66 | nn.ReLU(), # relu3-2 67 | nn.ReflectionPad2d((1, 1, 1, 1)), 68 | nn.Conv2d(256, 256, (3, 3)), 69 | nn.ReLU(), # relu3-3 70 | nn.ReflectionPad2d((1, 1, 1, 1)), 71 | nn.Conv2d(256, 256, (3, 3)), 72 | nn.ReLU(), # relu3-4 73 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 74 | nn.ReflectionPad2d((1, 1, 1, 1)), 75 | nn.Conv2d(256, 512, (3, 3)), 76 | nn.ReLU(), # relu4-1, this is the last layer used 77 | nn.ReflectionPad2d((1, 1, 1, 1)), 78 | nn.Conv2d(512, 512, (3, 3)), 79 | nn.ReLU(), # relu4-2 80 | nn.ReflectionPad2d((1, 1, 1, 1)), 81 | nn.Conv2d(512, 512, (3, 3)), 82 | nn.ReLU(), # relu4-3 83 | nn.ReflectionPad2d((1, 1, 1, 1)), 84 | nn.Conv2d(512, 512, (3, 3)), 85 | nn.ReLU(), # relu4-4 86 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 87 | nn.ReflectionPad2d((1, 1, 1, 1)), 88 | nn.Conv2d(512, 512, (3, 3)), 89 | nn.ReLU(), # relu5-1 90 | nn.ReflectionPad2d((1, 1, 1, 1)), 91 | nn.Conv2d(512, 512, (3, 3)), 92 | nn.ReLU(), # relu5-2 93 | nn.ReflectionPad2d((1, 1, 1, 1)), 94 | nn.Conv2d(512, 512, (3, 3)), 95 | nn.ReLU(), # relu5-3 96 | nn.ReflectionPad2d((1, 1, 1, 1)), 97 | nn.Conv2d(512, 512, (3, 3)), 98 | nn.ReLU(), # relu5-4 99 | ) 100 | 101 | 102 | class Net(nn.Module): 103 | def __init__(self, encoder, decoder, style): 104 | super(Net, self).__init__() 105 | enc_layers = list(encoder.children()) 106 | self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1 107 | self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1 108 | self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1 109 | self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1 110 | self.decoder = decoder 111 | self.mse_loss = nn.MSELoss() 112 | self.style = style 113 | 114 | # fix the encoder 115 | for name in ["enc_1", "enc_2", "enc_3", "enc_4"]: 116 | for param in getattr(self, name).parameters(): 117 | param.requires_grad = False 118 | 119 | # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image 120 | def encode_with_intermediate(self, input): 121 | results = [input] 122 | for i in range(4): 123 | func = getattr(self, "enc_{:d}".format(i + 1)) 124 | results.append(func(results[-1])) 125 | return results[1:] 126 | 127 | # extract relu4_1 from input image 128 | def encode(self, input): 129 | for i in range(4): 130 | input = getattr(self, "enc_{:d}".format(i + 1))(input) 131 | return input 132 | 133 | def calc_content_loss(self, input, target): 134 | assert input.size() == target.size() 135 | assert target.requires_grad is False 136 | return self.mse_loss(input, target) 137 | 138 | def calc_style_loss(self, input, target): 139 | # ipdb.set_trace() 140 | assert input.size() == target.size() 141 | assert ( 142 | target.requires_grad is False 143 | ) ## first make sure which one require gradient and which one do not. 144 | # print(input.requires_grad) ## True 145 | input_mean, input_std = calc_mean_std(input) 146 | target_mean, target_std = calc_mean_std(target) 147 | if self.style == "adain": 148 | return self.mse_loss(input_mean, target_mean) + self.mse_loss( 149 | input_std, target_std 150 | ) 151 | elif self.style == "adamean": 152 | return self.mse_loss(input_mean, target_mean) 153 | elif self.style == "adastd": 154 | return self.mse_loss(input_std, target_std) 155 | elif self.style == "efdm": 156 | B, C, W, H = input.size(0), input.size(1), input.size(2), input.size(3) 157 | value_content, index_content = torch.sort(input.view(B, C, -1)) 158 | value_style, index_style = torch.sort(target.view(B, C, -1)) 159 | inverse_index = index_content.argsort(-1) 160 | return self.mse_loss( 161 | input.view(B, C, -1), value_style.gather(-1, inverse_index) 162 | ) 163 | elif self.style == "hm": 164 | B, C, W, H = input.size(0), input.size(1), input.size(2), input.size(3) 165 | x_view = input.view(-1, W, H) 166 | image1_temp = match_histograms( 167 | np.array(x_view.detach().clone().cpu().float().transpose(0, 2)), 168 | np.array( 169 | target.view(-1, W, H).detach().clone().cpu().float().transpose(0, 2) 170 | ), 171 | multichannel=True, 172 | ) 173 | image1_temp = ( 174 | torch.from_numpy(image1_temp) 175 | .float() 176 | .to(input.device) 177 | .transpose(0, 2) 178 | .view(B, C, W, H) 179 | ) 180 | return self.mse_loss(input.reshape(B, C, -1), image1_temp.reshape(B, C, -1)) 181 | else: 182 | raise NotImplementedError 183 | 184 | def forward(self, content, style, alpha=1.0): 185 | assert 0 <= alpha <= 1 186 | # ipdb.set_trace() 187 | style_feats = self.encode_with_intermediate(style) 188 | content_feat = self.encode(content) 189 | # print(content_feat.requires_grad) False 190 | # print(style_feats[-1].requires_grad) False 191 | if self.style == "adain": 192 | t = adain(content_feat, style_feats[-1]) 193 | elif self.style == "adamean": 194 | t = adamean(content_feat, style_feats[-1]) 195 | elif self.style == "adastd": 196 | t = adastd(content_feat, style_feats[-1]) 197 | elif self.style == "efdm": 198 | t = efdm(content_feat, style_feats[-1]) 199 | elif self.style == "hm": 200 | t = hm(content_feat, style_feats[-1]) 201 | else: 202 | raise NotImplementedError 203 | t = alpha * t + (1 - alpha) * content_feat 204 | 205 | g_t = self.decoder(t) 206 | g_t_feats = self.encode_with_intermediate(g_t) 207 | 208 | loss_c = self.calc_content_loss( 209 | g_t_feats[-1], t 210 | ) ### final feature should be the same. 211 | loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0]) 212 | for i in range(1, 4): 213 | loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i]) 214 | return loss_c, loss_s 215 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ComfyUI-StyleTransferPlus 2 | Advance Non-Diffusion-based Style Transfer in ComfyUI 3 | 4 | Click name to jump to workflow 5 | 6 | 1. [**Neural Neighbor**](#neural-neighbor). Paper: [Neural Neighbor Style Transfer](https://github.com/nkolkin13/NeuralNeighborStyleTransfer) 7 | 2. [**CAST**](#cast). Paper: [Domain Enhanced Arbitrary Image Style Transfer via Contrastive Learning](https://github.com/zyxElsa/CAST_pytorch) 8 | 3. [**EFDM**](#efdm). Paper: [Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization](https://github.com/YBZh/EFDM) 9 | 4. [**MicroAST**](#microast). Paper: [Towards Super-Fast Ultra-Resolution Arbitrary Style Transfer](https://github.com/EndyWon/MicroAST) 10 | 5. [**UniST**](#unist). Paper: [Two Birds, One Stone: A Unified Framework for Joint Learning of Image and Video Style Transfers](https://github.com/NevSNev/UniST) 11 | 6. [**AesPA-Net**](#aespa-net). Paper: [Aesthetic Pattern-Aware Style Transfer Networks](https://github.com/Kibeom-Hong/AesPA-Net) 12 | 7. [**TSSAT**](#tssat). Paper: [Two-Stage Statistics-Aware Transformation for Artistic Style Transfer](https://github.com/HalbertCH/TSSAT) 13 | 8. [**AesFA**](#aesfa). Paper: [An Aesthetic Feature-Aware Arbitrary Neural Style Transfer](https://github.com/Sooyyoungg/AesFA) 14 | 15 | 16 | [**Extra nodes**](#extra-nodes) 17 | 18 | More will come soon 19 | 20 | ## Workflows 21 | 22 | All nodes support batched input (i.e video) but is generally not recommended. One should generate 1 or 2 style frames (start and end), then use [ComfyUI-EbSynth](https://github.com/FuouM/ComfyUI-EbSynth) to propagate the style to the entire video. 23 | 24 | | Neural Neighbor (2020-2022) - Slowest | CAST (2022-2023) - Fast | 25 | | :-----------------------------------------------------: | :------------------------------------------: | 26 | | ![neural_neighbor](example_outputs/neural_neighbor.png) | ![cast](example_outputs/cast.png) | 27 | | **EFDM (2022) - Fast** | **MicroAST (2023) - Fast** | 28 | | ![efdm](example_outputs/efdm.png) | ![microast](example_outputs/microast.png) | 29 | | **UniST (2023) - Medium - Only squares** | **AesPA-Net (2023) - Medium - Only squares** | 30 | | ![efdm](example_outputs/unist_image.png) | ![aespa](example_outputs/aespa.png) | 31 | | **TSSAT (2023) - Slow** | **AesFA (2024) - Fast - Only squares** | 32 | | ![tssat](example_outputs/tssat.png) | ![aesfa](example_outputs/aesfa.png) | 33 | 34 | ### Neural Neighbor 35 | 36 | Arguments: 37 | - `size`: Either `512` or `1024`, to be used for scaling the images. 1024 takes ~6-12 GB of VRAM. Defaults to `512` 38 | - `scale_long`: Scale by the longest (or shortest side) to `size`. Defaults to `True` 39 | - `flip`: Augment style image with rotations. Slows down algorithm and increases memory requirement. Generally improves content preservation but hurts stylization slightly. Defaults to `False` 40 | - `content_loss`: Defaults to `True` 41 | > Use experimental content loss. The most common failure mode of our method is that colors will shift within an object creating highly visible artifacts, if that happens this flag can usually fix it, but currently it has some drawbacks which is why it isn't enabled by default [...]. One advantage of using this flag though is that [content_weight] can typically be set all the way to 0.0 and the content will remain recognizable. 42 | - `colorize`: Whether to apply color correction. Defaults to `True` 43 | - `content_weight`: weight = 1.0 corresponds to maximum content preservation, weight = 0.0 is maximum stylization (Defaults to 0.75) 44 | - `max_iter`: How long each optimization step should take. `size=512` does 4 scaling steps * max_iter. `size=1024` does 5. The longer, the better and sharper the result. 45 | 46 | See more information on the original repository. No model is needed. 47 | 48 | [workflow_neural_neighbor.json](workflows/workflow_neural_neighbor.json) 49 | 50 | ![wf_neural_neighbor](workflows/wf_neural_neighbor.png) 51 | 52 | ### CAST 53 | 54 | Download the CAST and UCAST models in the [Test](https://github.com/zyxElsa/CAST_pytorch?tab=readme-ov-file#test) section (Google Drive) (Unzip). 55 | 56 | Download the `vgg_normalized.pth` model in the [Train](https://github.com/zyxElsa/CAST_pytorch?tab=readme-ov-file#train) section 57 | 58 | ``` 59 | models/ 60 | │ .gitkeep 61 | │ vgg_normalised.pth 62 | ─CAST_model 63 | │ latest_net_AE.pth 64 | │ latest_net_Dec_A.pth # Safe to delete 65 | │ latest_net_Dec_B.pth 66 | │ test_opt.txt # Safe to delete 67 | │ 68 | └─UCAST_model 69 | latest_net_AE.pth 70 | latest_net_Dec_A.pth # Safe to delete 71 | latest_net_Dec_B.pth 72 | ``` 73 | 74 | [workflow_cast.json](workflows/workflow_cast.json) 75 | 76 | ![wf_cast](workflows/wf_cast.png) 77 | 78 | ### EFDM 79 | 80 | Download the `hm_decoder_iter_160000.pth` 81 | in https://github.com/YBZh/EFDM/tree/main/ArbitraryStyleTransfer/models and put it in 82 | `models/hm_decoder_iter_160000.pth` 83 | 84 | Download `vgg_normalized.pth` as instructed in [**CAST**](#cast) 85 | 86 | Arguments: 87 | - `style_interp_weights`: Weight for each style (only applicable for multiple styles). If empty, each style will have same weight (1/num_styles). Weights are automatically normalized so you can just enter anything 88 | - `model_arch`: `["adain", "adamean", "adastd", "efdm", "hm"]`, defaults to `"efdm"` 89 | - `style_strength`: From 0.00 to 1.00. The higher the stronger the style. Defaults to 1.00 90 | - `do_crop`: Whether to resize and then crop the content and style images to `size x size` or just resize. Defaults to False 91 | - `preserve_color`: Whether to preserve the color of the content image. Defaults to False 92 | - `size`: Size of the height of the images after resizing. The aspect ratio of content is kept, while the styles will be resized to match content's height and width. The higher the `size`, the better the result, but also the more VRAM it will take. You may try `use_cpu` to circumvent OOM. Tested with `size=2048` on 12GB VRAM. 93 | - `use_cpu`: Whether to use CPU or GPU for this. CPU takes very long! (default workflow = ~21 seconds). Defaults to False. 94 | 95 | [workflow_efdm.json](workflows/workflow_efdm.json) 96 | 97 | ![wf_efdm](workflows/wf_efdm.png) 98 | 99 | 100 | ### MicroAST 101 | 102 | Download models from https://github.com/EndyWon/MicroAST/tree/main/models and place them in the models folder 103 | 104 | ``` 105 | models/microast/content_encoder_iter_160000.pth.tar 106 | models/microast/decoder_iter_160000.pth.tar 107 | models/microast/modulator_iter_160000.pth.tar 108 | models/microast/style_encoder_iter_160000.pth.tar 109 | ``` 110 | 111 | Arguments: Similar to [EFDM](#efdm) 112 | 113 | [workflow_microast.json](workflows/workflow_microast.json) 114 | 115 | ![wf_microast](workflows/wf_microast.png) 116 | 117 | 118 | ### UniST 119 | 120 | Download `UniST_model.pt` in [Testing](https://github.com/NevSNev/UniST?tab=readme-ov-file#testing) and `vgg_r41.pth, dec_r41.pth` in [Training](https://github.com/NevSNev/UniST?tab=readme-ov-file#training) section and place them in 121 | 122 | ``` 123 | models/unist/UniST_model.pt 124 | models/unist/dec_r41.pth 125 | models/unist/vgg_r41.pth 126 | ``` 127 | 128 | The model only works with square images, so inputs are resized to `size x size`. `do_crop` will resize height, then crop to square. 129 | 130 | The Video node is "more native" than the Image node for batched images (video) inputs. The model works with `batch=3` (3 consecutive frames), so we split the video into such. 131 | 132 | | UniST | UniST Video | 133 | | :-----------------------------------------------------------: | :--------------------------------------------------------------: | 134 | | [workflow_unist_image.json](workflows/workflow_microast.json) | [workflow_unist_video.json](workflows/workflow_unist_video.json) | 135 | | ![wf_unist_image](workflows/wf_unist_image.png) | ![wf_microast](example_outputs/example_unist_video.png) | 136 | | | (Not a workflow-embedded image) | 137 | 138 | ### AesPA-Net 139 | 140 | Download models from [Usage](https://github.com/Kibeom-Hong/AesPA-Net?tab=readme-ov-file#usage) section. 141 | 142 | For the VGG model, download `vgg_normalised_conv5_1.pth` from https://github.com/pietrocarbo/deep-transfer/tree/master/models/autoencoder_vgg19/vgg19_5 143 | 144 | ``` 145 | models/aespa/dec_model.pth 146 | models/aespa/transformer_model.pth 147 | models/aespa/vgg_normalised_conv5_1.pth 148 | ``` 149 | 150 | The original authors provided a `vgg_normalised_conv5_1.t7` model, which can only be opened with torchfile, and pyTorch has dropped native support for it. It also doesn't work reliably on Windows. 151 | 152 | [workflow_aespa.json](workflows/workflow_aespa.json) 153 | 154 | ![wf_aespa](workflows/wf_aespa.png) 155 | 156 | 157 | ### TSSAT 158 | 159 | Download the `TSSAT-model.zip` in the [Model Testing](https://github.com/HalbertCH/TSSAT?tab=readme-ov-file#model-testing) section, and unzip it 160 | 161 | ``` 162 | models/tssat/decoder_iter_160000.pth 163 | 164 | # Outside folder, can be ignored from the zip if already downloaded from above 165 | models/vgg_normalised.pth 166 | ``` 167 | 168 | [workflow_tssat.json](workflows/workflow_tssat.json) 169 | 170 | ![wf_tssat](workflows/wf_tssat.png) 171 | 172 | 173 | ### AesFA 174 | 175 | Download `main.pth` in the [Getting Started](https://github.com/Sooyyoungg/AesFA?tab=readme-ov-file#getting-started) and place it in `models/aesfa/main.pth`. 176 | 177 | [workflow_aesfa_w_blend.json](workflows/workflow_aesfa_w_blend.json) 178 | 179 | ![wf_aesfa_w_blend](workflows/wf_aesfa_w_blend.png) 180 | 181 | Authors' Notes: 182 | 183 | > [...] style blending, i.e., using the low-frequency and high-frequency style information from different images. 184 | The style-transferred outputs finely keep the color information from the low-frequency image and change the texture information based on the high frequency image. 185 | 186 | > The vertical line-shape artifacts alongside the images are often observed. We reason that these appear because the content features are being convolved directly with the predicted aesthetic feature-aware kernels and biases in our model. In addition, the upsampling operation could be the ones that create artifacts. 187 | 188 | ## Extra nodes 189 | 190 | ### Coral Color Transfer 191 | 192 | ![extra_coral_color_transfer](workflows/extra/extra_coral_color_transfer.png) 193 | 194 | ## Credits 195 | 196 | ``` 197 | @inproceedings{zhang2020cast, 198 | author = {Zhang, Yuxin and Tang, Fan and Dong, Weiming and Huang, Haibin and Ma, Chongyang and Lee, Tong-Yee and Xu, Changsheng}, 199 | title = {Domain Enhanced Arbitrary Image Style Transfer via Contrastive Learning}, 200 | booktitle = {ACM SIGGRAPH}, 201 | year = {2022} 202 | } 203 | 204 | @article{zhang2023unified, 205 | title={A Unified Arbitrary Style Transfer Framework via Adaptive Contrastive Learning}, 206 | author={Zhang, Yuxin and Tang, Fan and Dong, Weiming and Huang, Haibin and Ma, Chongyang and Lee, Tong-Yee and Xu, Changsheng}, 207 | journal={ACM Transactions on Graphics}, 208 | year={2023}, 209 | publisher={ACM New York, NY} 210 | } 211 | ``` 212 | 213 | ``` 214 | @inproceedings{zhang2021exact, 215 | title={Exact Feature Distribution Matching for Arbitrary Style Transfer and Domain Generalization}, 216 | author={Zhang, Yabin and Li, Minghan and Li, Ruihuang and Jia, Kui and Zhang, Lei}, 217 | booktitle={CVPR}, 218 | year={2022} 219 | } 220 | ``` 221 | 222 | ``` 223 | @inproceedings{wang2023microast, 224 | title={MicroAST: Towards Super-Fast Ultra-Resolution Arbitrary Style Transfer}, 225 | author={Wang, Zhizhong and Zhao, Lei and Zuo, Zhiwen and Li, Ailin and Chen, Haibo and Xing, Wei and Lu, Dongming}, 226 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 227 | year={2023} 228 | } 229 | ``` 230 | 231 | ``` 232 | @InProceedings{Gu_2023_ICCV, 233 | author = {Gu, Bohai and Fan, Heng and Zhang, Libo}, 234 | title = {Two Birds, One Stone: A Unified Framework for Joint Learning of Image and Video Style Transfers}, 235 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 236 | month = {October}, 237 | year = {2023}, 238 | pages = {23545-23554} 239 | } 240 | ``` 241 | 242 | ``` 243 | @article{Hong2023AesPANetAP, 244 | title={AesPA-Net: Aesthetic Pattern-Aware Style Transfer Networks}, 245 | author={Kibeom Hong and Seogkyu Jeon and Junsoo Lee and Namhyuk Ahn and Kunhee Kim and Pilhyeon Lee and Daesik Kim and Youngjung Uh and Hyeran Byun}, 246 | journal={ArXiv}, 247 | year={2023}, 248 | volume={abs/2307.09724}, 249 | url={https://api.semanticscholar.org/CorpusID:259982728} 250 | } 251 | ``` -------------------------------------------------------------------------------- /module_unist/models.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from einops import rearrange 6 | 7 | from .transformer_layers import DecoderLayer_style, EncoderLayer, EncoderLayer_cross 8 | 9 | 10 | class Encoder_vgg(nn.Module): 11 | def __init__(self): 12 | super(Encoder_vgg, self).__init__() 13 | # vgg 14 | # 256 x 256 15 | self.conv1 = nn.Conv2d(3, 3, 1, 1, 0) 16 | self.reflecPad1 = nn.ReflectionPad2d((1, 1, 1, 1)) 17 | # 256 x 256 18 | 19 | self.conv2 = nn.Conv2d(3, 64, 3, 1, 0) 20 | self.relu2 = nn.ReLU(inplace=True) 21 | # 256 x 256 22 | 23 | self.reflecPad3 = nn.ReflectionPad2d((1, 1, 1, 1)) 24 | self.conv3 = nn.Conv2d(64, 64, 3, 1, 0) 25 | self.relu3 = nn.ReLU(inplace=True) 26 | # 256 x 256 27 | 28 | self.maxPool = nn.MaxPool2d(kernel_size=2, stride=2) 29 | # 128 x 128 30 | 31 | self.reflecPad4 = nn.ReflectionPad2d((1, 1, 1, 1)) 32 | self.conv4 = nn.Conv2d(64, 128, 3, 1, 0) 33 | self.relu4 = nn.ReLU(inplace=True) 34 | # 128 x 128 35 | 36 | self.reflecPad5 = nn.ReflectionPad2d((1, 1, 1, 1)) 37 | self.conv5 = nn.Conv2d(128, 128, 3, 1, 0) 38 | self.relu5 = nn.ReLU(inplace=True) 39 | # 128 x 128 x 128 40 | 41 | self.maxPool2 = nn.MaxPool2d(kernel_size=2, stride=2) 42 | # 64 x 64 43 | 44 | self.reflecPad6 = nn.ReflectionPad2d((1, 1, 1, 1)) 45 | self.conv6 = nn.Conv2d(128, 256, 3, 1, 0) 46 | self.relu6 = nn.ReLU(inplace=True) 47 | # 64 x 64 48 | 49 | self.reflecPad7 = nn.ReflectionPad2d((1, 1, 1, 1)) 50 | self.conv7 = nn.Conv2d(256, 256, 3, 1, 0) 51 | self.relu7 = nn.ReLU(inplace=True) 52 | # 64 x 64 53 | 54 | self.reflecPad8 = nn.ReflectionPad2d((1, 1, 1, 1)) 55 | self.conv8 = nn.Conv2d(256, 256, 3, 1, 0) 56 | self.relu8 = nn.ReLU(inplace=True) 57 | # 64 x 64 58 | 59 | self.reflecPad9 = nn.ReflectionPad2d((1, 1, 1, 1)) 60 | self.conv9 = nn.Conv2d(256, 256, 3, 1, 0) 61 | self.relu9 = nn.ReLU(inplace=True) 62 | # 64 x 64 63 | 64 | self.maxPool3 = nn.MaxPool2d(kernel_size=2, stride=2) 65 | # 32 x 32 66 | 67 | self.reflecPad10 = nn.ReflectionPad2d((1, 1, 1, 1)) 68 | self.conv10 = nn.Conv2d(256, 512, 3, 1, 0) 69 | self.relu10 = nn.ReLU(inplace=True) 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = self.reflecPad1(x) 74 | x = self.relu2(self.conv2(x)) 75 | x = self.reflecPad3(x) 76 | x = self.relu3(self.conv3(x)) 77 | 78 | x = self.maxPool(x) 79 | 80 | x = self.reflecPad4(x) 81 | x = self.relu4(self.conv4(x)) 82 | x = self.reflecPad5(x) 83 | x = self.relu5(self.conv5(x)) 84 | 85 | x = self.maxPool2(x) 86 | 87 | x = self.reflecPad6(x) 88 | x = self.relu6(self.conv6(x)) 89 | x = self.reflecPad7(x) 90 | x = self.relu7(self.conv7(x)) 91 | x = self.reflecPad8(x) 92 | x = self.relu8(self.conv8(x)) 93 | x = self.reflecPad9(x) 94 | x = self.relu9(self.conv9(x)) 95 | 96 | x = self.maxPool3(x) 97 | 98 | x = self.reflecPad10(x) 99 | x = self.relu10(self.conv10(x)) 100 | 101 | return x 102 | 103 | 104 | class Decoder_vgg(nn.Module): 105 | def __init__(self): 106 | super(Decoder_vgg, self).__init__() 107 | # decoder 108 | self.reflecPad11 = nn.ReflectionPad2d((1, 1, 1, 1)) 109 | self.conv11 = nn.Conv2d(512, 256, 3, 1, 0) 110 | self.relu11 = nn.ReLU(inplace=True) 111 | # 32 x 32 112 | 113 | self.unpool = nn.UpsamplingNearest2d(scale_factor=2) 114 | # 64 x 64 115 | 116 | self.reflecPad12 = nn.ReflectionPad2d((1, 1, 1, 1)) 117 | self.conv12 = nn.Conv2d(256, 256, 3, 1, 0) 118 | self.relu12 = nn.ReLU(inplace=True) 119 | # 64 x 64 120 | 121 | self.reflecPad13 = nn.ReflectionPad2d((1, 1, 1, 1)) 122 | self.conv13 = nn.Conv2d(256, 256, 3, 1, 0) 123 | self.relu13 = nn.ReLU(inplace=True) 124 | # 64 x 64 125 | 126 | self.reflecPad14 = nn.ReflectionPad2d((1, 1, 1, 1)) 127 | self.conv14 = nn.Conv2d(256, 256, 3, 1, 0) 128 | self.relu14 = nn.ReLU(inplace=True) 129 | # 64 x 64 130 | 131 | self.reflecPad15 = nn.ReflectionPad2d((1, 1, 1, 1)) 132 | self.conv15 = nn.Conv2d(256, 128, 3, 1, 0) 133 | self.relu15 = nn.ReLU(inplace=True) 134 | # 64 x 64 135 | 136 | self.unpool2 = nn.UpsamplingNearest2d(scale_factor=2) 137 | # 128 x 128 138 | 139 | self.reflecPad16 = nn.ReflectionPad2d((1, 1, 1, 1)) 140 | self.conv16 = nn.Conv2d(128, 128, 3, 1, 0) 141 | self.relu16 = nn.ReLU(inplace=True) 142 | # 128 x 128 143 | 144 | self.reflecPad17 = nn.ReflectionPad2d((1, 1, 1, 1)) 145 | self.conv17 = nn.Conv2d(128, 64, 3, 1, 0) 146 | self.relu17 = nn.ReLU(inplace=True) 147 | # 128 x 128 148 | 149 | self.unpool3 = nn.UpsamplingNearest2d(scale_factor=2) 150 | # 256 x 256 151 | 152 | self.reflecPad18 = nn.ReflectionPad2d((1, 1, 1, 1)) 153 | self.conv18 = nn.Conv2d(64, 64, 3, 1, 0) 154 | self.relu18 = nn.ReLU(inplace=True) 155 | # 256 x 256 156 | 157 | self.reflecPad19 = nn.ReflectionPad2d((1, 1, 1, 1)) 158 | self.conv19 = nn.Conv2d(64, 3, 3, 1, 0) 159 | 160 | def forward(self, x): 161 | # 128x128x128 162 | out = self.reflecPad11(x) 163 | out = self.conv11(out) 164 | out = self.relu11(out) 165 | 166 | out = self.unpool(out) 167 | 168 | out = self.reflecPad12(out) 169 | out = self.conv12(out) 170 | out = self.relu12(out) 171 | out = self.reflecPad13(out) 172 | out = self.conv13(out) 173 | out = self.relu13(out) 174 | out = self.reflecPad14(out) 175 | out = self.conv14(out) 176 | out = self.relu14(out) 177 | out = self.reflecPad15(out) 178 | out = self.conv15(out) 179 | out = self.relu15(out) 180 | out = self.unpool2(out) 181 | out = self.reflecPad16(out) 182 | out = self.conv16(out) 183 | out = self.relu16(out) 184 | out = self.reflecPad17(out) 185 | out = self.conv17(out) 186 | out = self.relu17(out) 187 | out = self.unpool3(out) 188 | out = self.reflecPad18(out) 189 | out = self.conv18(out) 190 | out = self.relu18(out) 191 | out = self.reflecPad19(out) 192 | out = self.conv19(out) 193 | return out 194 | 195 | 196 | class Encoder_video(nn.Module): 197 | """For Content Enc.""" 198 | 199 | def __init__(self, n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1): 200 | super().__init__() 201 | self.layer_stack = nn.ModuleList( 202 | [ 203 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 204 | for _ in range(n_layers) 205 | ] 206 | ) 207 | 208 | def forward(self, enc_output): 209 | for enc_layer in self.layer_stack: 210 | enc_output = enc_layer(enc_output) 211 | return enc_output 212 | 213 | 214 | class Encoder_image(nn.Module): 215 | """For Style Enc.""" 216 | 217 | def __init__(self, n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1): 218 | super().__init__() 219 | self.layer_stack = nn.ModuleList( 220 | [ 221 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 222 | for _ in range(n_layers) 223 | ] 224 | ) 225 | 226 | def forward(self, enc_output): 227 | for enc_layer in self.layer_stack: 228 | enc_output = enc_layer(enc_output) 229 | return enc_output 230 | 231 | 232 | class Encoder_cross(nn.Module): 233 | def __init__(self, n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1): 234 | super().__init__() 235 | self.layer_stack = nn.ModuleList( 236 | [ 237 | EncoderLayer_cross(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 238 | for _ in range(n_layers) 239 | ] 240 | ) 241 | 242 | def forward(self, dec_output, enc_output): 243 | # print(f"{dec_output.shape=}") 244 | # print(f"{enc_output.shape=}") 245 | for enc_layer in self.layer_stack: 246 | enc_output = enc_layer(dec_output, enc_output) 247 | return enc_output 248 | 249 | 250 | class Decoder_style_tranfer(nn.Module): 251 | def __init__(self, n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1): 252 | super().__init__() 253 | 254 | self.layer_stack = nn.ModuleList( 255 | [ 256 | DecoderLayer_style(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 257 | for _ in range(n_layers) 258 | ] 259 | ) 260 | 261 | def forward(self, dec_output, enc_output): 262 | # print(f"{dec_output.shape=}") 263 | # print(f"{enc_output.shape=}") 264 | for dec_layer in self.layer_stack: 265 | dec_output = dec_layer(dec_output, enc_output) 266 | 267 | _, l, _ = dec_output.size() 268 | dec_output = rearrange( 269 | dec_output, "b (m1 m2) c -> b c m1 m2", m1=int(math.sqrt(l)) 270 | ) 271 | 272 | return dec_output 273 | 274 | 275 | class Transformer(nn.Module): 276 | def __init__( 277 | self, 278 | encoder_path: str, 279 | decoder_path: str, 280 | d_model=512, 281 | d_inner=4096, 282 | n_layers=3, 283 | n_head=8, 284 | d_k=64, 285 | d_v=64, 286 | dropout=0.1, 287 | ): 288 | super().__init__() 289 | 290 | self.encoder_frames = Encoder_video( 291 | d_model=d_model, 292 | d_inner=d_inner, 293 | n_layers=2, 294 | n_head=n_head, 295 | d_k=d_k, 296 | d_v=d_v, 297 | dropout=dropout, 298 | ) 299 | 300 | self.encoder_style = Encoder_image( 301 | d_model=d_model, 302 | d_inner=d_inner, 303 | n_layers=1, 304 | n_head=n_head, 305 | d_k=d_k, 306 | d_v=d_v, 307 | dropout=dropout, 308 | ) 309 | 310 | self.image_cross_video = Encoder_cross( 311 | d_model=d_model, 312 | d_inner=d_inner, 313 | n_layers=1, 314 | n_head=n_head, 315 | d_k=d_k, 316 | d_v=d_v, 317 | dropout=dropout, 318 | ) 319 | 320 | self.video_cross_image = Encoder_cross( 321 | d_model=d_model, 322 | d_inner=d_inner, 323 | n_layers=1, 324 | n_head=n_head, 325 | d_k=d_k, 326 | d_v=d_v, 327 | dropout=dropout, 328 | ) 329 | 330 | self.style_decoder = Decoder_style_tranfer( 331 | d_model=d_model, 332 | d_inner=d_inner, 333 | n_layers=n_layers, 334 | n_head=n_head, 335 | d_k=d_k, 336 | d_v=d_v, 337 | dropout=dropout, 338 | ) 339 | 340 | self.encoder = Encoder_vgg() 341 | 342 | self.decoder = Decoder_vgg() 343 | 344 | for p in self.parameters(): 345 | if p.dim() > 1: 346 | nn.init.xavier_uniform_(p) 347 | 348 | print("Loading pretrained AutoEncoder") 349 | encoder_ckpt = torch.load(encoder_path, map_location="cpu") 350 | decoder_ckpt = torch.load(decoder_path, map_location="cpu") 351 | self.encoder.load_state_dict(encoder_ckpt, strict=True) 352 | self.decoder.load_state_dict(decoder_ckpt, strict=True) 353 | 354 | def forward( 355 | self, 356 | content, 357 | style_images, 358 | content_type="image", 359 | id_loss="transfer", 360 | tab=None, 361 | ): 362 | if id_loss == "transfer": 363 | b, t, c, h, w = content.size() 364 | content = content.view(b * t, c, h, w) 365 | enc_content = self.encoder(content) 366 | enc_content = rearrange(enc_content, "bt c h w ->bt (h w) c") 367 | enc_content = self.encoder_frames(enc_content) 368 | if content_type == "image": 369 | content_self_attn = self.image_cross_video(enc_content, enc_content) 370 | elif content_type == "video": 371 | content_self_attn = self.video_cross_image(enc_content, enc_content) 372 | enc_style = self.encoder.forward(style_images) 373 | enc_style = rearrange(enc_style, "b c h w ->b (h w) c") 374 | enc_style_image = self.encoder_style.forward(enc_style) 375 | dec_output = self.style_decoder.forward(content_self_attn, enc_style_image) 376 | style_result = self.decoder.forward(dec_output) 377 | return style_result 378 | -------------------------------------------------------------------------------- /module_aesfa/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | # import torch.nn.functional as F 5 | from .blocks import AdaOctConv, Oct_Conv_aftup, Oct_conv_lreLU, Oct_conv_up, OctConv 6 | 7 | # from blocks import * 8 | 9 | 10 | def define_network( 11 | net_type, alpha_in, alpha_out, style_kernel, input_nc, nf, output_nc, freq_ratio 12 | ): 13 | net = None 14 | sk = style_kernel 15 | 16 | if net_type == "Encoder": 17 | net = Encoder( 18 | in_dim=input_nc, 19 | nf=nf, 20 | style_kernel=[sk, sk], 21 | alpha_in=alpha_in, 22 | alpha_out=alpha_out, 23 | ) 24 | elif net_type == "Generator": 25 | net = Decoder( 26 | nf=nf, 27 | out_dim=output_nc, 28 | style_channel=256, 29 | style_kernel=[sk, sk, 3], 30 | alpha_in=alpha_in, 31 | freq_ratio=freq_ratio, 32 | alpha_out=alpha_out, 33 | ) 34 | return net 35 | 36 | 37 | class Encoder(nn.Module): 38 | def __init__(self, in_dim, nf=64, style_kernel=[3, 3], alpha_in=0.5, alpha_out=0.5): 39 | super(Encoder, self).__init__() 40 | 41 | self.conv = nn.Conv2d( 42 | in_channels=in_dim, out_channels=nf, kernel_size=7, stride=1, padding=3 43 | ) 44 | 45 | self.OctConv1_1 = OctConv( 46 | in_channels=nf, 47 | out_channels=nf, 48 | kernel_size=3, 49 | stride=2, 50 | padding=1, 51 | groups=64, 52 | alpha_in=alpha_in, 53 | alpha_out=alpha_out, 54 | type="first", 55 | ) 56 | self.OctConv1_2 = OctConv( 57 | in_channels=nf, 58 | out_channels=2 * nf, 59 | kernel_size=1, 60 | alpha_in=alpha_in, 61 | alpha_out=alpha_out, 62 | type="normal", 63 | ) 64 | self.OctConv1_3 = OctConv( 65 | in_channels=2 * nf, 66 | out_channels=2 * nf, 67 | kernel_size=3, 68 | stride=1, 69 | padding=1, 70 | alpha_in=alpha_in, 71 | alpha_out=alpha_out, 72 | type="normal", 73 | ) 74 | 75 | self.OctConv2_1 = OctConv( 76 | in_channels=2 * nf, 77 | out_channels=2 * nf, 78 | kernel_size=3, 79 | stride=2, 80 | padding=1, 81 | groups=128, 82 | alpha_in=alpha_in, 83 | alpha_out=alpha_out, 84 | type="normal", 85 | ) 86 | self.OctConv2_2 = OctConv( 87 | in_channels=2 * nf, 88 | out_channels=4 * nf, 89 | kernel_size=1, 90 | alpha_in=alpha_in, 91 | alpha_out=alpha_out, 92 | type="normal", 93 | ) 94 | self.OctConv2_3 = OctConv( 95 | in_channels=4 * nf, 96 | out_channels=4 * nf, 97 | kernel_size=3, 98 | stride=1, 99 | padding=1, 100 | alpha_in=alpha_in, 101 | alpha_out=alpha_out, 102 | type="normal", 103 | ) 104 | 105 | self.pool_h = nn.AdaptiveAvgPool2d((style_kernel[0], style_kernel[0])) 106 | self.pool_l = nn.AdaptiveAvgPool2d((style_kernel[1], style_kernel[1])) 107 | 108 | self.relu = Oct_conv_lreLU() 109 | 110 | def forward(self, x): 111 | enc_feat = [] 112 | out = self.conv(x) 113 | 114 | out = self.OctConv1_1(out) 115 | out = self.relu(out) 116 | out = self.OctConv1_2(out) 117 | out = self.relu(out) 118 | out = self.OctConv1_3(out) 119 | out = self.relu(out) 120 | enc_feat.append(out) 121 | 122 | out = self.OctConv2_1(out) 123 | out = self.relu(out) 124 | out = self.OctConv2_2(out) 125 | out = self.relu(out) 126 | out = self.OctConv2_3(out) 127 | out = self.relu(out) 128 | enc_feat.append(out) 129 | 130 | out_high, out_low = out 131 | out_sty_h = self.pool_h(out_high) 132 | out_sty_l = self.pool_l(out_low) 133 | out_sty = out_sty_h, out_sty_l 134 | 135 | return out, out_sty, enc_feat 136 | 137 | def forward_test(self, x, cond): 138 | out = self.conv(x) 139 | 140 | out = self.OctConv1_1(out) 141 | out = self.relu(out) 142 | out = self.OctConv1_2(out) 143 | out = self.relu(out) 144 | out = self.OctConv1_3(out) 145 | out = self.relu(out) 146 | 147 | out = self.OctConv2_1(out) 148 | out = self.relu(out) 149 | out = self.OctConv2_2(out) 150 | out = self.relu(out) 151 | out = self.OctConv2_3(out) 152 | out = self.relu(out) 153 | 154 | if cond == "style": 155 | out_high, out_low = out 156 | out_sty_h = self.pool_h(out_high) 157 | out_sty_l = self.pool_l(out_low) 158 | return out_sty_h, out_sty_l 159 | else: 160 | return out 161 | 162 | 163 | class Decoder(nn.Module): 164 | def __init__( 165 | self, 166 | nf=64, 167 | out_dim=3, 168 | style_channel=512, 169 | style_kernel=[3, 3, 3], 170 | alpha_in=0.5, 171 | alpha_out=0.5, 172 | freq_ratio=[1, 1], 173 | pad_type="reflect", 174 | ): 175 | super(Decoder, self).__init__() 176 | 177 | group_div = [1, 2, 4, 8] 178 | self.up_oct = Oct_conv_up(scale_factor=2) 179 | 180 | self.AdaOctConv1_1 = AdaOctConv( 181 | in_channels=4 * nf, 182 | out_channels=4 * nf, 183 | group_div=group_div[0], 184 | style_channels=style_channel, 185 | kernel_size=style_kernel, 186 | stride=1, 187 | padding=1, 188 | oct_groups=4 * nf, 189 | alpha_in=alpha_in, 190 | alpha_out=alpha_out, 191 | type="normal", 192 | ) 193 | self.OctConv1_2 = OctConv( 194 | in_channels=4 * nf, 195 | out_channels=2 * nf, 196 | kernel_size=1, 197 | stride=1, 198 | alpha_in=alpha_in, 199 | alpha_out=alpha_out, 200 | type="normal", 201 | ) 202 | self.oct_conv_aftup_1 = Oct_Conv_aftup( 203 | in_channels=2 * nf, 204 | out_channels=2 * nf, 205 | kernel_size=3, 206 | stride=1, 207 | padding=1, 208 | pad_type=pad_type, 209 | alpha_in=alpha_in, 210 | alpha_out=alpha_out, 211 | ) 212 | 213 | self.AdaOctConv2_1 = AdaOctConv( 214 | in_channels=2 * nf, 215 | out_channels=2 * nf, 216 | group_div=group_div[1], 217 | style_channels=style_channel, 218 | kernel_size=style_kernel, 219 | stride=1, 220 | padding=1, 221 | oct_groups=2 * nf, 222 | alpha_in=alpha_in, 223 | alpha_out=alpha_out, 224 | type="normal", 225 | ) 226 | self.OctConv2_2 = OctConv( 227 | in_channels=2 * nf, 228 | out_channels=nf, 229 | kernel_size=1, 230 | stride=1, 231 | alpha_in=alpha_in, 232 | alpha_out=alpha_out, 233 | type="normal", 234 | ) 235 | self.oct_conv_aftup_2 = Oct_Conv_aftup( 236 | nf, nf, 3, 1, 1, pad_type, alpha_in, alpha_out 237 | ) 238 | 239 | self.AdaOctConv3_1 = AdaOctConv( 240 | in_channels=nf, 241 | out_channels=nf, 242 | group_div=group_div[2], 243 | style_channels=style_channel, 244 | kernel_size=style_kernel, 245 | stride=1, 246 | padding=1, 247 | oct_groups=nf, 248 | alpha_in=alpha_in, 249 | alpha_out=alpha_out, 250 | type="normal", 251 | ) 252 | self.OctConv3_2 = OctConv( 253 | in_channels=nf, 254 | out_channels=nf // 2, 255 | kernel_size=1, 256 | stride=1, 257 | alpha_in=alpha_in, 258 | alpha_out=alpha_out, 259 | type="last", 260 | freq_ratio=freq_ratio, 261 | ) 262 | 263 | self.conv4 = nn.Conv2d(in_channels=nf // 2, out_channels=out_dim, kernel_size=1) 264 | 265 | def forward(self, content, style): 266 | out = self.AdaOctConv1_1(content, style) 267 | out = self.OctConv1_2(out) 268 | out = self.up_oct(out) 269 | out = self.oct_conv_aftup_1(out) 270 | 271 | out = self.AdaOctConv2_1(out, style) 272 | out = self.OctConv2_2(out) 273 | out = self.up_oct(out) 274 | out = self.oct_conv_aftup_2(out) 275 | 276 | out = self.AdaOctConv3_1(out, style) 277 | out = self.OctConv3_2(out) 278 | out, out_high, out_low = out 279 | 280 | out = self.conv4(out) 281 | out_high = self.conv4(out_high) 282 | out_low = self.conv4(out_low) 283 | 284 | return out, out_high, out_low 285 | 286 | def forward_test(self, content, style): 287 | out = self.AdaOctConv1_1(content, style, "test") 288 | out = self.OctConv1_2(out) 289 | out = self.up_oct(out) 290 | out = self.oct_conv_aftup_1(out) 291 | 292 | out = self.AdaOctConv2_1(out, style, "test") 293 | out = self.OctConv2_2(out) 294 | out = self.up_oct(out) 295 | out = self.oct_conv_aftup_2(out) 296 | 297 | out = self.AdaOctConv3_1(out, style, "test") 298 | out = self.OctConv3_2(out) 299 | 300 | out = self.conv4(out[0]) 301 | return out 302 | 303 | 304 | ############## Contrastive Loss function ############## 305 | def calc_mean_std(feat, eps=1e-5): 306 | # eps is a small value added to the variance to avoid divide-by-zero. 307 | size = feat.size() 308 | assert len(size) == 4 309 | N, C = size[:2] 310 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 311 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 312 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 313 | return feat_mean, feat_std 314 | 315 | 316 | def calc_content_loss(input, target): 317 | assert input.size() == target.size() 318 | mse_loss = nn.MSELoss() 319 | return mse_loss(input, target) 320 | 321 | 322 | def calc_style_loss(input, target): 323 | assert input.size() == target.size() 324 | mse_loss = nn.MSELoss() 325 | input_mean, input_std = calc_mean_std(input) 326 | target_mean, target_std = calc_mean_std(target) 327 | 328 | loss = mse_loss(input_mean, target_mean) + mse_loss(input_std, target_std) 329 | return loss 330 | 331 | 332 | class EFDM_loss(nn.Module): 333 | def __init__(self): 334 | super(EFDM_loss, self).__init__() 335 | self.mse_loss = nn.MSELoss() 336 | 337 | def efdm_single(self, style, trans): 338 | B, C, W, H = style.size(0), style.size(1), style.size(2), style.size(3) 339 | 340 | value_style, index_style = torch.sort(style.view(B, C, -1)) 341 | value_trans, index_trans = torch.sort(trans.view(B, C, -1)) 342 | inverse_index = index_trans.argsort(-1) 343 | 344 | return self.mse_loss( 345 | trans.view(B, C, -1), value_style.gather(-1, inverse_index) 346 | ) 347 | 348 | def forward(self, style_E, style_S, translate_E, translate_S, neg_idx): 349 | loss = 0.0 350 | batch = style_E[0][0].shape[0] 351 | for b in range(batch): 352 | poss_loss = 0.0 353 | neg_loss = 0.0 354 | 355 | # Positive loss 356 | for i in range(len(style_E)): 357 | poss_loss += self.efdm_single( 358 | style_E[i][0][b].unsqueeze(0), translate_E[i][0][b].unsqueeze(0) 359 | ) + self.efdm_single( 360 | style_E[i][1][b].unsqueeze(0), translate_E[i][1][b].unsqueeze(0) 361 | ) 362 | for i in range(len(style_S)): 363 | poss_loss += self.efdm_single( 364 | style_S[i][0][b].unsqueeze(0), translate_S[i][0][b].unsqueeze(0) 365 | ) + self.efdm_single( 366 | style_S[i][1][b].unsqueeze(0), translate_S[i][1][b].unsqueeze(0) 367 | ) 368 | 369 | # Negative loss 370 | for nb in neg_idx[b]: 371 | for i in range(len(style_E)): 372 | neg_loss += self.efdm_single( 373 | style_E[i][0][nb].unsqueeze(0), 374 | translate_E[i][0][b].unsqueeze(0), 375 | ) + self.efdm_single( 376 | style_E[i][1][nb].unsqueeze(0), 377 | translate_E[i][1][b].unsqueeze(0), 378 | ) 379 | for i in range(len(style_S)): 380 | neg_loss += self.efdm_single( 381 | style_S[i][0][nb].unsqueeze(0), 382 | translate_S[i][0][b].unsqueeze(0), 383 | ) + self.efdm_single( 384 | style_S[i][1][nb].unsqueeze(0), 385 | translate_S[i][1][b].unsqueeze(0), 386 | ) 387 | 388 | loss += poss_loss / neg_loss 389 | 390 | return loss 391 | -------------------------------------------------------------------------------- /module_aesfa/blocks.py: -------------------------------------------------------------------------------- 1 | import os 2 | # import glob 3 | # from path import Path 4 | import math 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch.optim import lr_scheduler 9 | # from blocks import * 10 | 11 | def model_save(ckpt_dir, model, optim_E, optim_S, optim_G, epoch, itr=None): 12 | if not os.path.exists(ckpt_dir): 13 | os.makedirs(ckpt_dir) 14 | 15 | torch.save({'netE': model.netE.state_dict(), 16 | 'netS': model.netS.state_dict(), 17 | 'netG': model.netG.state_dict(), 18 | 'optim_E': optim_E.state_dict(), 19 | 'optim_S': optim_S.state_dict(), 20 | 'optim_G': optim_G.state_dict()}, 21 | '%s/model_iter_%d_epoch_%d.pth' % (ckpt_dir, itr+1, epoch+1)) 22 | 23 | def model_load(checkpoint, ckpt_path, model, optim_E, optim_S, optim_G): 24 | # if not os.path.exists(ckpt_dir): 25 | # epoch = -1 26 | # return model, optim_E, optim_S, optim_G, epoch 27 | 28 | # ckpt_path = Path(ckpt_dir) 29 | if checkpoint: 30 | model_ckpt = ckpt_path + '/' + checkpoint 31 | else: 32 | ckpt_lst = ckpt_path.glob('model_iter_*') 33 | ckpt_lst.sort(key=lambda x: int(x.split('iter_')[1].split('_epoch')[0])) 34 | model_ckpt = ckpt_lst[-1] 35 | itr = int(model_ckpt.split('iter_')[1].split('_epoch_')[0]) 36 | epoch = int(model_ckpt.split('iter_')[1].split('_epoch_')[1].split('.')[0]) 37 | print(model_ckpt) 38 | 39 | dict_model = torch.load(model_ckpt) 40 | 41 | model.netE.load_state_dict(dict_model['netE']) 42 | model.netS.load_state_dict(dict_model['netS']) 43 | model.netG.load_state_dict(dict_model['netG']) 44 | optim_E.load_state_dict(dict_model['optim_E']) 45 | optim_S.load_state_dict(dict_model['optim_S']) 46 | optim_G.load_state_dict(dict_model['optim_G']) 47 | 48 | return model, optim_E, optim_S, optim_G, epoch, itr 49 | 50 | def test_model_load(checkpoint, model): 51 | dict_model = torch.load(checkpoint) 52 | model.netE.load_state_dict(dict_model['netE']) 53 | model.netS.load_state_dict(dict_model['netS']) 54 | model.netG.load_state_dict(dict_model['netG']) 55 | return model 56 | 57 | def get_scheduler(optimizer, config): 58 | if config.lr_policy == 'lambda': 59 | def lambda_rule(epoch): 60 | lr_l = 1.0 - max(0, epoch + config.n_epoch - config.n_iter) / float(config.n_iter_decay + 1) 61 | return lr_l 62 | 63 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 64 | elif config.lr_policy == 'step': 65 | scheduler = lr_scheduler.StepLR(optimizer, step_size=config.lr_decay_iters, gamma=0.1) 66 | elif config.lr_policy == 'plateau': 67 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 68 | elif config.lr_policy == 'cosine': 69 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.n_iter, eta_min=0) 70 | else: 71 | return NotImplementedError('learning rate policy [%s] is not implemented', config.lr_policy) 72 | return scheduler 73 | 74 | def update_learning_rate(scheduler, optimizer): 75 | scheduler.step() 76 | lr = optimizer.param_groups[0]['lr'] 77 | print('learning rate = %.7f' % lr) 78 | 79 | class Oct_Conv_aftup(nn.Module): 80 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, pad_type, alpha_in, alpha_out): 81 | super(Oct_Conv_aftup, self).__init__() 82 | lf_in = int(in_channels*alpha_in) 83 | lf_out = int(out_channels*alpha_out) 84 | hf_in = in_channels - lf_in 85 | hf_out = out_channels - lf_out 86 | 87 | self.conv_h = nn.Conv2d(in_channels=hf_in, out_channels=hf_out, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=pad_type) 88 | self.conv_l = nn.Conv2d(in_channels=lf_in, out_channels=lf_out, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=pad_type) 89 | 90 | def forward(self, x): 91 | hf, lf = x 92 | hf = self.conv_h(hf) 93 | lf = self.conv_l(lf) 94 | return hf, lf 95 | 96 | class Oct_conv_reLU(nn.ReLU): 97 | def forward(self, x): 98 | hf, lf = x 99 | hf = super(Oct_conv_reLU, self).forward(hf) 100 | lf = super(Oct_conv_reLU, self).forward(lf) 101 | return hf, lf 102 | 103 | class Oct_conv_lreLU(nn.LeakyReLU): 104 | def forward(self, x): 105 | hf, lf = x 106 | hf = super(Oct_conv_lreLU, self).forward(hf) 107 | lf = super(Oct_conv_lreLU, self).forward(lf) 108 | return hf, lf 109 | 110 | class Oct_conv_up(nn.Upsample): 111 | def forward(self, x): 112 | hf, lf = x 113 | hf = super(Oct_conv_up, self).forward(hf) 114 | lf = super(Oct_conv_up, self).forward(lf) 115 | return hf, lf 116 | 117 | 118 | ############## Encoder ############## 119 | class OctConv(nn.Module): 120 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 121 | padding=0, groups=1, pad_type='reflect', alpha_in=0.5, alpha_out=0.5, type='normal', freq_ratio = [1, 1]): 122 | super(OctConv, self).__init__() 123 | self.kernel_size = kernel_size 124 | self.stride = stride 125 | self.type = type 126 | self.alpha_in = alpha_in 127 | self.alpha_out = alpha_out 128 | self.freq_ratio = freq_ratio 129 | 130 | hf_ch_in = int(in_channels * (1 - self.alpha_in)) 131 | hf_ch_out = int(out_channels * (1 -self. alpha_out)) 132 | lf_ch_in = in_channels - hf_ch_in 133 | lf_ch_out = out_channels - hf_ch_out 134 | 135 | self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2) 136 | self.upsample = nn.Upsample(scale_factor=2) 137 | 138 | self.is_dw = groups == in_channels 139 | 140 | if type == 'first': 141 | self.convh = nn.Conv2d(in_channels, hf_ch_out, kernel_size=kernel_size, 142 | stride=stride, padding=padding, padding_mode=pad_type, bias = False) 143 | self.convl = nn.Conv2d(in_channels, lf_ch_out, 144 | kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=pad_type, bias=False) 145 | elif type == 'last': 146 | self.convh = nn.Conv2d(hf_ch_in, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=pad_type, bias=False) 147 | self.convl = nn.Conv2d(lf_ch_in, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=pad_type, bias=False) 148 | else: 149 | self.L2L = nn.Conv2d( 150 | lf_ch_in, lf_ch_out, 151 | kernel_size=kernel_size, stride=stride, padding=padding, groups=math.ceil(alpha_in * groups), padding_mode=pad_type, bias=False 152 | ) 153 | if self.is_dw: 154 | self.L2H = None 155 | self.H2L = None 156 | else: 157 | self.L2H = nn.Conv2d( 158 | lf_ch_in, hf_ch_out, 159 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, padding_mode=pad_type, bias=False 160 | ) 161 | self.H2L = nn.Conv2d( 162 | hf_ch_in, lf_ch_out, 163 | kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, padding_mode=pad_type, bias=False 164 | ) 165 | self.H2H = nn.Conv2d( 166 | hf_ch_in, hf_ch_out, 167 | kernel_size=kernel_size, stride=stride, padding=padding, groups=math.ceil(groups - alpha_in * groups), padding_mode=pad_type, bias=False 168 | ) 169 | 170 | def forward(self, x): 171 | if self.type == 'first': 172 | hf = self.convh(x) 173 | lf = self.avg_pool(x) 174 | lf = self.convl(lf) 175 | return hf, lf 176 | elif self.type == 'last': 177 | hf, lf = x 178 | out_h = self.convh(hf) 179 | out_l = self.convl(self.upsample(lf)) 180 | output = out_h * self.freq_ratio[0] + out_l * self.freq_ratio[1] 181 | return output, out_h, out_l 182 | else: 183 | hf, lf = x 184 | if self.is_dw: 185 | hf, lf = self.H2H(hf), self.L2L(lf) 186 | else: 187 | hf, lf = self.H2H(hf) + self.L2H(self.upsample(lf)), self.L2L(lf) + self.H2L(self.avg_pool(hf)) 188 | return hf, lf 189 | 190 | 191 | ############## Decoder ############## 192 | class AdaOctConv(nn.Module): 193 | def __init__(self, in_channels, out_channels, group_div, style_channels, kernel_size, 194 | stride, padding, oct_groups, alpha_in, alpha_out, type='normal'): 195 | super(AdaOctConv, self).__init__() 196 | self.in_channels = in_channels 197 | self.alpha_in = alpha_in 198 | self.alpha_out = alpha_out 199 | self.type = type 200 | 201 | h_in = int(in_channels * (1 - self.alpha_in)) 202 | l_in = in_channels - h_in 203 | 204 | n_groups_h = h_in // group_div 205 | n_groups_l = l_in // group_div 206 | 207 | style_channels_h = int(style_channels * (1 - self.alpha_in)) 208 | style_channels_l = int(style_channels - style_channels_h) 209 | 210 | kernel_size_h = kernel_size[0] 211 | kernel_size_l = kernel_size[1] 212 | kernel_size_A = kernel_size[2] 213 | 214 | self.kernelPredictor_h = KernelPredictor(in_channels=h_in, 215 | out_channels=h_in, 216 | n_groups=n_groups_h, 217 | style_channels=style_channels_h, 218 | kernel_size=kernel_size_h) 219 | self.kernelPredictor_l = KernelPredictor(in_channels=l_in, 220 | out_channels=l_in, 221 | n_groups=n_groups_l, 222 | style_channels=style_channels_l, 223 | kernel_size=kernel_size_l) 224 | 225 | self.AdaConv_h = AdaConv2d(in_channels=h_in, out_channels=h_in, n_groups=n_groups_h) 226 | self.AdaConv_l = AdaConv2d(in_channels=l_in, out_channels=l_in, n_groups=n_groups_l) 227 | 228 | self.OctConv = OctConv(in_channels=in_channels, 229 | out_channels=out_channels, 230 | kernel_size=kernel_size_A, stride=stride, padding=padding, groups=oct_groups, 231 | alpha_in=alpha_in, alpha_out=alpha_out, type=type) 232 | 233 | self.relu = Oct_conv_lreLU() 234 | 235 | def forward(self, content, style, cond='train'): 236 | c_hf, c_lf = content 237 | s_hf, s_lf = style 238 | h_w_spatial, h_w_pointwise, h_bias = self.kernelPredictor_h(s_hf) 239 | l_w_spatial, l_w_pointwise, l_bias = self.kernelPredictor_l(s_lf) 240 | 241 | if cond == 'train': 242 | output_h = self.AdaConv_h(c_hf, h_w_spatial, h_w_pointwise, h_bias) 243 | output_l = self.AdaConv_l(c_lf, l_w_spatial, l_w_pointwise, l_bias) 244 | output = output_h, output_l 245 | 246 | output = self.relu(output) 247 | 248 | output = self.OctConv(output) 249 | if self.type != 'last': 250 | output = self.relu(output) 251 | return output 252 | 253 | if cond == 'test': 254 | output_h = self.AdaConv_h(c_hf, h_w_spatial, h_w_pointwise, h_bias) 255 | output_l = self.AdaConv_l(c_lf, l_w_spatial, l_w_pointwise, l_bias) 256 | output = output_h, output_l 257 | output = self.relu(output) 258 | output = self.OctConv(output) 259 | if self.type != 'last': 260 | output = self.relu(output) 261 | return output 262 | 263 | class KernelPredictor(nn.Module): 264 | def __init__(self, in_channels, out_channels, n_groups, style_channels, kernel_size): 265 | super(KernelPredictor, self).__init__() 266 | 267 | self.in_channels = in_channels 268 | self.out_channels = out_channels 269 | self.n_groups = n_groups 270 | self.w_channels = style_channels 271 | self.kernel_size = kernel_size 272 | 273 | padding = (kernel_size - 1) / 2 274 | self.spatial = nn.Conv2d(style_channels, 275 | in_channels * out_channels // n_groups, 276 | kernel_size=kernel_size, 277 | padding=(math.ceil(padding), math.ceil(padding)), 278 | padding_mode='reflect') 279 | self.pointwise = nn.Sequential( 280 | nn.AdaptiveAvgPool2d((1, 1)), 281 | nn.Conv2d(style_channels, 282 | out_channels * out_channels // n_groups, 283 | kernel_size=1) 284 | ) 285 | self.bias = nn.Sequential( 286 | nn.AdaptiveAvgPool2d((1, 1)), 287 | nn.Conv2d(style_channels, 288 | out_channels, 289 | kernel_size=1) 290 | ) 291 | 292 | def forward(self, w): 293 | w_spatial = self.spatial(w) 294 | w_spatial = w_spatial.reshape(len(w), 295 | self.out_channels, 296 | self.in_channels // self.n_groups, 297 | self.kernel_size, self.kernel_size) 298 | 299 | w_pointwise = self.pointwise(w) 300 | w_pointwise = w_pointwise.reshape(len(w), 301 | self.out_channels, 302 | self.out_channels // self.n_groups, 303 | 1, 1) 304 | bias = self.bias(w) 305 | bias = bias.reshape(len(w), self.out_channels) 306 | return w_spatial, w_pointwise, bias 307 | 308 | class AdaConv2d(nn.Module): 309 | def __init__(self, in_channels, out_channels, kernel_size=3, n_groups=None): 310 | super(AdaConv2d, self).__init__() 311 | self.n_groups = in_channels if n_groups is None else n_groups 312 | self.in_channels = in_channels 313 | self.out_channels = out_channels 314 | 315 | padding = (kernel_size - 1) / 2 316 | self.conv = nn.Conv2d(in_channels=in_channels, 317 | out_channels=out_channels, 318 | kernel_size=(kernel_size, kernel_size), 319 | padding=(math.ceil(padding), math.floor(padding)), 320 | padding_mode='reflect') 321 | 322 | def forward(self, x, w_spatial, w_pointwise, bias): 323 | assert len(x) == len(w_spatial) == len(w_pointwise) == len(bias) 324 | x = F.instance_norm(x) 325 | 326 | ys = [] 327 | for i in range(len(x)): 328 | y = self.forward_single(x[i:i+1], w_spatial[i], w_pointwise[i], bias[i]) 329 | ys.append(y) 330 | ys = torch.cat(ys, dim=0) 331 | 332 | ys = self.conv(ys) 333 | return ys 334 | 335 | def forward_single(self, x, w_spatial, w_pointwise, bias): 336 | assert w_spatial.size(-1) == w_spatial.size(-2) 337 | padding = (w_spatial.size(-1) - 1) / 2 338 | pad = (math.ceil(padding), math.floor(padding), math.ceil(padding), math.floor(padding)) 339 | 340 | x = F.pad(x, pad=pad, mode='reflect') 341 | x = F.conv2d(x, w_spatial, groups=self.n_groups) 342 | x = F.conv2d(x, w_pointwise, groups=self.n_groups, bias=bias) 343 | return x --------------------------------------------------------------------------------