├── .gitignore ├── README.md ├── app.py ├── config.json ├── models ├── StarGANv2.py └── StarGANv2_model.py ├── requirements.txt ├── static ├── main.js ├── placeholder.png └── style.css ├── templates ├── 404.html ├── index.html ├── partials │ ├── footer.html │ ├── header.html │ ├── message.html │ └── nav.html └── starganv2_afhq.html └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .vscode 3 | *.pyc 4 | *.ckpt 5 | cache 6 | temp -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch Deployment 2 | > A template for rapid deployment of PyTorch models. 3 | 4 |
5 | 点击以展开中文 README 6 |
7 | 8 | ## 功能 9 | + [x] API. 10 | + [x] Web 前端. 11 | + [x] 支持图像转换模型. 12 | + [ ] 支持图像分类模型. 13 | + [ ] 支持图像生成模型. 14 | 15 | ## 演示 16 | ![starganv2_afhq](https://user-images.githubusercontent.com/39998050/155641683-fbef7d4a-7a44-4f60-bf96-7df79a02c0ee.gif) 17 | 18 | ## 部署步骤 19 | 1. 安装依赖:`pip install -r requirements.txt` 20 | 2. 启动服务,你有以下选择: 21 | 1. `flask run -p 3000`,注意该方式性能较差。 22 | 2. 使用 Gunicorn: 23 | 1. 安装 Gunicorn:`pip install gunicorn` 24 | 2. 启动应用:`gunicorn -b 127.0.0.1:3009 app:app` 25 | 26 | 请注意,模型的权重是从 Github 上下载的,如果你的服务器无法正常访问 Github,请手动下载权重并放到 data 文件夹下。 27 | 28 | 目前内置的模型: 29 | + starganv2_afhq.ckpt:https://github.com/songquanpeng/pytorch-deployment/releases/download/v0.1.0/starganv2_afhq.ckpt 30 | 31 |
32 |
33 | 34 | ## Features 35 | + [x] API. 36 | + [x] Web frontend. 37 | + [x] Support image translation models. 38 | + [ ] Support image classification models. 39 | + [ ] Support image generation models. 40 | 41 | ## Demo 42 | ![starganv2_afhq](https://user-images.githubusercontent.com/39998050/155641683-fbef7d4a-7a44-4f60-bf96-7df79a02c0ee.gif) 43 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import werkzeug 2 | from flask import Flask, render_template, send_from_directory, request 3 | from munch import Munch 4 | 5 | from models import StarGANv2 6 | from utils import load_cfg, cache_path 7 | 8 | cfg = load_cfg() 9 | app = Flask(__name__) 10 | 11 | 12 | @app.context_processor 13 | def inject_config(): 14 | return dict(cfg=cfg) 15 | 16 | 17 | @app.route('/') 18 | def index(): 19 | return render_template('index.html') 20 | 21 | 22 | @app.errorhandler(404) 23 | def page_not_found(e): 24 | return render_template('404.html'), 404 25 | 26 | 27 | @app.route('/model/', methods=['GET']) 28 | def model_page(model_id): 29 | if model_id in cfg.models: 30 | model = cfg.models[model_id] 31 | return render_template(f'{model_id}.html', title=model['name'], description=model['description']) 32 | else: 33 | return render_template('index.html', message=f'No such model: {model_id}.', is_warning=True) 34 | 35 | 36 | @app.route('/api/model', methods=['POST']) 37 | def model_inference(): 38 | res = Munch({ 39 | "success": False, 40 | "message": "default message", 41 | "data": None 42 | }) 43 | 44 | try: 45 | model_name = request.form['model'] 46 | if model_name == 'starganv2_afhq': 47 | res = StarGANv2.controller(request) 48 | else: 49 | res.message = f"no such model: {model_name}" 50 | except Exception as e: 51 | res.message = str(e) 52 | print(e) 53 | return res 54 | 55 | 56 | @app.route('/cache/') 57 | def cached_image(filename): 58 | return send_from_directory(cache_path, filename) 59 | 60 | 61 | @app.route('/api/', methods=['POST']) 62 | def predict(model_name): 63 | return { 64 | "success": True, 65 | "message": model_name 66 | } 67 | 68 | 69 | StarGANv2.init(cfg.device) 70 | 71 | if __name__ == '__main__': 72 | app.run(host='0.0.0.0', port=cfg.port) 73 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "port": 3000, 3 | "device": "cpu", 4 | "website": { 5 | "name": "Demo", 6 | "title": "Title", 7 | "description": "Change this description by modify config.json.", 8 | "source_license": "MIT", 9 | "license_link": "http://opensource.org/licenses/mit-license.php", 10 | "source_link": "https://github.com/songquanpeng/pytorch-deployment", 11 | "help_link": "https://github.com/songquanpeng/pytorch-deployment", 12 | "owner_name": "JustSong", 13 | "owner_link": "https://github.com/songquanpeng" 14 | }, 15 | "models": { 16 | "starganv2_afhq": { 17 | "name": "Pretrained StarGANv2 model on AFHQ", 18 | "description": "This model can be used to translate animal faces." 19 | } 20 | } 21 | } -------------------------------------------------------------------------------- /models/StarGANv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | from munch import Munch 4 | from torchvision import transforms 5 | 6 | from models.StarGANv2_model import Generator, MappingNetwork, StyleEncoder 7 | from utils import load_weights, set_eval_mode, save_images, to_device 8 | 9 | nets = Munch() 10 | transform_list = [] 11 | args = Munch() 12 | 13 | 14 | def init(device='cpu'): 15 | # Basic configuration 16 | global args 17 | args = Munch({ 18 | "img_size": 256, 19 | "style_dim": 64, 20 | "latent_dim": 16, 21 | "num_domains": 3, 22 | "w_hpf": 0, 23 | "device": device 24 | }) 25 | # Prepare image transform list 26 | norm_mean = [0.5, 0.5, 0.5] 27 | norm_std = [0.5, 0.5, 0.5] 28 | global transform_list 29 | transform_list = [ 30 | transforms.Resize([args.img_size, args.img_size]), 31 | transforms.ToTensor(), 32 | transforms.Normalize(mean=norm_mean, std=norm_std), 33 | ] 34 | 35 | # Instantiate models 36 | generator = Generator(args.img_size, args.style_dim, w_hpf=args.w_hpf) 37 | mapping_network = MappingNetwork(args.latent_dim, args.style_dim, args.num_domains) 38 | style_encoder = StyleEncoder(args.img_size, args.style_dim, args.num_domains) 39 | global nets 40 | nets = Munch(generator=generator, 41 | mapping_network=mapping_network, 42 | style_encoder=style_encoder) 43 | 44 | # Load parameters 45 | weight_path = "starganv2_afhq.ckpt" 46 | download_path = "https://github.com/songquanpeng/pytorch-deployment/releases/download/v0.1.0/starganv2_afhq.ckpt" 47 | weight_dict = load_weights(weight_path, download_path, args.device) 48 | 49 | # Apply parameters to models 50 | for name, module in nets.items(): 51 | module.load_state_dict(weight_dict[name]) 52 | 53 | # Set model to eval mode 54 | set_eval_mode(nets) 55 | 56 | # To device 57 | to_device(nets, device) 58 | return nets 59 | 60 | 61 | def preprocess(img): 62 | if img is None: 63 | return None 64 | assert len(transform_list) != 0 65 | # TODO: let the frontend do image preprocessing 66 | for transform in transform_list: 67 | img = transform(img) 68 | img = img.to(args.device) 69 | img = torch.unsqueeze(img, dim=0) 70 | return img 71 | 72 | 73 | @torch.no_grad() 74 | def inference(x_src, x_ref, y=0, seed=0, mode='latent'): 75 | batch_size = x_src.shape[0] 76 | y = torch.LongTensor(batch_size).to(args.device).fill_(y) 77 | if mode == 'latent': 78 | torch.manual_seed(seed) 79 | z = torch.randn(batch_size, args.latent_dim).to(args.device) 80 | s = nets.mapping_network(z, y) 81 | elif mode == "reference": 82 | s = nets.style_encoder(x_ref, y) 83 | else: 84 | assert False, f"No such mode: {mode}" 85 | x_fake = nets.generator(x_src, s) 86 | return x_fake 87 | 88 | 89 | def stargan_v2(x_src, x_ref, y=None, seed=0, mode='latent'): 90 | """ 91 | StarGANv2 model 92 | :param x_src: source image 93 | :param x_ref: reference image 94 | :param y: reference image's label or target image 95 | :param seed: random seed for latent mode 96 | :param mode: available options: reference, latent 97 | :return: 98 | """ 99 | res = Munch({ 100 | "success": False, 101 | "message": "default message", 102 | "data": None 103 | }) 104 | 105 | if nets is None: 106 | res.message = "model not initialized" 107 | if mode not in ['latent', 'reference']: 108 | res.message = f"no such mode: {mode}" 109 | res.success = True 110 | x_src = preprocess(x_src) 111 | x_ref = preprocess(x_ref) 112 | 113 | fake_images = inference(x_src, x_ref, y, seed, mode) 114 | filenames = save_images(fake_images) 115 | res.data = filenames 116 | 117 | return res.__dict__ 118 | 119 | 120 | def controller(request): 121 | mode = request.form['mode'] 122 | y = request.form['y'] 123 | y = int(y) 124 | src_img = Image.open(request.files['src_img']) 125 | if mode == 'reference': 126 | ref_img = Image.open(request.files['ref_img']) 127 | res = stargan_v2(src_img, ref_img, y=y, mode=mode) 128 | else: 129 | seed = request.form['seed'] 130 | res = stargan_v2(src_img, x_ref=None, y=y, seed=seed, mode=mode) 131 | return res 132 | 133 | 134 | if __name__ == '__main__': 135 | def main(): 136 | init('cuda') 137 | src_img_path = "./temp/cat.jpg" 138 | ref_img_path = "./temp/dog.jpg" 139 | y = 0 140 | src_img = Image.open(src_img_path).convert('RGB') 141 | ref_img = Image.open(ref_img_path).convert('RGB') 142 | res = stargan_v2(src_img, ref_img, y, mode='reference') 143 | print(res) 144 | res = stargan_v2(src_img, ref_img, y, mode='latent') 145 | print(res) 146 | 147 | 148 | main() 149 | -------------------------------------------------------------------------------- /models/StarGANv2_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | StarGAN v2 3 | Copyright (c) 2020-present NAVER Corp. 4 | 5 | This work is licensed under the Creative Commons Attribution-NonCommercial 6 | 4.0 International License. To view a copy of this license, visit 7 | http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to 8 | Creative Commons, PO Box 1866, Mountain View, CA 94042, USA. 9 | """ 10 | 11 | import math 12 | 13 | import numpy as np 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | 19 | class ResBlk(nn.Module): 20 | def __init__(self, dim_in, dim_out, actv=nn.LeakyReLU(0.2), 21 | normalize=False, downsample=False): 22 | super().__init__() 23 | self.actv = actv 24 | self.normalize = normalize 25 | self.downsample = downsample 26 | self.learned_sc = dim_in != dim_out 27 | self._build_weights(dim_in, dim_out) 28 | 29 | def _build_weights(self, dim_in, dim_out): 30 | self.conv1 = nn.Conv2d(dim_in, dim_in, 3, 1, 1) 31 | self.conv2 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) 32 | if self.normalize: 33 | self.norm1 = nn.InstanceNorm2d(dim_in, affine=True) 34 | self.norm2 = nn.InstanceNorm2d(dim_in, affine=True) 35 | if self.learned_sc: 36 | self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) 37 | 38 | def _shortcut(self, x): 39 | if self.learned_sc: 40 | x = self.conv1x1(x) 41 | if self.downsample: 42 | x = F.avg_pool2d(x, 2) 43 | return x 44 | 45 | def _residual(self, x): 46 | if self.normalize: 47 | x = self.norm1(x) 48 | x = self.actv(x) 49 | x = self.conv1(x) 50 | if self.downsample: 51 | x = F.avg_pool2d(x, 2) 52 | if self.normalize: 53 | x = self.norm2(x) 54 | x = self.actv(x) 55 | x = self.conv2(x) 56 | return x 57 | 58 | def forward(self, x): 59 | x = self._shortcut(x) + self._residual(x) 60 | return x / math.sqrt(2) # unit variance 61 | 62 | 63 | class AdaIN(nn.Module): 64 | def __init__(self, style_dim, num_features): 65 | super().__init__() 66 | self.norm = nn.InstanceNorm2d(num_features, affine=False) 67 | self.fc = nn.Linear(style_dim, num_features * 2) 68 | 69 | def forward(self, x, s): 70 | h = self.fc(s) 71 | h = h.view(h.size(0), h.size(1), 1, 1) 72 | gamma, beta = torch.chunk(h, chunks=2, dim=1) 73 | return (1 + gamma) * self.norm(x) + beta 74 | 75 | 76 | class AdainResBlk(nn.Module): 77 | def __init__(self, dim_in, dim_out, style_dim=64, w_hpf=0, 78 | actv=nn.LeakyReLU(0.2), upsample=False): 79 | super().__init__() 80 | self.w_hpf = w_hpf 81 | self.actv = actv 82 | self.upsample = upsample 83 | self.learned_sc = dim_in != dim_out 84 | self._build_weights(dim_in, dim_out, style_dim) 85 | 86 | def _build_weights(self, dim_in, dim_out, style_dim=64): 87 | self.conv1 = nn.Conv2d(dim_in, dim_out, 3, 1, 1) 88 | self.conv2 = nn.Conv2d(dim_out, dim_out, 3, 1, 1) 89 | self.norm1 = AdaIN(style_dim, dim_in) 90 | self.norm2 = AdaIN(style_dim, dim_out) 91 | if self.learned_sc: 92 | self.conv1x1 = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False) 93 | 94 | def _shortcut(self, x): 95 | if self.upsample: 96 | x = F.interpolate(x, scale_factor=2, mode='nearest') 97 | if self.learned_sc: 98 | x = self.conv1x1(x) 99 | return x 100 | 101 | def _residual(self, x, s): 102 | x = self.norm1(x, s) 103 | x = self.actv(x) 104 | if self.upsample: 105 | x = F.interpolate(x, scale_factor=2, mode='nearest') 106 | x = self.conv1(x) 107 | x = self.norm2(x, s) 108 | x = self.actv(x) 109 | x = self.conv2(x) 110 | return x 111 | 112 | def forward(self, x, s): 113 | out = self._residual(x, s) 114 | if self.w_hpf == 0: 115 | out = (out + self._shortcut(x)) / math.sqrt(2) 116 | return out 117 | 118 | 119 | class HighPass(nn.Module): 120 | def __init__(self, w_hpf, device): 121 | super(HighPass, self).__init__() 122 | self.filter = torch.tensor([[-1, -1, -1], 123 | [-1, 8., -1], 124 | [-1, -1, -1]]).to(device) / w_hpf 125 | 126 | def forward(self, x): 127 | filter = self.filter.unsqueeze(0).unsqueeze(1).repeat(x.size(1), 1, 1, 1) 128 | return F.conv2d(x, filter, padding=1, groups=x.size(1)) 129 | 130 | 131 | class Generator(nn.Module): 132 | def __init__(self, img_size=256, style_dim=64, max_conv_dim=512, w_hpf=1): 133 | super().__init__() 134 | dim_in = 2 ** 14 // img_size 135 | self.img_size = img_size 136 | self.from_rgb = nn.Conv2d(3, dim_in, 3, 1, 1) 137 | self.encode = nn.ModuleList() 138 | self.decode = nn.ModuleList() 139 | self.to_rgb = nn.Sequential( 140 | nn.InstanceNorm2d(dim_in, affine=True), 141 | nn.LeakyReLU(0.2), 142 | nn.Conv2d(dim_in, 3, 1, 1, 0)) 143 | 144 | # down/up-sampling blocks 145 | repeat_num = int(np.log2(img_size)) - 4 146 | if w_hpf > 0: 147 | repeat_num += 1 148 | for _ in range(repeat_num): 149 | dim_out = min(dim_in * 2, max_conv_dim) 150 | self.encode.append( 151 | ResBlk(dim_in, dim_out, normalize=True, downsample=True)) 152 | self.decode.insert( 153 | 0, AdainResBlk(dim_out, dim_in, style_dim, 154 | w_hpf=w_hpf, upsample=True)) # stack-like 155 | dim_in = dim_out 156 | 157 | # bottleneck blocks 158 | for _ in range(2): 159 | self.encode.append( 160 | ResBlk(dim_out, dim_out, normalize=True)) 161 | self.decode.insert( 162 | 0, AdainResBlk(dim_out, dim_out, style_dim, w_hpf=w_hpf)) 163 | 164 | if w_hpf > 0: 165 | device = torch.device( 166 | 'cuda' if torch.cuda.is_available() else 'cpu') 167 | self.hpf = HighPass(w_hpf, device) 168 | 169 | def forward(self, x, s, masks=None): 170 | x = self.from_rgb(x) 171 | cache = {} 172 | for block in self.encode: 173 | if (masks is not None) and (x.size(2) in [32, 64, 128]): 174 | cache[x.size(2)] = x 175 | x = block(x) 176 | for block in self.decode: 177 | x = block(x, s) 178 | if (masks is not None) and (x.size(2) in [32, 64, 128]): 179 | mask = masks[0] if x.size(2) in [32] else masks[1] 180 | mask = F.interpolate(mask, size=x.size(2), mode='bilinear') 181 | x = x + self.hpf(mask * cache[x.size(2)]) 182 | return self.to_rgb(x) 183 | 184 | 185 | class MappingNetwork(nn.Module): 186 | def __init__(self, latent_dim=16, style_dim=64, num_domains=2): 187 | super().__init__() 188 | layers = [] 189 | layers += [nn.Linear(latent_dim, 512)] 190 | layers += [nn.ReLU()] 191 | for _ in range(3): 192 | layers += [nn.Linear(512, 512)] 193 | layers += [nn.ReLU()] 194 | self.shared = nn.Sequential(*layers) 195 | 196 | self.unshared = nn.ModuleList() 197 | for _ in range(num_domains): 198 | self.unshared += [nn.Sequential(nn.Linear(512, 512), 199 | nn.ReLU(), 200 | nn.Linear(512, 512), 201 | nn.ReLU(), 202 | nn.Linear(512, 512), 203 | nn.ReLU(), 204 | nn.Linear(512, style_dim))] 205 | 206 | def forward(self, z, y): 207 | h = self.shared(z) 208 | out = [] 209 | for layer in self.unshared: 210 | out += [layer(h)] 211 | out = torch.stack(out, dim=1) # (batch, num_domains, style_dim) 212 | idx = torch.LongTensor(range(y.size(0))).to(y.device) 213 | s = out[idx, y] # (batch, style_dim) 214 | return s 215 | 216 | 217 | class StyleEncoder(nn.Module): 218 | def __init__(self, img_size=256, style_dim=64, num_domains=2, max_conv_dim=512): 219 | super().__init__() 220 | dim_in = 2 ** 14 // img_size 221 | blocks = [] 222 | blocks += [nn.Conv2d(3, dim_in, 3, 1, 1)] 223 | 224 | repeat_num = int(np.log2(img_size)) - 2 225 | for _ in range(repeat_num): 226 | dim_out = min(dim_in * 2, max_conv_dim) 227 | blocks += [ResBlk(dim_in, dim_out, downsample=True)] 228 | dim_in = dim_out 229 | 230 | blocks += [nn.LeakyReLU(0.2)] 231 | blocks += [nn.Conv2d(dim_out, dim_out, 4, 1, 0)] 232 | blocks += [nn.LeakyReLU(0.2)] 233 | self.shared = nn.Sequential(*blocks) 234 | 235 | self.unshared = nn.ModuleList() 236 | for _ in range(num_domains): 237 | self.unshared += [nn.Linear(dim_out, style_dim)] 238 | 239 | def forward(self, x, y): 240 | h = self.shared(x) 241 | h = h.view(h.size(0), -1) 242 | out = [] 243 | for layer in self.unshared: 244 | out += [layer(h)] 245 | out = torch.stack(out, dim=1) # (batch, num_domains, style_dim) 246 | idx = torch.LongTensor(range(y.size(0))).to(y.device) 247 | s = out[idx, y] # (batch, style_dim) 248 | return s 249 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | flask~=1.1.2 2 | jinja2~=3.0.1 3 | munch~=2.5.0 4 | torchvision~=0.8.1 5 | numpy~=1.18.5 6 | requests~=2.24.0 7 | pillow~=8.3.1 8 | werkzeug~=1.0.1 -------------------------------------------------------------------------------- /static/main.js: -------------------------------------------------------------------------------- 1 | // Credit: https://codepen.io/t7team/pen/ZowdRN 2 | function openTab(e, tabName) { 3 | let i, x, tabLinks; 4 | x = document.getElementsByClassName('content-tab'); 5 | for (i = 0; i < x.length; i++) { 6 | x[i].style.display = 'none'; 7 | } 8 | tabLinks = document.getElementsByClassName('tab'); 9 | for (i = 0; i < x.length; i++) { 10 | tabLinks[i].className = tabLinks[i].className.replace(' is-active', ''); 11 | } 12 | document.getElementById(tabName).style.display = 'block'; 13 | e.className += ' is-active'; 14 | } 15 | 16 | function submitSrcImage() { 17 | let srcImage = document.getElementById("srcImage"); 18 | srcImage.click(); 19 | srcImage.onchange = function (e) { 20 | let image = srcImage.files[0]; 21 | if (image) { 22 | document.getElementById('showSrcImage').src = URL.createObjectURL(image); 23 | } 24 | 25 | } 26 | } 27 | 28 | function submitRefImage() { 29 | let refImage = document.getElementById("refImage"); 30 | refImage.click(); 31 | refImage.onchange = function (e) { 32 | let image = refImage.files[0]; 33 | if (image) { 34 | document.getElementById('showRefImage').src = URL.createObjectURL(image); 35 | } 36 | 37 | } 38 | } 39 | 40 | async function StarGANv2Generate() { 41 | let form = new FormData(); 42 | let y = document.getElementById('y').value; 43 | let seed = document.getElementById('seed').value; 44 | let isRef = document.getElementById("refRadio").checked; 45 | let mode = "latent"; 46 | if (isRef) mode = "reference" 47 | let src_img = document.getElementById("srcImage").files[0]; 48 | if (isRef) { 49 | let ref_img = document.getElementById("refImage").files[0]; 50 | form.append("ref_img", ref_img, "ref_img"); 51 | } 52 | form.append("model", "starganv2_afhq"); 53 | form.append("y", y); 54 | form.append("seed", seed); 55 | form.append("mode", mode); 56 | form.append("src_img", src_img, "src_img"); 57 | let res = await fetch("/api/model", { 58 | method: 'post', 59 | body: form 60 | }); 61 | let data = await res.json(); 62 | if (data.success) { 63 | document.getElementById('showResImage').src = "/" + data.data[0]; 64 | } else { 65 | showErrorMessage(data.message); 66 | } 67 | } 68 | 69 | 70 | function showErrorMessage(message, duration = 5) { 71 | console.error(message); 72 | let e = document.getElementById('errorMessage'); 73 | e.children[0].textContent = `Error: ${message}`; 74 | e.style.display = 'block'; 75 | setTimeout(() => { 76 | e.style.display = 'none'; 77 | }, duration * 1000); 78 | } -------------------------------------------------------------------------------- /static/placeholder.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/songquanpeng/pytorch-deployment/cb55ed8de9ac68835dd638d01c1fa74d5d335c58/static/placeholder.png -------------------------------------------------------------------------------- /static/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | line-height: 1.6; 3 | margin: 0; 4 | font-family: Verdana, Candara, Arial, Helvetica, Microsoft YaHei, sans-serif; 5 | } 6 | 7 | nav { 8 | margin-bottom: 16px; 9 | } 10 | 11 | a { 12 | text-decoration: none; 13 | color: #368CCB; 14 | } 15 | 16 | .wrapper { 17 | max-width: 960px; 18 | margin: 0 auto; 19 | } 20 | 21 | #page-container { 22 | position: relative; 23 | min-height: 97vh; 24 | } 25 | 26 | #content-wrap { 27 | padding-bottom: 4rem; 28 | } 29 | 30 | #footer { 31 | height: 4rem; 32 | } 33 | 34 | #footer a { 35 | color: black; 36 | } 37 | 38 | code { 39 | font-family: Consolas, 'Courier New', monospace; 40 | } 41 | 42 | .page-card-list { 43 | margin: 8px 8px; 44 | } 45 | 46 | .page-card-title { 47 | font-size: x-large; 48 | font-weight: 500; 49 | color: #000000; 50 | text-decoration: none; 51 | } 52 | 53 | .page-card-text { 54 | margin-top: 8px; 55 | } 56 | 57 | .pagination { 58 | margin: 16px 4px; 59 | } 60 | 61 | .pagination a { 62 | border: none; 63 | overflow: hidden; 64 | } 65 | 66 | .shadow { 67 | box-shadow: 0 0.5em 1em -0.125em rgba(10,10,10,.1), 0 0 0 1px rgba(10,10,10,.02); 68 | } 69 | 70 | .nav-shadow { 71 | box-shadow: 0 2px 3px rgba(26,26,26,.1); 72 | } 73 | 74 | .box .article { 75 | overflow-wrap: break-word; 76 | font-size: larger; 77 | word-break: break-word; 78 | line-height: 1.6; 79 | padding: 16px; 80 | margin-bottom: 16px; 81 | background-color: #ffffff; 82 | } 83 | 84 | img { 85 | max-width: 100%; 86 | max-height: 100%; 87 | } 88 | 89 | .normal-container { 90 | margin: auto; 91 | max-width: 960px; 92 | padding: 16px 16px; 93 | overflow-wrap: break-word; 94 | word-break: break-word; 95 | line-height: 1.6; 96 | font-size: larger; 97 | } 98 | 99 | .narrow-container { 100 | margin: auto; 101 | max-width: 560px; 102 | padding: 16px 16px; 103 | overflow-wrap: break-word; 104 | word-break: break-word; 105 | line-height: 1.6; 106 | font-size: larger; 107 | } 108 | 109 | .article a { 110 | color: #368CCB; 111 | text-decoration: none; 112 | } 113 | 114 | .article a:hover { 115 | color: #368CCB; 116 | text-decoration: none; 117 | } 118 | 119 | .article h2, 120 | .article h3, 121 | .article h4, 122 | .article h5, 123 | .article h6 { 124 | font-weight: 700; 125 | line-height: 1.5; 126 | margin: 20px 0 15px; 127 | margin-block-start: 1em; 128 | margin-block-end: 0.2em; 129 | } 130 | 131 | .article h1 { 132 | font-size: 1.7em 133 | } 134 | 135 | .article h2 { 136 | font-size: 1.6em 137 | } 138 | 139 | .article h3 { 140 | font-size: 1.45em 141 | } 142 | 143 | .article h4 { 144 | font-size: 1.25em; 145 | } 146 | 147 | .article h5 { 148 | font-size: 1.1em; 149 | } 150 | .article h6 { 151 | font-size: 1em; 152 | font-weight: bold 153 | } 154 | 155 | @media screen and (max-width: 960px) { 156 | .article h1 { 157 | font-size: 1.5em 158 | } 159 | 160 | .article h2 { 161 | font-size: 1.35em 162 | } 163 | 164 | .article h3 { 165 | font-size: 1.3em 166 | } 167 | 168 | .article h4 { 169 | font-size: 1.2em; 170 | } 171 | } 172 | 173 | .article p { 174 | margin-top: 0; 175 | margin-bottom: 1.25rem; 176 | } 177 | 178 | .article table { 179 | margin: auto; 180 | border-collapse: collapse; 181 | border-spacing: 0; 182 | vertical-align: middle; 183 | text-align: left; 184 | min-width: 66%; 185 | } 186 | 187 | .article table td, 188 | .article table th { 189 | padding: 5px 8px; 190 | border: 1px solid #bbb; 191 | } 192 | 193 | .article blockquote { 194 | margin-left: 0; 195 | padding: 0 1em; 196 | font-size: smaller; 197 | border-left: 5px solid #ddd; 198 | } 199 | 200 | .article pre { 201 | overflow-x: auto; 202 | padding: 0; 203 | font-size: 16px; 204 | margin-top: 12px; 205 | margin-bottom: 12px; 206 | } 207 | 208 | .article ol { 209 | text-decoration: none; 210 | padding-inline-start: 40px; 211 | margin-bottom: 1.25rem; 212 | } 213 | 214 | .article code { 215 | color: #bc9458; 216 | padding: .065em .4em; 217 | } 218 | 219 | .article .copyright{ 220 | display: none; 221 | } 222 | 223 | .info { 224 | font-size: 14px; 225 | line-height: 28px; 226 | text-align: left; 227 | color: #738292; 228 | margin-bottom: 3em 229 | } 230 | 231 | .info a { 232 | text-decoration: none; 233 | color: inherit; 234 | } 235 | 236 | .links { 237 | margin: 16px; 238 | } 239 | 240 | span.line { 241 | display: inline-block; 242 | } 243 | 244 | .toc { 245 | position: sticky; 246 | top: 24px; 247 | } 248 | 249 | .image.is-256x256 { 250 | height: 256px; 251 | width: 256px; 252 | } 253 | 254 | .footer { 255 | padding: 3rem 1.5rem; 256 | } -------------------------------------------------------------------------------- /templates/404.html: -------------------------------------------------------------------------------- 1 | {% include './partials/header.html' %} 2 | 3 |
4 |
5 |
6 | Page Not Found 7 |
8 |
9 |
10 | 11 | {% include './partials/footer.html' %} -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | {% include './partials/header.html' %} 2 | 3 |
4 | {% include './partials/message.html' %} 5 |
6 | {% for mode_id, model in cfg.models.items() %} 7 |
8 |
9 |
10 |
11 | 13 |
14 | {{ model['description'] }} 15 |
16 |
17 |
18 |
19 |
20 | {% endfor %} 21 |
22 | 23 |
24 | 25 | {% include './partials/footer.html' %} -------------------------------------------------------------------------------- /templates/partials/footer.html: -------------------------------------------------------------------------------- 1 | 2 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /templates/partials/header.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | {% if title %} 7 | {{ title }} 8 | {% else %} 9 | {{ cfg.website.title }} 10 | {% endif %} 11 | {% if description %} 12 | 13 | {% else %} 14 | 15 | {% endif %} 16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | {% include './partials/nav.html' %} 24 |
-------------------------------------------------------------------------------- /templates/partials/message.html: -------------------------------------------------------------------------------- 1 | {% if message %} 2 | {% if is_warning %} 3 |
4 | {% else %} 5 |
6 | {% endif %} 7 |
8 | {{ message }} 9 |
10 | 15 |
16 | {% endif %} -------------------------------------------------------------------------------- /templates/partials/nav.html: -------------------------------------------------------------------------------- 1 | 31 | -------------------------------------------------------------------------------- /templates/starganv2_afhq.html: -------------------------------------------------------------------------------- 1 | {% include './partials/header.html' %} 2 | 3 |
4 | {% include './partials/message.html' %} 5 |
6 |
{{ title }}
7 |
8 |
9 | {{ description }} 10 |
11 |
12 | 17 |
18 |
19 |
20 |
21 | 22 |
23 |
24 | 29 |
30 |
31 |
32 |
33 |
34 |
35 | 36 |
37 | 38 |
39 |
40 |
41 |
42 |
43 | 44 |
45 | 49 | 53 | 54 |
55 |
56 |
57 |
58 | 59 |
60 |
61 |
62 | 63 |
64 |
65 | src 66 |
67 | 68 | 70 |
71 |
72 |
73 |
74 |
75 | 76 |
77 |
78 | ref 79 |
80 | 81 | 83 |
84 |
85 |
86 |
87 |
88 | 89 |
90 |
91 | fake 92 |
93 | 95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 | 103 | {% include './partials/footer.html' %} -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import uuid 4 | 5 | import numpy as np 6 | import requests 7 | import torch 8 | from PIL import Image 9 | from munch import Munch 10 | 11 | 12 | def load_cfg(cfg_path="config.json"): 13 | assert os.path.exists(cfg_path), "config.json is missing!" 14 | with open(cfg_path, 'rb') as f: 15 | cfg = json.load(f) 16 | cfg = Munch(cfg) 17 | return cfg 18 | 19 | 20 | cache_path = 'cache' 21 | data_path = './data' 22 | 23 | 24 | def load_weights(file_name, download_url, device): 25 | os.makedirs(data_path, exist_ok=True) 26 | weight_path = os.path.join(data_path, file_name) 27 | if not os.path.exists(weight_path): 28 | print(f"Downloading from: {download_url}...") 29 | res = requests.get(download_url) 30 | with open(weight_path, 'wb') as f: 31 | f.write(res.content) 32 | print(f'File saved at: {weight_path}') 33 | return torch.load(weight_path, map_location=torch.device(device)) 34 | 35 | 36 | def set_eval_mode(nets): 37 | for net in nets.values(): 38 | net.eval() 39 | 40 | 41 | def to_device(nets, device): 42 | for net in nets.values(): 43 | net.to(device) 44 | 45 | 46 | def denormalize(x): 47 | out = (x + 1) / 2 48 | return out.clamp_(0, 1) 49 | 50 | 51 | def get_image_name(): 52 | return f"{uuid.uuid4().hex}.png" 53 | 54 | 55 | def save_images(imgs): 56 | imgs = denormalize(imgs) 57 | imgs = imgs * 255 58 | imgs = imgs.cpu().numpy().astype(np.uint8) 59 | imgs = imgs.transpose((0, 2, 3, 1)) 60 | os.makedirs(cache_path, exist_ok=True) 61 | filenames = [] 62 | for img in imgs: 63 | img = Image.fromarray(img) 64 | filename = os.path.join(cache_path, get_image_name()).replace("\\", "/") 65 | img.save(filename) 66 | filenames.append(filename) 67 | return filenames 68 | --------------------------------------------------------------------------------