├── .DS_Store ├── LICENSE ├── README.assets ├── .DS_Store ├── 13-32-style.png ├── 14-48-consistency.png ├── 18-4-content.png ├── 27--8.png ├── 27-8-consistency-8074642.png ├── 27-8-consistency.png ├── 39-13-content.png ├── 4-29-style.png ├── 40-43-consistency.png ├── 8-35-content-8074538.png ├── 8-35-content.png └── other.png ├── README.md ├── hist_loss.py ├── main.py ├── net.py └── utils.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxuan-cs/PAMA/a5267429a2e1ab764b6b8e89abf9c036bcb9ff23/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 computer-vision2022 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxuan-cs/PAMA/a5267429a2e1ab764b6b8e89abf9c036bcb9ff23/README.assets/.DS_Store -------------------------------------------------------------------------------- /README.assets/13-32-style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxuan-cs/PAMA/a5267429a2e1ab764b6b8e89abf9c036bcb9ff23/README.assets/13-32-style.png -------------------------------------------------------------------------------- /README.assets/14-48-consistency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxuan-cs/PAMA/a5267429a2e1ab764b6b8e89abf9c036bcb9ff23/README.assets/14-48-consistency.png -------------------------------------------------------------------------------- /README.assets/18-4-content.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxuan-cs/PAMA/a5267429a2e1ab764b6b8e89abf9c036bcb9ff23/README.assets/18-4-content.png -------------------------------------------------------------------------------- /README.assets/27--8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxuan-cs/PAMA/a5267429a2e1ab764b6b8e89abf9c036bcb9ff23/README.assets/27--8.png -------------------------------------------------------------------------------- /README.assets/27-8-consistency-8074642.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxuan-cs/PAMA/a5267429a2e1ab764b6b8e89abf9c036bcb9ff23/README.assets/27-8-consistency-8074642.png -------------------------------------------------------------------------------- /README.assets/27-8-consistency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxuan-cs/PAMA/a5267429a2e1ab764b6b8e89abf9c036bcb9ff23/README.assets/27-8-consistency.png -------------------------------------------------------------------------------- /README.assets/39-13-content.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxuan-cs/PAMA/a5267429a2e1ab764b6b8e89abf9c036bcb9ff23/README.assets/39-13-content.png -------------------------------------------------------------------------------- /README.assets/4-29-style.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxuan-cs/PAMA/a5267429a2e1ab764b6b8e89abf9c036bcb9ff23/README.assets/4-29-style.png -------------------------------------------------------------------------------- /README.assets/40-43-consistency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxuan-cs/PAMA/a5267429a2e1ab764b6b8e89abf9c036bcb9ff23/README.assets/40-43-consistency.png -------------------------------------------------------------------------------- /README.assets/8-35-content-8074538.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxuan-cs/PAMA/a5267429a2e1ab764b6b8e89abf9c036bcb9ff23/README.assets/8-35-content-8074538.png -------------------------------------------------------------------------------- /README.assets/8-35-content.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxuan-cs/PAMA/a5267429a2e1ab764b6b8e89abf9c036bcb9ff23/README.assets/8-35-content.png -------------------------------------------------------------------------------- /README.assets/other.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luoxuan-cs/PAMA/a5267429a2e1ab764b6b8e89abf9c036bcb9ff23/README.assets/other.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | PAMA 2 | ================ 3 | ​ This is the Pytorch implementation of Progressive Attentional Manifold Alignment. 4 | 5 | ​ 1/14/2022 Thanks to github user [AK391](https://github.com/AK391) for making a [web demo](https://huggingface.co/spaces/akhaliq/PAMA) for PAMA. 6 | 7 | ​ 1/18/2022 New checkpoints are available. (w/o color loss, 1.5x color loss weight, 1.5x content loss weight) 8 | 9 | ​ 11/15/2022 Our paper is accepted by ACCV 2022. We also updated our code, here is the new version of PAMA: https://drive.google.com/drive/folders/1At5XYHW153Pe8A1TgmAbOfwTv3fH7YK1?usp=share_link 10 | 11 | 12 | ## Requirements 13 | 14 | * python 3.6 15 | * pytorch 1.2.0+ 16 | * PIL, numpy, matplotlib 17 | 18 | ## Checkpoints 19 | 20 | Please download the pre-trained checkpoints at [google drive](https://drive.google.com/file/d/1rPB_qnelVVSad6CtadmhRFi0PMI_RKdy/view?usp=sharing) and put them in ./checkpoints. 21 | 22 | Here we also provide some other pre-trained results with different loss weights: 23 | 24 | | Type | Loss | Download | 25 | | ---------------- | --------------- | -------------------- | 26 | | high consistency | w/o color loss | [PAMA_without_color.zip](https://drive.google.com/file/d/1IrggOiutiZceJCrEb24cLnBjeA5I3N1D/view?usp=sharing) | 27 | | high color | 1.5x color loss weight | [PAMA_1.5x_color.zip](https://drive.google.com/file/d/1HXet2u_zk2QCVM_z5Llg2bcfvvndabtt/view?usp=sharing) | 28 | | high content | 1.5x content loss weight | [PAMA_1.5x_content.zip](https://drive.google.com/file/d/13m7Lb9xwfG_DVOesuG9PyxDHG4SwqlNt/view?usp=sharing) | 29 | 30 | The checkpionts will be uploaded recently. 31 | 32 | ## Training 33 | 34 | The training set consists of two parts, the content images from COCO2014 and style images from Wikiart. 35 | 36 | ```python 37 | python main.py train --lr 1e-4 --content_folder ./COCO2014 --style_folder ./Wikiart 38 | ``` 39 | 40 | ## Testing 41 | 42 | To test the code, you need to specify the path of the content image and the style image. 43 | 44 | ```python 45 | python main.py eval --content ./content/1.jpg --style ./style/1.jpg 46 | ``` 47 | 48 | If you want to do a batch operation for all pictures under the folder at one time, please execute the following code. 49 | 50 | ```python 51 | python main.py eval --run_folder True --content ./content/ --style ./style/ 52 | ``` 53 | 54 | 55 | ## Results Presentation 56 | 57 | ​ The results prove the quality of PAMA from three dimensions: Regional Consistency, Content Proservation, Style Quality. 58 | 59 | 60 | 61 | #### Regional Consistency 62 | 63 | ![39-13-content](./README.assets/39-13-content.png) 64 | 65 | ![8-35-content](./README.assets/8-35-content-8074538.png) 66 | 67 | #### Content preservation 68 | 69 | ![18-4-content](./README.assets/18-4-content.png) 70 | 71 | ![27-8-consistency](./README.assets/27-8-consistency-8074642.png) 72 | 73 | #### Style Quality 74 | 75 | ![4-29-style](./README.assets/4-29-style.png) 76 | 77 | ![13-32-style](./README.assets/13-32-style.png) 78 | 79 | #### Other Results 80 | 81 | ![other](./README.assets/other.png) 82 | 83 | -------------------------------------------------------------------------------- /hist_loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright 2021 Mahmoud Afifi. 3 | Mahmoud Afifi, Marcus A. Brubaker, and Michael S. Brown. "HistoGAN: 4 | Controlling Colors of GAN-Generated and Real Images via Color Histograms." 5 | In CVPR, 2021. 6 | 7 | @inproceedings{afifi2021histogan, 8 | title={Histo{GAN}: Controlling Colors of {GAN}-Generated and Real Images via 9 | Color Histograms}, 10 | author={Afifi, Mahmoud and Brubaker, Marcus A. and Brown, Michael S.}, 11 | booktitle={CVPR}, 12 | year={2021} 13 | } 14 | """ 15 | 16 | import torch 17 | import torch.nn as nn 18 | from PIL import Image 19 | import matplotlib.pyplot as plt 20 | import torch.nn.functional as F 21 | import torchvision.transforms as transforms 22 | import numpy as np 23 | 24 | EPS = 1e-6 25 | 26 | class RGBuvHistBlock(nn.Module): 27 | def __init__(self, h=64, insz=150, resizing='interpolation', 28 | method='inverse-quadratic', sigma=0.02, intensity_scale=True, 29 | device='cuda'): 30 | """ Computes the RGB-uv histogram feature of a given image. 31 | Args: 32 | h: histogram dimension size (scalar). The default value is 64. 33 | insz: maximum size of the input image; if it is larger than this size, the 34 | image will be resized (scalar). Default value is 150 (i.e., 150 x 150 35 | pixels). 36 | resizing: resizing method if applicable. Options are: 'interpolation' or 37 | 'sampling'. Default is 'interpolation'. 38 | method: the method used to count the number of pixels for each bin in the 39 | histogram feature. Options are: 'thresholding', 'RBF' (radial basis 40 | function), or 'inverse-quadratic'. Default value is 'inverse-quadratic'. 41 | sigma: if the method value is 'RBF' or 'inverse-quadratic', then this is 42 | the sigma parameter of the kernel function. The default value is 0.02. 43 | intensity_scale: boolean variable to use the intensity scale (I_y in 44 | Equation 2). Default value is True. 45 | 46 | Methods: 47 | forward: accepts input image and returns its histogram feature. Note that 48 | unless the method is 'thresholding', this is a differentiable function 49 | and can be easily integrated with the loss function. As mentioned in the 50 | paper, the 'inverse-quadratic' was found more stable than 'RBF' in our 51 | training. 52 | """ 53 | super(RGBuvHistBlock, self).__init__() 54 | self.h = h 55 | self.insz = insz 56 | self.device = device 57 | self.resizing = resizing 58 | self.method = method 59 | self.intensity_scale = intensity_scale 60 | if self.method == 'thresholding': 61 | self.eps = 6.0 / h 62 | else: 63 | self.sigma = sigma 64 | 65 | def forward(self, x): 66 | x = torch.clamp(x, 0, 1) 67 | if x.shape[2] > self.insz or x.shape[3] > self.insz: 68 | if self.resizing == 'interpolation': 69 | x_sampled = F.interpolate(x, size=(self.insz, self.insz), 70 | mode='bilinear', align_corners=False) 71 | elif self.resizing == 'sampling': 72 | inds_1 = torch.LongTensor( 73 | np.linspace(0, x.shape[2], self.h, endpoint=False)).to( 74 | device=self.device) 75 | inds_2 = torch.LongTensor( 76 | np.linspace(0, x.shape[3], self.h, endpoint=False)).to( 77 | device=self.device) 78 | x_sampled = x.index_select(2, inds_1) 79 | x_sampled = x_sampled.index_select(3, inds_2) 80 | else: 81 | raise Exception( 82 | f'Wrong resizing method. It should be: interpolation or sampling. ' 83 | f'But the given value is {self.resizing}.') 84 | else: 85 | x_sampled = x 86 | 87 | L = x_sampled.shape[0] # size of mini-batch 88 | if x_sampled.shape[1] > 3: 89 | x_sampled = x_sampled[:, :3, :, :] 90 | X = torch.unbind(x_sampled, dim=0) 91 | hists = torch.zeros((x_sampled.shape[0], 3, self.h, self.h)).to( 92 | device=self.device) 93 | for l in range(L): 94 | I = torch.t(torch.reshape(X[l], (3, -1))) 95 | II = torch.pow(I, 2) 96 | if self.intensity_scale: 97 | Iy = torch.unsqueeze(torch.sqrt(II[:, 0] + II[:, 1] + II[:, 2] + EPS), 98 | dim=1) 99 | else: 100 | Iy = 1 101 | 102 | Iu0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 1] + EPS), 103 | dim=1) 104 | Iv0 = torch.unsqueeze(torch.log(I[:, 0] + EPS) - torch.log(I[:, 2] + EPS), 105 | dim=1) 106 | diff_u0 = abs( 107 | Iu0 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)), 108 | dim=0).to(self.device)) 109 | diff_v0 = abs( 110 | Iv0 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)), 111 | dim=0).to(self.device)) 112 | if self.method == 'thresholding': 113 | diff_u0 = torch.reshape(diff_u0, (-1, self.h)) <= self.eps / 2 114 | diff_v0 = torch.reshape(diff_v0, (-1, self.h)) <= self.eps / 2 115 | elif self.method == 'RBF': 116 | diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)), 117 | 2) / self.sigma ** 2 118 | diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)), 119 | 2) / self.sigma ** 2 120 | diff_u0 = torch.exp(-diff_u0) # Radial basis function 121 | diff_v0 = torch.exp(-diff_v0) 122 | elif self.method == 'inverse-quadratic': 123 | diff_u0 = torch.pow(torch.reshape(diff_u0, (-1, self.h)), 124 | 2) / self.sigma ** 2 125 | diff_v0 = torch.pow(torch.reshape(diff_v0, (-1, self.h)), 126 | 2) / self.sigma ** 2 127 | diff_u0 = 1 / (1 + diff_u0) # Inverse quadratic 128 | diff_v0 = 1 / (1 + diff_v0) 129 | else: 130 | raise Exception( 131 | f'Wrong kernel method. It should be either thresholding, RBF,' 132 | f' inverse-quadratic. But the given value is {self.method}.') 133 | diff_u0 = diff_u0.type(torch.float32) 134 | diff_v0 = diff_v0.type(torch.float32) 135 | a = torch.t(Iy * diff_u0) 136 | hists[l, 0, :, :] = torch.mm(a, diff_v0) 137 | 138 | Iu1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 0] + EPS), 139 | dim=1) 140 | Iv1 = torch.unsqueeze(torch.log(I[:, 1] + EPS) - torch.log(I[:, 2] + EPS), 141 | dim=1) 142 | diff_u1 = abs( 143 | Iu1 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)), 144 | dim=0).to(self.device)) 145 | diff_v1 = abs( 146 | Iv1 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)), 147 | dim=0).to(self.device)) 148 | 149 | if self.method == 'thresholding': 150 | diff_u1 = torch.reshape(diff_u1, (-1, self.h)) <= self.eps / 2 151 | diff_v1 = torch.reshape(diff_v1, (-1, self.h)) <= self.eps / 2 152 | elif self.method == 'RBF': 153 | diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)), 154 | 2) / self.sigma ** 2 155 | diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)), 156 | 2) / self.sigma ** 2 157 | diff_u1 = torch.exp(-diff_u1) # Gaussian 158 | diff_v1 = torch.exp(-diff_v1) 159 | elif self.method == 'inverse-quadratic': 160 | diff_u1 = torch.pow(torch.reshape(diff_u1, (-1, self.h)), 161 | 2) / self.sigma ** 2 162 | diff_v1 = torch.pow(torch.reshape(diff_v1, (-1, self.h)), 163 | 2) / self.sigma ** 2 164 | diff_u1 = 1 / (1 + diff_u1) # Inverse quadratic 165 | diff_v1 = 1 / (1 + diff_v1) 166 | 167 | diff_u1 = diff_u1.type(torch.float32) 168 | diff_v1 = diff_v1.type(torch.float32) 169 | a = torch.t(Iy * diff_u1) 170 | hists[l, 1, :, :] = torch.mm(a, diff_v1) 171 | 172 | Iu2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 0] + EPS), 173 | dim=1) 174 | Iv2 = torch.unsqueeze(torch.log(I[:, 2] + EPS) - torch.log(I[:, 1] + EPS), 175 | dim=1) 176 | diff_u2 = abs( 177 | Iu2 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)), 178 | dim=0).to(self.device)) 179 | diff_v2 = abs( 180 | Iv2 - torch.unsqueeze(torch.tensor(np.linspace(-3, 3, num=self.h)), 181 | dim=0).to(self.device)) 182 | if self.method == 'thresholding': 183 | diff_u2 = torch.reshape(diff_u2, (-1, self.h)) <= self.eps / 2 184 | diff_v2 = torch.reshape(diff_v2, (-1, self.h)) <= self.eps / 2 185 | elif self.method == 'RBF': 186 | diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)), 187 | 2) / self.sigma ** 2 188 | diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)), 189 | 2) / self.sigma ** 2 190 | diff_u2 = torch.exp(-diff_u2) # Gaussian 191 | diff_v2 = torch.exp(-diff_v2) 192 | elif self.method == 'inverse-quadratic': 193 | diff_u2 = torch.pow(torch.reshape(diff_u2, (-1, self.h)), 194 | 2) / self.sigma ** 2 195 | diff_v2 = torch.pow(torch.reshape(diff_v2, (-1, self.h)), 196 | 2) / self.sigma ** 2 197 | diff_u2 = 1 / (1 + diff_u2) # Inverse quadratic 198 | diff_v2 = 1 / (1 + diff_v2) 199 | diff_u2 = diff_u2.type(torch.float32) 200 | diff_v2 = diff_v2.type(torch.float32) 201 | a = torch.t(Iy * diff_u2) 202 | hists[l, 2, :, :] = torch.mm(a, diff_v2) 203 | 204 | # normalization 205 | hists_normalized = hists / ( 206 | ((hists.sum(dim=1)).sum(dim=1)).sum(dim=1).view(-1, 1, 1, 1) + EPS) 207 | 208 | return hists_normalized -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import logging 5 | import torch 6 | import torch.nn as nn 7 | import torch.utils.data as data 8 | from torchvision.utils import save_image 9 | from PIL import Image, ImageFile 10 | from net import Net 11 | from utils import DEVICE, train_transform, test_transform, FlatFolderDataset, InfiniteSamplerWrapper, plot_grad_flow, adjust_learning_rate 12 | Image.MAX_IMAGE_PIXELS = None 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | 15 | 16 | def train(args): 17 | logging.basicConfig(filename='training.log', 18 | format='%(asctime)s %(levelname)s: %(message)s', 19 | level=logging.INFO, 20 | datefmt='%Y-%m-%d %H:%M:%S') 21 | 22 | mes = "current pid: " + str(os.getpid()) 23 | print(mes) 24 | logging.info(mes) 25 | model = Net(args) 26 | model.train() 27 | device_ids = [0, 1] 28 | model = nn.DataParallel(model, device_ids=device_ids) 29 | model = model.to(DEVICE) 30 | 31 | tf = train_transform() 32 | content_dataset = FlatFolderDataset(args.content_folder, tf) 33 | style_dataset = FlatFolderDataset(args.style_folder, tf) 34 | content_iter = iter(data.DataLoader( 35 | content_dataset, batch_size=args.batch_size, 36 | sampler=InfiniteSamplerWrapper(content_dataset), 37 | num_workers=args.num_workers)) 38 | style_iter = iter(data.DataLoader( 39 | style_dataset, batch_size=args.batch_size, 40 | sampler=InfiniteSamplerWrapper(style_dataset), 41 | num_workers=args.num_workers)) 42 | 43 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 44 | 45 | for img_index in range(args.iterations): 46 | print("iteration :", img_index+1) 47 | optimizer.zero_grad() 48 | Ic = next(content_iter).to(DEVICE) 49 | Is = next(style_iter).to(DEVICE) 50 | 51 | loss = model(Ic, Is) 52 | print(loss) 53 | loss.sum().backward() 54 | 55 | #plot_grad_flow(GMMN.named_parameters()) 56 | optimizer.step() 57 | 58 | if (img_index+1)%args.log_interval == 0: 59 | print("saving...") 60 | mes = "iteration: " + str(img_index+1) + " loss: " + str(loss.sum().item()) 61 | logging.info(mes) 62 | model.module.save_ckpts() 63 | adjust_learning_rate(optimizer, img_index, args) 64 | 65 | 66 | def eval(args): 67 | mes = "current pid: " + str(os.getpid()) 68 | print(mes) 69 | logging.info(mes) 70 | model = Net(args) 71 | model.eval() 72 | model = model.to(DEVICE) 73 | 74 | tf = test_transform() 75 | if args.run_folder == True: 76 | content_dir = args.content 77 | style_dir = args.style 78 | for content in os.listdir(content_dir): 79 | for style in os.listdir(style_dir): 80 | name_c = content_dir + content 81 | name_s = style_dir + style 82 | Ic = tf(Image.open(name_c)).to(DEVICE) 83 | Is = tf(Image.open(name_s)).to(DEVICE) 84 | Ic = Ic.unsqueeze(dim=0) 85 | Is = Is.unsqueeze(dim=0) 86 | with torch.no_grad(): 87 | Ics = model(Ic, Is) 88 | 89 | name_cs = "ics/" + os.path.splitext(content)[0]+"--"+style 90 | save_image(Ics[0], name_cs) 91 | else: 92 | Ic = tf(Image.open(args.content)).to(DEVICE) 93 | Is = tf(Image.open(args.style)).to(DEVICE) 94 | 95 | Ic = Ic.unsqueeze(dim=0) 96 | Is = Is.unsqueeze(dim=0) 97 | 98 | with torch.no_grad(): 99 | Ics = model(Ic, Is) 100 | 101 | name_cs = "ics.jpg" 102 | save_image(Ics[0], name_cs) 103 | 104 | 105 | def main(): 106 | main_parser = argparse.ArgumentParser(description="main parser") 107 | subparsers = main_parser.add_subparsers(title="subcommands", dest="subcommand") 108 | 109 | main_parser.add_argument("--pretrained", type=bool, default=True, 110 | help="whether to use the pre-trained checkpoints") 111 | main_parser.add_argument("--requires_grad", type=bool, default=True, 112 | help="set to True if the model requires model gradient") 113 | 114 | train_parser = subparsers.add_parser("train", help="training mode parser") 115 | train_parser.add_argument("--training", type=bool, default=True) 116 | train_parser.add_argument("--iterations", type=int, default=160000, 117 | help="total training epochs (default: 160000)") 118 | train_parser.add_argument("--batch_size", type=int, default=8, 119 | help="training batch size (default: 8)") 120 | train_parser.add_argument("--num_workers", type=int, default=8, 121 | help="iterator threads (default: 8)") 122 | train_parser.add_argument("--lr", type=float, default=1e-4, help="the learning rate during training (default: 1e-4)") 123 | train_parser.add_argument("--content_folder", type=str, required = True, 124 | help="the root of content images, the path should point to a folder") 125 | train_parser.add_argument("--style_folder", type=str, required = True, 126 | help="the root of style images, the path should point to a folder") 127 | train_parser.add_argument("--log_interval", type=int, default=10000, 128 | help="number of images after which the training loss is logged (default: 20000)") 129 | 130 | train_parser.add_argument("--w_content1", type=float, default=12, help="the stage1 content loss weight") 131 | train_parser.add_argument("--w_content2", type=float, default=9, help="the stage2 content loss weight") 132 | train_parser.add_argument("--w_content3", type=float, default=7, help="the stage3 content loss weight") 133 | train_parser.add_argument("--w_remd1", type=float, default=2, help="the stage1 remd loss weight") 134 | train_parser.add_argument("--w_remd2", type=float, default=2, help="the stage2 remd loss weight") 135 | train_parser.add_argument("--w_remd3", type=float, default=2, help="the stage3 remd loss weight") 136 | train_parser.add_argument("--w_moment1", type=float, default=2, help="the stage1 moment loss weight") 137 | train_parser.add_argument("--w_moment2", type=float, default=2, help="the stage2 moment loss weight") 138 | train_parser.add_argument("--w_moment3", type=float, default=2, help="the stage3 moment loss weight") 139 | train_parser.add_argument("--color_on", type=str, default=True, help="turn on the color loss") 140 | train_parser.add_argument("--w_color1", type=float, default=0.25, help="the stage1 color loss weight") 141 | train_parser.add_argument("--w_color2", type=float, default=0.5, help="the stage2 color loss weight") 142 | train_parser.add_argument("--w_color3", type=float, default=1, help="the stage3 color loss weight") 143 | 144 | 145 | eval_parser = subparsers.add_parser("eval", help="evaluation mode parser") 146 | eval_parser.add_argument("--training", type=bool, default=False) 147 | eval_parser.add_argument("--run_folder", type=bool, default=False) 148 | eval_parser.add_argument("--content", type=str, default="./content/", 149 | help="content image you want to stylize") 150 | eval_parser.add_argument("--style", type=str, default="./style/", 151 | help="style image for stylization") 152 | 153 | args = main_parser.parse_args() 154 | 155 | 156 | 157 | if args.subcommand is None: 158 | print("ERROR: specify either train or eval") 159 | sys.exit(1) 160 | if args.subcommand == "train": 161 | train(args) 162 | 163 | else: 164 | eval(args) 165 | 166 | if __name__ == "__main__": 167 | main() 168 | -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from utils import mean_variance_norm, DEVICE 4 | from utils import calc_ss_loss, calc_remd_loss, calc_moment_loss, calc_mse_loss, calc_histogram_loss 5 | from hist_loss import RGBuvHistBlock 6 | import torch 7 | 8 | class Net(nn.Module): 9 | def __init__(self, args): 10 | super(Net, self).__init__() 11 | self.args = args 12 | self.vgg = vgg19[:44] 13 | self.vgg.load_state_dict(torch.load('./checkpoints/encoder.pth', map_location='cpu'), strict=False) 14 | for param in self.vgg.parameters(): 15 | param.requires_grad = False 16 | 17 | self.align1 = PAMA(512) 18 | self.align2 = PAMA(512) 19 | self.align3 = PAMA(512) 20 | 21 | self.decoder = decoder 22 | self.hist = RGBuvHistBlock(insz=64, h=256, 23 | intensity_scale=True, 24 | method='inverse-quadratic', 25 | device=DEVICE) 26 | 27 | if args.pretrained == True: 28 | self.align1.load_state_dict(torch.load('./checkpoints/PAMA1.pth', map_location='cpu'), strict=True) 29 | self.align2.load_state_dict(torch.load('./checkpoints/PAMA2.pth', map_location='cpu'), strict=True) 30 | self.align3.load_state_dict(torch.load('./checkpoints/PAMA3.pth', map_location='cpu'), strict=True) 31 | self.decoder.load_state_dict(torch.load('./checkpoints/decoder.pth', map_location='cpu'), strict=False) 32 | 33 | if args.requires_grad == False: 34 | for param in self.parameters(): 35 | param.requires_grad = False 36 | 37 | 38 | def forward(self, Ic, Is): 39 | feat_c = self.forward_vgg(Ic) 40 | feat_s = self.forward_vgg(Is) 41 | Fc, Fs = feat_c[3], feat_s[3] 42 | 43 | Fcs1 = self.align1(Fc, Fs) 44 | Fcs2 = self.align2(Fcs1, Fs) 45 | Fcs3 = self.align3(Fcs2, Fs) 46 | 47 | Ics3 = self.decoder(Fcs3) 48 | 49 | if self.args.training == True: 50 | Ics1 = self.decoder(Fcs1) 51 | Ics2 = self.decoder(Fcs2) 52 | Irc = self.decoder(Fc) 53 | Irs = self.decoder(Fs) 54 | feat_cs1 = self.forward_vgg(Ics1) 55 | feat_cs2 = self.forward_vgg(Ics2) 56 | feat_cs3 = self.forward_vgg(Ics3) 57 | feat_rc = self.forward_vgg(Irc) 58 | feat_rs = self.forward_vgg(Irs) 59 | 60 | content_loss1, remd_loss1, moment_loss1, color_loss1 = 0.0, 0.0, 0.0, 0.0 61 | content_loss2, remd_loss2, moment_loss2, color_loss2 = 0.0, 0.0, 0.0, 0.0 62 | content_loss3, remd_loss3, moment_loss3, color_loss3 = 0.0, 0.0, 0.0, 0.0 63 | loss_rec = 0.0 64 | 65 | for l in range(2, 5): 66 | content_loss1 += self.args.w_content1 * calc_ss_loss(feat_cs1[l], feat_c[l]) 67 | remd_loss1 += self.args.w_remd1 * calc_remd_loss(feat_cs1[l], feat_s[l]) 68 | moment_loss1 += self.args.w_moment1 * calc_moment_loss(feat_cs1[l], feat_s[l]) 69 | 70 | content_loss2 += self.args.w_content2 * calc_ss_loss(feat_cs2[l], feat_c[l]) 71 | remd_loss2 += self.args.w_remd2 * calc_remd_loss(feat_cs2[l], feat_s[l]) 72 | moment_loss2 += self.args.w_moment2 * calc_moment_loss(feat_cs2[l], feat_s[l]) 73 | 74 | content_loss3 += self.args.w_content3 * calc_ss_loss(feat_cs3[l], feat_c[l]) 75 | remd_loss3 += self.args.w_remd3 * calc_remd_loss(feat_cs3[l], feat_s[l]) 76 | moment_loss3 += self.args.w_moment3 * calc_moment_loss(feat_cs3[l], feat_s[l]) 77 | 78 | loss_rec += 0.5 * calc_mse_loss(feat_rc[l], feat_c[l]) + 0.5 * calc_mse_loss(feat_rs[l], feat_s[l]) 79 | loss_rec += 25 * calc_mse_loss(Irc, Ic) 80 | loss_rec += 25 * calc_mse_loss(Irs, Is) 81 | 82 | if self.args.color_on: 83 | color_loss1 += self.args.w_color1 * calc_histogram_loss(Ics1, Is, self.hist) 84 | color_loss2 += self.args.w_color2 * calc_histogram_loss(Ics2, Is, self.hist) 85 | color_loss3 += self.args.w_color3 * calc_histogram_loss(Ics3, Is, self.hist) 86 | 87 | loss1 = (content_loss1+remd_loss1+moment_loss1+color_loss1)/(self.args.w_content1+self.args.w_remd1+self.args.w_moment1+self.args.w_color1) 88 | loss2 = (content_loss2+remd_loss2+moment_loss2+color_loss2)/(self.args.w_content2+self.args.w_remd2+self.args.w_moment2+self.args.w_color2) 89 | loss3 = (content_loss3+remd_loss3+moment_loss3+color_loss3)/(self.args.w_content3+self.args.w_remd3+self.args.w_moment3+self.args.w_color3) 90 | loss = loss1 + loss2 + loss3 + loss_rec 91 | return loss 92 | else: 93 | return Ics3 94 | 95 | def forward_vgg(self, x): 96 | relu1_1 = self.vgg[:4](x) 97 | relu2_1 = self.vgg[4:11](relu1_1) 98 | relu3_1 = self.vgg[11:18](relu2_1) 99 | relu4_1 = self.vgg[18:31](relu3_1) 100 | relu5_1 = self.vgg[31:44](relu4_1) 101 | return [relu1_1, relu2_1, relu3_1, relu4_1, relu5_1] 102 | 103 | def save_ckpts(self): 104 | torch.save(self.align1.state_dict(), "./checkpoints/PAMA1.pth") 105 | torch.save(self.align2.state_dict(), "./checkpoints/PAMA2.pth") 106 | torch.save(self.align3.state_dict(), "./checkpoints/PAMA3.pth") 107 | torch.save(self.decoder.state_dict(), "./checkpoints/decoder.pth") 108 | 109 | #--------------------------------------------------------------------------------------------------------------- 110 | 111 | vgg19 = nn.Sequential( 112 | nn.Conv2d(3, 3, (1, 1)), 113 | nn.ReflectionPad2d((1, 1, 1, 1)), 114 | nn.Conv2d(3, 64, (3, 3)), 115 | nn.ReLU(), # relu1-1 116 | nn.ReflectionPad2d((1, 1, 1, 1)), 117 | nn.Conv2d(64, 64, (3, 3)), 118 | nn.ReLU(), # relu1-2 119 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 120 | nn.ReflectionPad2d((1, 1, 1, 1)), 121 | nn.Conv2d(64, 128, (3, 3)), 122 | nn.ReLU(), # relu2-1 123 | nn.ReflectionPad2d((1, 1, 1, 1)), 124 | nn.Conv2d(128, 128, (3, 3)), 125 | nn.ReLU(), # relu2-2 126 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 127 | nn.ReflectionPad2d((1, 1, 1, 1)), 128 | nn.Conv2d(128, 256, (3, 3)), 129 | nn.ReLU(), # relu3-1 130 | nn.ReflectionPad2d((1, 1, 1, 1)), 131 | nn.Conv2d(256, 256, (3, 3)), 132 | nn.ReLU(), # relu3-2 133 | nn.ReflectionPad2d((1, 1, 1, 1)), 134 | nn.Conv2d(256, 256, (3, 3)), 135 | nn.ReLU(), # relu3-3 136 | nn.ReflectionPad2d((1, 1, 1, 1)), 137 | nn.Conv2d(256, 256, (3, 3)), 138 | nn.ReLU(), # relu3-4 139 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 140 | nn.ReflectionPad2d((1, 1, 1, 1)), 141 | nn.Conv2d(256, 512, (3, 3)), 142 | nn.ReLU(), # relu4-1, 143 | nn.ReflectionPad2d((1, 1, 1, 1)), 144 | nn.Conv2d(512, 512, (3, 3)), 145 | nn.ReLU(), # relu4-2 146 | nn.ReflectionPad2d((1, 1, 1, 1)), 147 | nn.Conv2d(512, 512, (3, 3)), 148 | nn.ReLU(), # relu4-3 149 | nn.ReflectionPad2d((1, 1, 1, 1)), 150 | nn.Conv2d(512, 512, (3, 3)), 151 | nn.ReLU(), # relu4-4 152 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 153 | nn.ReflectionPad2d((1, 1, 1, 1)), 154 | nn.Conv2d(512, 512, (3, 3)), 155 | nn.ReLU(), # relu5-1 156 | nn.ReflectionPad2d((1, 1, 1, 1)), 157 | nn.Conv2d(512, 512, (3, 3)), 158 | nn.ReLU(), # relu5-2 159 | nn.ReflectionPad2d((1, 1, 1, 1)), 160 | nn.Conv2d(512, 512, (3, 3)), 161 | nn.ReLU(), # relu5-3 162 | nn.ReflectionPad2d((1, 1, 1, 1)), 163 | nn.Conv2d(512, 512, (3, 3)), 164 | nn.ReLU() # relu5-4 165 | ) 166 | 167 | #--------------------------------------------------------------------------------------------------------------- 168 | 169 | decoder = nn.Sequential( 170 | nn.ReflectionPad2d((1, 1, 1, 1)), 171 | nn.Conv2d(512, 256, (3, 3)), 172 | nn.ReLU(), #relu4_1 173 | nn.Upsample(scale_factor=2, mode='nearest'), 174 | nn.ReflectionPad2d((1, 1, 1, 1)), 175 | nn.Conv2d(256, 256, (3, 3)), 176 | nn.ReLU(), 177 | nn.ReflectionPad2d((1, 1, 1, 1)), 178 | nn.Conv2d(256, 256, (3, 3)), 179 | nn.ReLU(), 180 | nn.ReflectionPad2d((1, 1, 1, 1)), 181 | nn.Conv2d(256, 256, (3, 3)), 182 | nn.ReLU(), 183 | nn.ReflectionPad2d((1, 1, 1, 1)), 184 | nn.Conv2d(256, 128, (3, 3)), 185 | nn.ReLU(), #relu3_1 186 | nn.Upsample(scale_factor=2, mode='nearest'), 187 | nn.ReflectionPad2d((1, 1, 1, 1)), 188 | nn.Conv2d(128, 128, (3, 3)), 189 | nn.ReLU(), 190 | nn.ReflectionPad2d((1, 1, 1, 1)), 191 | nn.Conv2d(128, 64, (3, 3)), 192 | nn.ReLU(), #relu2_1 193 | nn.Upsample(scale_factor=2, mode='nearest'), 194 | nn.ReflectionPad2d((1, 1, 1, 1)), 195 | nn.Conv2d(64, 64, (3, 3)), 196 | nn.ReLU(), #relu1_1 197 | nn.ReflectionPad2d((1, 1, 1, 1)), 198 | nn.Conv2d(64, 3, (3, 3)), 199 | ) 200 | 201 | #--------------------------------------------------------------------------------------------------------------- 202 | 203 | class AttentionUnit(nn.Module): 204 | def __init__(self, channels): 205 | super(AttentionUnit, self).__init__() 206 | self.relu6 = nn.ReLU6() 207 | self.f = nn.Conv2d(channels, channels//2, (1, 1)) 208 | self.g = nn.Conv2d(channels, channels//2, (1, 1)) 209 | self.h = nn.Conv2d(channels, channels//2, (1, 1)) 210 | 211 | self.out_conv = nn.Conv2d(channels//2, channels, (1, 1)) 212 | self.softmax = nn.Softmax(dim = -1) 213 | 214 | def forward(self, Fc, Fs): 215 | B, C, H, W = Fc.shape 216 | f_Fc = self.relu6(self.f(mean_variance_norm(Fc))) 217 | g_Fs = self.relu6(self.g(mean_variance_norm(Fs))) 218 | h_Fs = self.relu6(self.h(Fs)) 219 | f_Fc = f_Fc.view(f_Fc.shape[0], f_Fc.shape[1], -1).permute(0, 2, 1) 220 | g_Fs = g_Fs.view(g_Fs.shape[0], g_Fs.shape[1], -1) 221 | 222 | Attention = self.softmax(torch.bmm(f_Fc, g_Fs)) 223 | 224 | h_Fs = h_Fs.view(h_Fs.shape[0], h_Fs.shape[1], -1) 225 | 226 | Fcs = torch.bmm(h_Fs, Attention.permute(0, 2, 1)) 227 | Fcs = Fcs.view(B, C//2, H, W) 228 | Fcs = self.relu6(self.out_conv(Fcs)) 229 | 230 | return Fcs 231 | 232 | class FuseUnit(nn.Module): 233 | def __init__(self, channels): 234 | super(FuseUnit, self).__init__() 235 | self.proj1 = nn.Conv2d(2*channels, channels, (1, 1)) 236 | self.proj2 = nn.Conv2d(channels, channels, (1, 1)) 237 | self.proj3 = nn.Conv2d(channels, channels, (1, 1)) 238 | 239 | self.fuse1x = nn.Conv2d(channels, 1, (1, 1), stride = 1) 240 | self.fuse3x = nn.Conv2d(channels, 1, (3, 3), stride = 1) 241 | self.fuse5x = nn.Conv2d(channels, 1, (5, 5), stride = 1) 242 | 243 | self.pad3x = nn.ReflectionPad2d((1, 1, 1, 1)) 244 | self.pad5x = nn.ReflectionPad2d((2, 2, 2, 2)) 245 | self.sigmoid = nn.Sigmoid() 246 | 247 | def forward(self, F1, F2): 248 | Fcat = self.proj1(torch.cat((F1, F2), dim=1)) 249 | F1 = self.proj2(F1) 250 | F2 = self.proj3(F2) 251 | 252 | fusion1 = self.sigmoid(self.fuse1x(Fcat)) 253 | fusion3 = self.sigmoid(self.fuse3x(self.pad3x(Fcat))) 254 | fusion5 = self.sigmoid(self.fuse5x(self.pad5x(Fcat))) 255 | fusion = (fusion1 + fusion3 + fusion5) / 3 256 | 257 | return torch.clamp(fusion, min=0, max=1.0)*F1 + torch.clamp(1 - fusion, min=0, max=1.0)*F2 258 | 259 | class PAMA(nn.Module): 260 | def __init__(self, channels): 261 | super(PAMA, self).__init__() 262 | self.conv_in = nn.Conv2d(channels, channels, (3, 3), stride=1) 263 | self.attn = AttentionUnit(channels) 264 | self.fuse = FuseUnit(channels) 265 | self.conv_out = nn.Conv2d(channels, channels, (3, 3), stride=1) 266 | 267 | self.pad = nn.ReflectionPad2d((1, 1, 1, 1)) 268 | self.relu6 = nn.ReLU6() 269 | 270 | def forward(self, Fc, Fs): 271 | Fc = self.relu6(self.conv_in(self.pad(Fc))) 272 | Fs = self.relu6(self.conv_in(self.pad(Fs))) 273 | Fcs = self.attn(Fc, Fs) 274 | Fcs = self.relu6(self.conv_out(self.pad(Fcs))) 275 | Fcs = self.fuse(Fc, Fcs) 276 | 277 | return Fcs 278 | 279 | #--------------------------------------------------------------------------------------------------------------- 280 | 281 | 282 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.utils.data as data 7 | from torchvision import transforms 8 | import PIL.Image as Image 9 | 10 | DEVICE = 'cuda' 11 | mse = nn.MSELoss() 12 | 13 | 14 | def calc_histogram_loss(A, B, histogram_block): 15 | input_hist = histogram_block(A) 16 | target_hist = histogram_block(B) 17 | histogram_loss = (1/np.sqrt(2.0) * (torch.sqrt(torch.sum( 18 | torch.pow(torch.sqrt(target_hist) - torch.sqrt(input_hist), 2)))) / 19 | input_hist.shape[0]) 20 | 21 | return histogram_loss 22 | 23 | # B, C, H, W; mean var on HW 24 | def calc_mean_std(feat, eps=1e-5): 25 | # eps is a small value added to the variance to avoid divide-by-zero. 26 | size = feat.size() 27 | assert (len(size) == 4) 28 | N, C = size[:2] 29 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 30 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 31 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 32 | return feat_mean, feat_std 33 | 34 | def mean_variance_norm(feat): 35 | size = feat.size() 36 | mean, std = calc_mean_std(feat) 37 | normalized_feat = (feat - mean.expand(size)) / std.expand(size) 38 | return normalized_feat 39 | 40 | def train_transform(): 41 | transform_list = [ 42 | transforms.Resize(size=512), 43 | transforms.RandomCrop(256), 44 | transforms.ToTensor() 45 | ] 46 | return transforms.Compose(transform_list) 47 | 48 | def test_transform(): 49 | transform_list = [] 50 | transform_list.append(transforms.Resize(size=(512))) 51 | transform_list.append(transforms.ToTensor()) 52 | transform = transforms.Compose(transform_list) 53 | return transform 54 | 55 | # https://discuss.pytorch.org/t/check-gradient-flow-in-network/15063/7 56 | def plot_grad_flow(named_parameters): 57 | '''Plots the gradients flowing through different layers in the net during training. 58 | Can be used for checking for possible gradient vanishing / exploding problems. 59 | 60 | Usage: Plug this function in Trainer class after loss.backwards() as 61 | "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow''' 62 | ave_grads = [] 63 | max_grads= [] 64 | layers = [] 65 | for n, p in named_parameters: 66 | if(p.requires_grad) and ("bias" not in n): 67 | layers.append(n) 68 | ave_grads.append(p.grad.abs().mean()) 69 | max_grads.append(p.grad.abs().max()) 70 | print('-'*82) 71 | print(n, p.grad.abs().mean(), p.grad.abs().max()) 72 | print('-'*82) 73 | 74 | def InfiniteSampler(n): 75 | # i = 0 76 | i = n - 1 77 | order = np.random.permutation(n) 78 | while True: 79 | yield order[i] 80 | i += 1 81 | if i >= n: 82 | np.random.seed() 83 | order = np.random.permutation(n) 84 | i = 0 85 | 86 | class InfiniteSamplerWrapper(data.sampler.Sampler): 87 | def __init__(self, data_source): 88 | self.num_samples = len(data_source) 89 | 90 | def __iter__(self): 91 | return iter(InfiniteSampler(self.num_samples)) 92 | 93 | def __len__(self): 94 | return 2 ** 31 95 | 96 | class FlatFolderDataset(data.Dataset): 97 | def __init__(self, root, transform): 98 | super(FlatFolderDataset, self).__init__() 99 | self.root = root 100 | self.paths = os.listdir(self.root) 101 | self.transform = transform 102 | 103 | def __getitem__(self, index): 104 | path = self.paths[index] 105 | img = Image.open(os.path.join(self.root, path)).convert('RGB') 106 | img = self.transform(img) 107 | return img 108 | 109 | def __len__(self): 110 | return len(self.paths) 111 | 112 | def name(self): 113 | return 'FlatFolderDataset' 114 | 115 | def adjust_learning_rate(optimizer, iteration_count, args): 116 | """Imitating the original implementation""" 117 | lr = args.lr / (1.0 + 5e-5 * iteration_count) 118 | for param_group in optimizer.param_groups: 119 | param_group['lr'] = lr 120 | 121 | def cosine_dismat(A, B): 122 | A = A.view(A.shape[0], A.shape[1], -1) 123 | B = B.view(B.shape[0], B.shape[1], -1) 124 | 125 | A_norm = torch.sqrt((A**2).sum(1)) 126 | B_norm = torch.sqrt((B**2).sum(1)) 127 | 128 | A = (A/A_norm.unsqueeze(dim=1).expand(A.shape)).permute(0,2,1) 129 | B = (B/B_norm.unsqueeze(dim=1).expand(B.shape)) 130 | dismat = 1.-torch.bmm(A, B) 131 | 132 | return dismat 133 | 134 | def calc_remd_loss(A, B): 135 | C = cosine_dismat(A, B) 136 | m1, _ = C.min(1) 137 | m2, _ = C.min(2) 138 | 139 | remd = torch.max(m1.mean(), m2.mean()) 140 | 141 | return remd 142 | 143 | def calc_ss_loss(A, B): 144 | MA = cosine_dismat(A, A) 145 | MB = cosine_dismat(B, B) 146 | Lself_similarity = torch.abs(MA-MB).mean() 147 | 148 | return Lself_similarity 149 | 150 | def calc_moment_loss(A, B): 151 | A = A.view(A.shape[0], A.shape[1], -1) 152 | B = B.view(B.shape[0], B.shape[1], -1) 153 | 154 | mu_a = torch.mean(A, 1, keepdim=True) 155 | mu_b = torch.mean(B, 1, keepdim=True) 156 | mu_d = torch.abs(mu_a - mu_b).mean() 157 | 158 | A_c = A - mu_a 159 | B_c = B - mu_b 160 | cov_a = torch.bmm(A_c, A_c.permute(0,2,1)) / (A.shape[2]-1) 161 | cov_b = torch.bmm(B_c, B_c.permute(0,2,1)) / (B.shape[2]-1) 162 | cov_d = torch.abs(cov_a - cov_b).mean() 163 | loss = mu_d + cov_d 164 | return loss 165 | 166 | def calc_mse_loss(A, B): 167 | return mse(A, B) 168 | 169 | --------------------------------------------------------------------------------