├── models ├── styleganv2 │ ├── __init__.py │ ├── op │ │ ├── __init__.py │ │ ├── fused_bias_act.cpp │ │ ├── upfirdn2d.cpp │ │ ├── fused_bias_act_kernel.cu │ │ ├── fused_act.py │ │ ├── upfirdn2d.py │ │ └── upfirdn2d_kernel.cu │ └── modules.py ├── common │ ├── base │ │ ├── __init__.py │ │ ├── heads.py │ │ ├── initialization.py │ │ ├── modules.py │ │ └── model.py │ ├── unet.py │ ├── encoders │ │ ├── _utils.py │ │ └── resnetv2.py │ └── decoderv2.py ├── renderer.py ├── generator.py └── efficientnet.py ├── assets ├── icons │ ├── paper.png │ ├── supmat.png │ └── project.png └── scripts_scheme.png ├── docker ├── 10_nvidia.json ├── source.sh ├── build.sh ├── run.sh └── Dockerfile ├── inference_module ├── __init__.py ├── config.yaml ├── criterion.py ├── inferer.py └── runner.py ├── utils ├── smplx_models.py ├── demo.py ├── common.py ├── uv_renderer.py └── bbox.py ├── requirements.txt ├── LICENSE ├── infer_texture.py ├── sample_new_textures.py ├── render_person.py ├── README.md └── dataloaders └── inference_loader.py /models/styleganv2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/icons/paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dolorousrtur/style-people/HEAD/assets/icons/paper.png -------------------------------------------------------------------------------- /assets/icons/supmat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dolorousrtur/style-people/HEAD/assets/icons/supmat.png -------------------------------------------------------------------------------- /assets/icons/project.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dolorousrtur/style-people/HEAD/assets/icons/project.png -------------------------------------------------------------------------------- /assets/scripts_scheme.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Dolorousrtur/style-people/HEAD/assets/scripts_scheme.png -------------------------------------------------------------------------------- /models/common/base/__init__.py: -------------------------------------------------------------------------------- 1 | from .heads import ( 2 | SegmentationHeadv2 3 | ) 4 | from .model import SegmentationModel 5 | -------------------------------------------------------------------------------- /models/styleganv2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /docker/10_nvidia.json: -------------------------------------------------------------------------------- 1 | { 2 | "file_format_version" : "1.0.0", 3 | "ICD" : { 4 | "library_path" : "libEGL_nvidia.so.0" 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /inference_module/__init__.py: -------------------------------------------------------------------------------- 1 | import inference_module.criterion 2 | import inference_module.inferer 3 | import inference_module.runner 4 | 5 | __version__ = '1.0' -------------------------------------------------------------------------------- /docker/source.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | PORT=8087 3 | 4 | PARAMS="-p ${PORT}:${PORT} --net=host --ipc=host -u $(id -u ${USER}):$(id -g ${USER})" 5 | NAME="stylepeople" 6 | VOLUMES="-v /:/mounted" -------------------------------------------------------------------------------- /docker/build.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CURRENT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 4 | source ${CURRENT_DIR}/source.sh 5 | 6 | docker build -t $NAME -f ${CURRENT_DIR}/Dockerfile ${CURRENT_DIR}/.. -------------------------------------------------------------------------------- /docker/run.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CURRENT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 4 | source ${CURRENT_DIR}/source.sh 5 | 6 | CODE_DIR=/mounted/${CURRENT_DIR}/.. 7 | NV_GPU=$(nvidia-smi --query-gpu=uuid --format=csv,noheader | tr '\n' ',') nvidia-docker run -w=${CODE_DIR} -ti $PARAMS $VOLUMES $NAME $@ 8 | -------------------------------------------------------------------------------- /utils/smplx_models.py: -------------------------------------------------------------------------------- 1 | import smplx 2 | import os 3 | 4 | def build_smplx_model_dict(smplx_model_dir, device): 5 | gender2filename = dict(neutral='SMPLX_NEUTRAL.pkl', male='SMPLX_MALE.pkl', female='SMPLX_FEMALE.pkl') 6 | gender2path = {k:os.path.join(smplx_model_dir, v) for (k, v) in gender2filename.items()} 7 | gender2model = {k:smplx.body_models.SMPLX(v).to(device) for (k, v) in gender2path.items()} 8 | 9 | return gender2model -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipykernel==5.2.0 2 | jupyter==1.0.0 3 | matplotlib==3.2.1 4 | nbconvert==5.6.1 5 | numpy==1.18.2 6 | opencv-python==4.2.0.32 7 | pandas==1.0.3 8 | Pillow==7.0.0 9 | plotly==4.9.0 10 | scikit-image==0.16.2 11 | scikit-learn==0.22.2.post1 12 | scipy==1.4.1 13 | tqdm==4.46.0 14 | trimesh==3.6.16 15 | imageio==2.9.0 16 | ninja==1.10.0 17 | imgaug==0.4.0 18 | lpips==0.1.3 19 | moviepy==1.0.3 20 | munch==2.5.0 21 | yamlenv==0.7.1 22 | omegaconf==2.0.6 23 | smplx==0.1.26 24 | efficientnet-pytorch==0.7.0 25 | kornia 26 | 27 | git+https://github.com/karfly/nvdiffrast_compute-capability_6.0 28 | git+https://github.com/nghorbani/configer 29 | -------------------------------------------------------------------------------- /models/common/base/heads.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from models.styleganv2.modules import EqualConv2d, EqualConv2dSame 4 | from .modules import Flatten, Activation 5 | 6 | 7 | class SegmentationHeadv2(nn.Sequential): 8 | 9 | def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1, same=False): 10 | ConvLayer = EqualConv2d if not same else EqualConv2dSame 11 | conv2d = ConvLayer(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) 12 | upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity() 13 | activation = Activation(activation) 14 | super().__init__(conv2d, upsampling, activation) 15 | -------------------------------------------------------------------------------- /models/styleganv2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /models/common/base/initialization.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def initialize_decoder(module): 5 | for m in module.modules(): 6 | 7 | if isinstance(m, nn.Conv2d): 8 | nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu") 9 | if m.bias is not None: 10 | nn.init.constant_(m.bias, 0) 11 | 12 | elif isinstance(m, nn.BatchNorm2d): 13 | nn.init.constant_(m.weight, 1) 14 | nn.init.constant_(m.bias, 0) 15 | 16 | elif isinstance(m, nn.Linear): 17 | nn.init.xavier_uniform_(m.weight) 18 | if m.bias is not None: 19 | nn.init.constant_(m.bias, 0) 20 | 21 | 22 | def initialize_head(module): 23 | for m in module.modules(): 24 | if isinstance(m, (nn.Linear, nn.Conv2d)): 25 | nn.init.xavier_uniform_(m.weight) 26 | if m.bias is not None: 27 | nn.init.constant_(m.bias, 0) 28 | -------------------------------------------------------------------------------- /models/styleganv2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2021 Samsung Electronics Co., Ltd. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in 13 | all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 21 | THE SOFTWARE. 22 | -------------------------------------------------------------------------------- /models/common/base/modules.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | try: 4 | from inplace_abn import InPlaceABN 5 | except ImportError: 6 | InPlaceABN = None 7 | 8 | class Activation(nn.Module): 9 | 10 | def __init__(self, name, **params): 11 | 12 | super().__init__() 13 | 14 | if name is None or name == 'identity': 15 | self.activation = nn.Identity(**params) 16 | elif name == 'sigmoid': 17 | self.activation = nn.Sigmoid() 18 | elif name == 'softmax2d': 19 | self.activation = nn.Softmax(dim=1, **params) 20 | elif name == 'softmax': 21 | self.activation = nn.Softmax(**params) 22 | elif name == 'logsoftmax': 23 | self.activation = nn.LogSoftmax(**params) 24 | elif callable(name): 25 | self.activation = name(**params) 26 | else: 27 | raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/None; got {}'.format(name)) 28 | 29 | def forward(self, x): 30 | return self.activation(x) 31 | 32 | class Flatten(nn.Module): 33 | def forward(self, x): 34 | return x.view(x.shape[0], -1) 35 | -------------------------------------------------------------------------------- /models/common/unet.py: -------------------------------------------------------------------------------- 1 | from .base import SegmentationHeadv2 2 | from .base import SegmentationModel 3 | from .decoderv2 import UnetDecoderv2 4 | from .encoders import resnetv2 5 | 6 | class Unetv2(SegmentationModel): 7 | def __init__( 8 | self, 9 | in_channels, 10 | classes, 11 | activation=None, 12 | ngf = 64, 13 | same=False 14 | ): 15 | super().__init__() 16 | 17 | self.encoder = resnetv2.ResNetv2Encoder( 18 | in_channels=in_channels, 19 | ngf=ngf, 20 | same=same 21 | ) 22 | 23 | decoder_channels = (256, 128, 64, 32, 16) 24 | self.decoder = UnetDecoderv2( 25 | encoder_channels=self.encoder.out_channels, 26 | decoder_channels=decoder_channels, 27 | n_blocks=5, 28 | same=same 29 | ) 30 | 31 | self.segmentation_head = SegmentationHeadv2( 32 | in_channels=decoder_channels[-1], 33 | out_channels=classes, 34 | activation=activation, 35 | kernel_size=3, 36 | same=same 37 | ) 38 | 39 | self.classification_head = None 40 | 41 | # self.name = "u-{}".format(encoder_name) 42 | self.initialize() 43 | 44 | -------------------------------------------------------------------------------- /models/common/base/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from . import initialization as init 4 | 5 | 6 | class SegmentationModel(torch.nn.Module): 7 | 8 | def initialize(self): 9 | init.initialize_decoder(self.decoder) 10 | init.initialize_head(self.segmentation_head) 11 | if self.classification_head is not None: 12 | init.initialize_head(self.classification_head) 13 | 14 | def forward(self, x): 15 | """Sequentially pass `x` trough model`s encoder, decoder and heads""" 16 | features = self.encoder(x) 17 | decoder_output = self.decoder(*features) 18 | 19 | masks = self.segmentation_head(decoder_output) 20 | 21 | if self.classification_head is not None: 22 | labels = self.classification_head(features[-1]) 23 | return masks, labels 24 | 25 | return masks 26 | 27 | def predict(self, x): 28 | """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()` 29 | 30 | Args: 31 | x: 4D torch tensor with shape (batch_size, channels, height, width) 32 | 33 | Return: 34 | prediction: 4D torch tensor with shape (batch_size, classes, height, width) 35 | 36 | """ 37 | if self.training: 38 | self.eval() 39 | 40 | with torch.no_grad(): 41 | x = self.forward(x) 42 | 43 | return x -------------------------------------------------------------------------------- /models/common/encoders/_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def patch_first_conv(model, in_channels): 6 | """Change first convolution layer input channels. 7 | In case: 8 | in_channels == 1 or in_channels == 2 -> reuse original weights 9 | in_channels > 3 -> make random kaiming normal initialization 10 | """ 11 | 12 | # get first conv 13 | for module in model.modules(): 14 | if isinstance(module, nn.Conv2d): 15 | break 16 | 17 | # change input channels for first conv 18 | module.in_channels = in_channels 19 | weight = module.weight.detach() 20 | reset = False 21 | 22 | if in_channels == 1: 23 | weight = weight.sum(1, keepdim=True) 24 | elif in_channels == 2: 25 | weight = weight[:, :2] * (3.0 / 2.0) 26 | else: 27 | reset = True 28 | weight = torch.Tensor( 29 | module.out_channels, 30 | module.in_channels // module.groups, 31 | *module.kernel_size 32 | ) 33 | 34 | module.weight = nn.parameter.Parameter(weight) 35 | if reset: 36 | module.reset_parameters() 37 | 38 | 39 | def replace_strides_with_dilation(module, dilation_rate): 40 | """Patch Conv2d modules replacing strides with dilation""" 41 | for mod in module.modules(): 42 | if isinstance(mod, nn.Conv2d): 43 | mod.stride = (1, 1) 44 | mod.dilation = (dilation_rate, dilation_rate) 45 | kh, kw = mod.kernel_size 46 | mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate) 47 | 48 | # Kostyl for EfficientNet 49 | if hasattr(mod, "static_padding"): 50 | mod.static_padding = nn.Identity() 51 | -------------------------------------------------------------------------------- /infer_texture.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import cv2 5 | import numpy as np 6 | 7 | from utils.common import tti 8 | from utils.demo import DemoInferer 9 | from omegaconf import OmegaConf 10 | 11 | 12 | def main(): 13 | parser = argparse.ArgumentParser(description='') 14 | parser.add_argument('--checkpoint_path', type=str, default='data/checkpoint/generative_model.pth', help='Path to generative model checkpoint') 15 | parser.add_argument('--config_path', type=str, default='inference_module/config.yaml') 16 | parser.add_argument('--smplx_model_dir', type=str, default='data/smplx/', help='Path to smplx models') 17 | parser.add_argument('--input_path', type=str, default='data/inference_samples/azure_02', help='Path to a directory that contains data samples') 18 | parser.add_argument('--texture_out_dir', type=str, default='data/textures/azure_02', help='Path to a directory to save fitted texture in') 19 | parser.add_argument('--n_rotimgs', type=int, default=8, help='Number of rotation steps to render textured model in') 20 | parser.add_argument('--imsize', type=int, default=1024, help='Resolution in which to render rotation steps') 21 | parser.add_argument('--device', type=str, default='cuda:0', help='Device to run inference process on') 22 | args = parser.parse_args() 23 | 24 | texture_save_dir = args.texture_out_dir 25 | os.makedirs(texture_save_dir, exist_ok=True) 26 | 27 | 28 | inferer = DemoInferer(args.checkpoint_path, args.smplx_model_dir, imsize=args.imsize, device=args.device) 29 | config = OmegaConf.load(args.config_path) 30 | ntexture = inferer.infer(config, args.input_path) 31 | 32 | texture_out_path = os.path.join(texture_save_dir, 'texture.pth') 33 | torch.save(ntexture.cpu(), texture_out_path) 34 | if args.n_rotimgs > 0: 35 | rot_images, _ = inferer.make_rotation_images(ntexture, args.n_rotimgs) 36 | 37 | for j, rgb in enumerate(rot_images): 38 | rgb = tti(rgb) 39 | rgb = (rgb * 255).astype(np.uint8) 40 | 41 | rgb_out_path = os.path.join(texture_save_dir, 'rotation_images', f"{j:04d}.png") 42 | os.makedirs(os.path.dirname(rgb_out_path), exist_ok=True) 43 | cv2.imwrite(rgb_out_path, rgb[..., ::-1]) 44 | 45 | 46 | if __name__ == '__main__': 47 | main() 48 | -------------------------------------------------------------------------------- /sample_new_textures.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from tqdm import tqdm 8 | 9 | from utils.common import tti, set_random_seed 10 | from utils.demo import DemoInferer 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser(description='') 14 | parser.add_argument('--textures_root', type=str, default='data/textures', help='Root directory to store textures in') 15 | parser.add_argument('--checkpoint_path', type=str, default='data/checkpoint/generative_model.pth', help='Path to generative model checkpoint') 16 | parser.add_argument('--smplx_model_dir', type=str, default='data/smplx', help='Path to smplx models') 17 | parser.add_argument('--texture_batch_name', type=str, help='An identifier for the current run of this script') 18 | parser.add_argument('--n_samples', type=int, default=10, help='Number of textures to sample') 19 | parser.add_argument('--n_rotimgs', type=int, default=8, help='Number of rotation steps to render textured model in') 20 | parser.add_argument('--imsize', type=int, default=1024, help='Resolution in which to render rotation steps') 21 | parser.add_argument('--seed', type=int, help='Rendom seed') 22 | parser.add_argument('--device', type=str, default='cuda:0', help='Device to run sampling process on') 23 | args = parser.parse_args() 24 | 25 | if args.seed is not None: 26 | set_random_seed(args.seed) 27 | 28 | inferer = DemoInferer(args.checkpoint_path, args.smplx_model_dir, imsize=args.imsize, device=args.device) 29 | batch_dir = os.path.join(args.textures_root, args.texture_batch_name) 30 | 31 | for i in tqdm(range(args.n_samples)): 32 | with torch.no_grad(): 33 | ntexture = inferer.sample_texture() 34 | 35 | texture_save_dir = os.path.join(batch_dir, f"{i:04}") 36 | os.makedirs(texture_save_dir, exist_ok=True) 37 | 38 | texture_out_path = os.path.join(texture_save_dir, 'texture.pth') 39 | torch.save(ntexture.cpu(), texture_out_path) 40 | 41 | if args.n_rotimgs > 0: 42 | rot_images, ltrb = inferer.make_rotation_images(ntexture, args.n_rotimgs) 43 | 44 | for j, rgb in enumerate(rot_images): 45 | rgb = tti(rgb) 46 | rgb = (rgb * 255).astype(np.uint8) 47 | 48 | rgb_out_path = os.path.join(texture_save_dir, 'rotation_images', f"{j:04d}.png") 49 | os.makedirs(os.path.dirname(rgb_out_path), exist_ok=True) 50 | cv2.imwrite(rgb_out_path, rgb[..., ::-1]) 51 | 52 | print(f"Stored {args.n_samples} random textures into {batch_dir}") 53 | -------------------------------------------------------------------------------- /render_person.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | 9 | from utils.common import tti 10 | from utils.demo import DemoInferer 11 | 12 | if __name__ == '__main__': 13 | parser = argparse.ArgumentParser(description='') 14 | parser.add_argument('--checkpoint_path', type=str, default='data/checkpoint/generative_model.pth', help='Path to generative model checkpoint') 15 | parser.add_argument('--smplx_model_dir', type=str, default='data/smplx/', help='Path to smplx models') 16 | parser.add_argument('--texture_path', type=str, help='Path to a .pth neural texture file') 17 | parser.add_argument('--smplx_dict_path', type=str, help='Path to a .pkl file with smplx parameters') 18 | parser.add_argument('--save_dir', type=str, help='Path to a directory to save generated images in') 19 | parser.add_argument('--n_rotimgs', type=int, default=8, help='Number of rotation steps to render textured model in') 20 | parser.add_argument('--imsize', type=int, default=1024, help='Resolution in which to render images (1024 recommended)') 21 | parser.add_argument('--device', type=str, default='cuda:0', help='Device to run images generation process on') 22 | args = parser.parse_args() 23 | 24 | os.makedirs(args.save_dir, exist_ok=True) 25 | 26 | inferer = DemoInferer(args.checkpoint_path, args.smplx_model_dir, imsize=args.imsize, device=args.device) 27 | ntexture = torch.load(args.texture_path).to(args.device) 28 | 29 | # vertices, K = inferer.load_smplx(args.smplx_dict_path) 30 | # vertices, K, ltrb = inferer.crop_vertices(vertices, K) 31 | # rgb = inferer.make_rgb(vertices, ntexture) 32 | # rgb = (tti(rgb) * 255).astype(np.uint8) 33 | 34 | # rgb_out_path = os.path.join(args.save_dir, f"rgb.png") 35 | # cv2.imwrite(rgb_out_path, rgb[..., ::-1]) 36 | 37 | rot_images, ltrb = inferer.make_rotation_images(ntexture, args.n_rotimgs, smplx_path=args.smplx_dict_path) 38 | 39 | 40 | for j, rgb in enumerate(rot_images): 41 | rgb = tti(rgb) 42 | rgb = (rgb * 255).astype(np.uint8) 43 | 44 | if j == 0: 45 | rgb_out_path = os.path.join(args.save_dir, f"rgb.png") 46 | os.makedirs(os.path.dirname(rgb_out_path), exist_ok=True) 47 | cv2.imwrite(rgb_out_path, rgb[..., ::-1]) 48 | 49 | rgb_out_path = os.path.join(args.save_dir, 'rotation_images', f"{j:04d}.png") 50 | os.makedirs(os.path.dirname(rgb_out_path), exist_ok=True) 51 | cv2.imwrite(rgb_out_path, rgb[..., ::-1]) 52 | 53 | ltrb = ltrb[0].cpu().numpy().tolist() 54 | with open(os.path.join(args.save_dir, f"ltrb.json"), 'w') as f: 55 | json.dump(ltrb, f) 56 | -------------------------------------------------------------------------------- /inference_module/config.yaml: -------------------------------------------------------------------------------- 1 | log: 2 | experiment_name: "default" 3 | 4 | log_freq: 999 5 | log_val_freq: 1 6 | log_vis_freq: 20 7 | log_n_samples: -1 8 | log_n_rotation_samples: 64 9 | 10 | inferer: 11 | faces_path: data/uv_renderer/face_tex.npy 12 | uv_vert_values_path: data/uv_renderer/uv.npy 13 | 14 | # smplx 15 | body_models_path: data 16 | n_mean_latent: 10000 17 | 18 | runner: 19 | init_with_encoder: true 20 | save_checkpoint: false 21 | 22 | white: false # specify this to generate images on white background 23 | 24 | texture_segm_path: data/texture_segm_256.pkl 25 | v_inds_path: data/v_inds.npy 26 | copy_hand_texture: false 27 | 28 | stages: 29 | stage_1: 30 | input_source: latent_and_noise 31 | n_iters: 100 32 | 33 | optimization_targets: 34 | - latent 35 | 36 | lr: 37 | latent: 0.01 38 | 39 | loss_weight: 40 | lpips: 1.0 41 | mse: 0.5 42 | face_lpips: 0.5 43 | encoder_latent_deviation: 0.1 44 | generator_params_deviation: 0.0 45 | ntexture_deviation: 0.0 46 | 47 | stage_2: 48 | input_source: latent_and_noise 49 | n_iters: 70 50 | 51 | optimization_targets: 52 | - generator 53 | 54 | lr: 55 | generator: 0.01 56 | 57 | loss_weight: 58 | lpips: 1.0 59 | mse: 0.5 60 | face_lpips: 1.0 61 | encoder_latent_deviation: 0.0 62 | generator_params_deviation: 1.0 63 | ntexture_deviation: 0.0 ## TODO 64 | 65 | stage_3: 66 | input_source: latent_and_noise 67 | n_iters: 30 68 | 69 | optimization_targets: 70 | - noise 71 | 72 | lr: 73 | noise: 0.1 74 | 75 | loss_weight: 76 | lpips: 1.0 77 | mse: 0.5 78 | face_lpips: 0.0 79 | encoder_latent_deviation: 0.0 80 | generator_params_deviation: 0.0 81 | ntexture_deviation: 0.0 82 | 83 | stage_4: 84 | input_source: ntexture 85 | n_iters: 200 86 | 87 | optimization_targets: 88 | - ntexture 89 | 90 | lr: 91 | ntexture: 0.15 92 | 93 | loss_weight: 94 | lpips: 1.0 95 | mse: 0.5 96 | face_lpips: 2.0 97 | encoder_latent_deviation: 0.0 98 | generator_params_deviation: 0.0 99 | ntexture_deviation: 0.1 100 | 101 | generator: 102 | experiment_dir: data 103 | checkpoint_name: generator.pth.tar 104 | 105 | image_h: 512 106 | image_w: 512 107 | 108 | render_h: 256 109 | render_w: 256 110 | 111 | divide_n_channels: 1 112 | 113 | encoder: 114 | experiment_dir: data 115 | checkpoint_name: encoder.pth 116 | 117 | bbox_scale: 1.2 118 | image_size: 512 119 | 120 | random_seed: 0 121 | device: "cuda:0" -------------------------------------------------------------------------------- /models/styleganv2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /models/styleganv2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | module_path = os.path.dirname(__file__) 9 | fused = load( 10 | 'fused', 11 | sources=[ 12 | os.path.join(module_path, 'fused_bias_act.cpp'), 13 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 14 | ], 15 | ) 16 | 17 | 18 | class FusedLeakyReLUFunctionBackward(Function): 19 | @staticmethod 20 | def forward(ctx, grad_output, out, negative_slope, scale): 21 | ctx.save_for_backward(out) 22 | ctx.negative_slope = negative_slope 23 | ctx.scale = scale 24 | 25 | empty = grad_output.new_empty(0) 26 | 27 | grad_input = fused.fused_bias_act( 28 | grad_output, empty, out, 3, 1, negative_slope, scale 29 | ) 30 | 31 | dim = [0] 32 | 33 | if grad_input.ndim > 2: 34 | dim += list(range(2, grad_input.ndim)) 35 | 36 | grad_bias = grad_input.sum(dim).detach() 37 | 38 | return grad_input, grad_bias 39 | 40 | @staticmethod 41 | def backward(ctx, gradgrad_input, gradgrad_bias): 42 | out, = ctx.saved_tensors 43 | gradgrad_out = fused.fused_bias_act( 44 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 45 | ) 46 | 47 | return gradgrad_out, None, None, None 48 | 49 | 50 | class FusedLeakyReLUFunction(Function): 51 | @staticmethod 52 | def forward(ctx, input, bias, negative_slope, scale): 53 | 54 | if input.dtype != bias.dtype: 55 | input = input.type(bias.dtype) 56 | empty = input.new_empty(0) 57 | 58 | 59 | 60 | # print('FusedLeakyReLUFunction::forward') 61 | # print('input', input.dtype) 62 | # print('bias', bias.dtype) 63 | # print('empty', empty.dtype) 64 | # print('-------\n') 65 | 66 | 67 | 68 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 69 | ctx.save_for_backward(out) 70 | ctx.negative_slope = negative_slope 71 | ctx.scale = scale 72 | 73 | return out 74 | 75 | @staticmethod 76 | def backward(ctx, grad_output): 77 | out, = ctx.saved_tensors 78 | 79 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 80 | grad_output, out, ctx.negative_slope, ctx.scale 81 | ) 82 | 83 | return grad_input, grad_bias, None, None 84 | 85 | 86 | class FusedLeakyReLU(nn.Module): 87 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 88 | super().__init__() 89 | 90 | self.bias = nn.Parameter(torch.zeros(channel)) 91 | self.negative_slope = negative_slope 92 | self.scale = scale 93 | 94 | def forward(self, input): 95 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 96 | 97 | 98 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 99 | # print('fused_leaky_relu') 100 | # print('input', input.dtype) 101 | # print('bias', bias.dtype) 102 | # print('-------\n') 103 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 104 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04 2 | 3 | ENV TZ=Europe/Moscow 4 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 5 | 6 | SHELL ["/bin/bash", "--login", "-c"] 7 | 8 | RUN apt-get update && apt-get install -y \ 9 | build-essential \ 10 | rsync \ 11 | curl \ 12 | wget \ 13 | htop \ 14 | git \ 15 | openssh-server \ 16 | nano \ 17 | cmake \ 18 | unzip \ 19 | zip \ 20 | python-opencv \ 21 | vim \ 22 | # ffmpeg \ 23 | tmux \ 24 | freeglut3-dev 25 | 26 | ENV PYTHONDONTWRITEBYTECODE=1 27 | ENV PYTHONUNBUFFERED=1 28 | 29 | # nvdiffrast setup 30 | RUN apt-get update && apt-get install -y \ 31 | pkg-config \ 32 | libglvnd0 \ 33 | libgl1 \ 34 | libglx0 \ 35 | libegl1 \ 36 | libgles2 \ 37 | libglvnd-dev \ 38 | libgl1-mesa-dev \ 39 | libegl1-mesa-dev \ 40 | libgles2-mesa-dev 41 | 42 | ENV LD_LIBRARY_PATH /usr/lib64:$LD_LIBRARY_PATH 43 | 44 | ENV NVIDIA_VISIBLE_DEVICES all 45 | ENV NVIDIA_DRIVER_CAPABILITIES compute,utility,graphics 46 | 47 | ENV PYOPENGL_PLATFORM egl 48 | 49 | COPY docker/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json 50 | 51 | ## glew installation from source 52 | RUN curl -L https://downloads.sourceforge.net/project/glew/glew/2.1.0/glew-2.1.0.tgz > /tmp/glew-2.1.0.tgz 53 | RUN mkdir -p /tmp && \ 54 | cd /tmp && tar zxf /tmp/glew-2.1.0.tgz && cd glew-2.1.0 && \ 55 | SYSTEM=linux-egl make && \ 56 | SYSTEM=linux-egl make install && \ 57 | rm -rf /tmp/glew-2.1.0.zip /tmp/glew-2.1.0 58 | 59 | # fixuid 60 | ARG USERNAME=user 61 | RUN apt-get update && apt-get install -y sudo curl && \ 62 | addgroup --gid 1000 $USERNAME && \ 63 | adduser --uid 1000 --gid 1000 --disabled-password --gecos '' $USERNAME && \ 64 | adduser $USERNAME sudo && \ 65 | echo '%sudo ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers && \ 66 | USER=$USERNAME && \ 67 | GROUP=$USERNAME && \ 68 | curl -SsL https://github.com/boxboat/fixuid/releases/download/v0.4/fixuid-0.4-linux-amd64.tar.gz | tar -C /usr/local/bin -xzf - && \ 69 | chown root:root /usr/local/bin/fixuid && \ 70 | chmod 4755 /usr/local/bin/fixuid && \ 71 | mkdir -p /etc/fixuid && \ 72 | printf "user: $USER\ngroup: $GROUP\n" > /etc/fixuid/config.yml 73 | USER $USERNAME:$USERNAME 74 | 75 | # miniconda 76 | WORKDIR /home/user 77 | ENV CONDA_AUTO_UPDATE_CONDA=false 78 | ENV PATH=/home/user/miniconda/bin:$PATH 79 | 80 | RUN wget --quiet https://repo.continuum.io/miniconda/Miniconda3-py37_4.8.3-Linux-x86_64.sh -O ~/miniconda.sh && \ 81 | chmod +x ~/miniconda.sh && \ 82 | ~/miniconda.sh -b -p ~/miniconda && \ 83 | rm ~/miniconda.sh && \ 84 | conda clean -ya 85 | 86 | # python libs 87 | RUN pip install --upgrade pip 88 | 89 | ## requirements 90 | COPY requirements.txt requirements.txt 91 | RUN pip --no-cache-dir install -r requirements.txt 92 | RUN conda install pytorch torchvision cudatoolkit=10.1 -c pytorch 93 | 94 | <<<<<<< HEAD 95 | ## FAILING TO DUN DURING DOKER BUILD, INSTALL INSIDE DOCKER 96 | ======= 97 | ## FAILS TO RUN DURING DOKER BUILD, INSTALL INSIDE DOCKER 98 | >>>>>>> master 99 | #RUN pip install git+https://github.com/rmbashirov/minimal_pytorch_rasterizer 100 | 101 | 102 | # docker setup 103 | WORKDIR / 104 | ENTRYPOINT ["fixuid", "-q"] 105 | CMD ["fixuid", "-q", "bash"] 106 | -------------------------------------------------------------------------------- /models/common/decoderv2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.styleganv2.modules import EqualConv2d, EqualConv2dSame 6 | from models.styleganv2.op import FusedLeakyReLU 7 | 8 | 9 | class Decoderv2Block(nn.Module): 10 | def __init__( 11 | self, 12 | in_channels, 13 | skip_channels, 14 | out_channels, 15 | norm_layer=None, 16 | same=False 17 | ): 18 | super().__init__() 19 | 20 | if norm_layer is None: 21 | norm_layer = nn.BatchNorm2d 22 | 23 | ConvLayer = EqualConv2d if not same else EqualConv2dSame 24 | 25 | conv1 = ConvLayer( 26 | in_channels + skip_channels, 27 | out_channels, 28 | kernel_size=3, 29 | padding=1, 30 | bias=False 31 | ) 32 | norm1 = norm_layer(out_channels, affine=True) 33 | relu1 = FusedLeakyReLU(out_channels) 34 | self.conv1 = nn.Sequential(conv1, norm1, relu1) 35 | 36 | conv2 = ConvLayer( 37 | out_channels, 38 | out_channels, 39 | kernel_size=3, 40 | padding=1, 41 | bias=False 42 | ) 43 | norm2 = norm_layer(out_channels, affine=True) 44 | relu2 = FusedLeakyReLU(out_channels) 45 | self.conv2 = nn.Sequential(conv2, norm2, relu2) 46 | 47 | def forward(self, x, skip=None): 48 | x = F.interpolate(x, scale_factor=2, mode="nearest") 49 | if skip is not None: 50 | x = torch.cat([x, skip], dim=1) 51 | x = self.conv1(x) 52 | x = self.conv2(x) 53 | return x 54 | 55 | 56 | 57 | class UnetDecoderv2(nn.Module): 58 | def __init__( 59 | self, 60 | encoder_channels, 61 | decoder_channels, 62 | n_blocks=5, 63 | same=False 64 | ): 65 | super().__init__() 66 | 67 | if n_blocks != len(decoder_channels): 68 | raise ValueError( 69 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 70 | n_blocks, len(decoder_channels) 71 | ) 72 | ) 73 | 74 | encoder_channels = encoder_channels[1:] # remove first skip with same spatial resolution 75 | encoder_channels = encoder_channels[::-1] # reverse channels to start from head of encoder 76 | 77 | # computing blocks input and output channels 78 | head_channels = encoder_channels[0] 79 | in_channels = [head_channels] + list(decoder_channels[:-1]) 80 | skip_channels = list(encoder_channels[1:]) + [0] 81 | out_channels = decoder_channels 82 | 83 | # combine decoder keyword arguments 84 | blocks = [ 85 | Decoderv2Block(in_ch, skip_ch, out_ch, same=same) 86 | for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) 87 | ] 88 | self.blocks = nn.ModuleList(blocks) 89 | 90 | def forward(self, *features): 91 | 92 | features = features[1:] # remove first skip with same spatial resolution 93 | features = features[::-1] # reverse channels to start from head of encoder 94 | 95 | head = features[0] 96 | skips = features[1:] 97 | 98 | x = head 99 | for i, decoder_block in enumerate(self.blocks): 100 | skip = skips[i] if i < len(skips) else None 101 | x = decoder_block(x, skip) 102 | 103 | return x 104 | -------------------------------------------------------------------------------- /models/renderer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import utils.common 7 | from models.styleganv2.modules import EqualConv2d 8 | from models.styleganv2.op import FusedLeakyReLU 9 | from models.common.unet import Unetv2 10 | 11 | 12 | class Renderer(nn.Module): 13 | def __init__(self, in_channels=18, segm_channels=3, ngf=64, normalization='batch'): 14 | super().__init__() 15 | 16 | n_out = 16 17 | self.model = Unetv2(in_channels=in_channels, classes=n_out, ngf=ngf, same=False) 18 | norm_layer = nn.InstanceNorm2d if normalization == 'instance' else nn.BatchNorm2d 19 | 20 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear') 21 | 22 | upsample_layers = [] 23 | upsample_layers.append(EqualConv2d(n_out, n_out, 3, 1, 1, bias=False)) 24 | upsample_layers.append(norm_layer(n_out, affine=True)) 25 | upsample_layers.append(FusedLeakyReLU(n_out)) 26 | upsample_layers.append(EqualConv2d(n_out, n_out, 3, 1, 1, bias=False)) 27 | upsample_layers.append(norm_layer(n_out, affine=True)) 28 | upsample_layers.append(FusedLeakyReLU(n_out)) 29 | upsample_layers.append(EqualConv2d(n_out, n_out, 3, 1, 1, bias=False)) 30 | upsample_layers.append(norm_layer(n_out, affine=True)) 31 | upsample_layers.append(FusedLeakyReLU(n_out)) 32 | upsample_layers.append(EqualConv2d(n_out, n_out, 3, 1, 1, bias=False)) 33 | upsample_layers.append(norm_layer(n_out, affine=True)) 34 | upsample_layers.append(FusedLeakyReLU(n_out)) 35 | self.upsample_convs = nn.Sequential(*upsample_layers) 36 | 37 | self.rgb_head = nn.Sequential( 38 | norm_layer(n_out, affine=True), 39 | FusedLeakyReLU(n_out), 40 | EqualConv2d(n_out, n_out, 3, 1, 1, bias=False), 41 | norm_layer(n_out, affine=True), 42 | FusedLeakyReLU(n_out), 43 | EqualConv2d(n_out, n_out, 3, 1, 1, bias=False), 44 | norm_layer(n_out, affine=True), 45 | FusedLeakyReLU(n_out), 46 | EqualConv2d(n_out, 3, 3, 1, 1, bias=True), 47 | nn.Tanh()) 48 | 49 | self.segm_head = nn.Sequential( 50 | norm_layer(n_out, affine=True), 51 | FusedLeakyReLU(n_out), 52 | EqualConv2d(n_out, 3, 3, 1, 1, bias=True), 53 | nn.Sigmoid()) 54 | 55 | self.segm_channels = segm_channels 56 | 57 | self.last_iter = None 58 | 59 | def forward(self, data_dict): 60 | assert 'uv' in data_dict 61 | assert 'nrender' in data_dict 62 | 63 | uv = data_dict['uv'] 64 | nrender = data_dict['nrender'] 65 | uvmask = ((uv > -10).sum(dim=1, keepdim=True) > 0).float() 66 | 67 | inp = torch.cat([nrender, uv], dim=1) 68 | out_lr = self.model(inp) 69 | out_lr = self.upsample(out_lr) 70 | out = self.upsample_convs(out_lr) 71 | 72 | rgb = self.rgb_head(out) 73 | segm = self.segm_head(out) 74 | 75 | segm = segm[:, :self.segm_channels] 76 | segm_fg = segm[:, :1] 77 | 78 | segm_H = segm_fg.shape[2] 79 | mask_H = uvmask.shape[2] 80 | if segm_H != mask_H: 81 | uvmask = torch.nn.functional.interpolate(uvmask, size=(segm_H, segm_H)) 82 | 83 | segm_fg = (segm_fg + uvmask).clamp(0., 1.) 84 | 85 | if 'crop_mask' in data_dict: 86 | segm_fg = segm_fg * data_dict['crop_mask'] 87 | 88 | if 'background' in data_dict: 89 | background = data_dict['background'] 90 | rgb_segm = utils.common.to_sigm(rgb) * segm_fg + background * (1. - segm_fg) 91 | else: 92 | rgb_segm = utils.common.to_sigm(rgb) * segm_fg 93 | rgb_segm = utils.common.to_tanh(rgb_segm) 94 | 95 | out_dict = dict(fake_rgb=rgb_segm, fake_segm=segm_fg, r_input=inp) 96 | 97 | return out_dict 98 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # StylePeople 3 | 4 | ### [Project](https://dolorousrtur.github.io/style-people/)   [Paper](https://arxiv.org/pdf/2104.08363.pdf)   5 | 6 | This is repository with inference code for paper [**"StylePeople: A Generative Model of Fullbody Human Avatars"**](https://arxiv.org/pdf/2104.08363.pdf) (CVPR21). 7 | This code is for the part of the paper describing generative neural textures model. For inference of video-based avatars refer to [this repository](https://github.com/dolorousrtur/neural-textures). 8 | 9 | ## Getting started 10 | ### Data 11 | To use this repository you first need to download model checkpoints and some auxiliary files. 12 | 13 | * Download the archive with data from [Google Drive](https://drive.google.com/file/d/1xfsCIy5Xn_fS9uqC23svB_tjiJw543ZF/view?usp=sharing) and unpack in into `StylePeople/data/`. It contains: 14 | * checkpoints for generative model and encoder network (`data/checkpoint`) 15 | * A few samples from *AzurePeople* dataset to run inference script on (`data/inference_samples`) 16 | * A sample of smplx parameters (`data/smplx_sample.pkl`) 17 | * Some auxiliary data (`data/uv_render` and `data/*.{yaml,pth,npy}`) 18 | * Download SMPL-X models (`SMPLX_{MALE,FEMALE,NEUTRAL}.pkl`) from [SMPL-X project page](https://smpl-x.is.tue.mpg.de/) and move them to `data/smplx/` 19 | 20 | ### Docker 21 | The easiest way to build an environment for this repository is to use docker image. To build it, make the following steps: 22 | 1. Build the image with the following command: 23 | ``` 24 | bash docker/build.sh 25 | ``` 26 | 2. Start a container: 27 | ``` 28 | bash docker/run.sh 29 | ``` 30 | It mounts root directory of the host system to `/mounted/` inside docker and sets cloned repository path as a starting directory. 31 | 32 | 3. **Inside the container** install `minimal_pytorch_rasterizer`. (Unfortunately, docker fails to install it during image building) 33 | ``` 34 | pip install git+https://github.com/rmbashirov/minimal_pytorch_rasterizer 35 | ``` 36 | 4. *(Optional)* You can then commit changes to the image so that you don't need to install `minimal_pytorch_rasterizer` for every new container. See [docker documentation](https://docs.docker.com/engine/reference/commandline/commit/). 37 | 38 | ## Usage 39 | This repository provides three scripts with three scenarios, as it is shown on the scheme below: 40 | 41 |

42 | drawing 43 |

44 | 45 | Below are brief descriptions and examples of usage for these scripts. Please see their argparse help message for more details. 46 | 47 | ### Sample texture from generative model 48 | `python sample_new_textures.py` samples a number of neural textures from generative model and saves them on disc. 49 | 50 | Example: 51 | ``` 52 | python sample_new_textures.py --n_samples=10 --texture_batch_name='my_run' 53 | ``` 54 | will sample 10 neural textures and save them to `data/textures/my_run` 55 | 56 | ### Infer a neural texture for a given set of images 57 | `infer_texture.py` fits a neural texture to a given set of data samples. See `data/inference_samples/` for samples' examples. 58 | 59 | Example: 60 | ``` 61 | python infer_texture.py --input_path=data/inference_samples/azure_04 --texture_out_dir=data/textures/azure_04 62 | ``` 63 | will load all data samples from `data/inference_samples/azure_04` and save inferred texture to `data/textures/azure_04` 64 | 65 | 66 | ### Render an image of a person with given neural texture and smplx parameters 67 | `render_person.py` generates an image of a person with given shape, pose, expression and neural texture. 68 | 69 | Example: 70 | ``` 71 | python render_person.py --texture_path=data/textures/my_run/0000/texture.pth --smplx_dict_path=data/smplx_sample.pkl --save_dir=data/my_person 72 | ``` 73 | will render a person with neural texture from `data/textures/my_run/0000/texture.pth` and smplx parameters from `data/smplx_sample.pkl` and save generated images to `data/my_person`. 74 | -------------------------------------------------------------------------------- /dataloaders/inference_loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | import numpy as np 5 | import cv2 6 | import torch 7 | import smplx 8 | from utils.common import json2kps, itt, to_tanh 9 | from utils.bbox import crop_resize_coords, get_ltrb_bbox, crop_resize_verts, crop_resize_coords, crop_resize_image 10 | from utils.smplx_models import build_smplx_model_dict 11 | import pickle 12 | 13 | 14 | class InferenceDataset(): 15 | def __init__(self, samples_dir, image_size, v_inds, smplx_model_dir): 16 | self.samples_dir = samples_dir 17 | self.frame_list = self.list_samples(self.samples_dir) 18 | 19 | 20 | self.image_size = image_size 21 | self.input_size = image_size // 2 22 | 23 | 24 | self.v_inds = v_inds 25 | 26 | self.smplx_models_dict = build_smplx_model_dict(smplx_model_dir, device='cpu') 27 | 28 | @staticmethod 29 | def list_samples(samples_dir): 30 | files = os.listdir(samples_dir) 31 | frame_ids = [x.split('_')[0] for x in files] 32 | frame_ids = sorted(list(set(frame_ids))) 33 | 34 | return frame_ids 35 | 36 | def load_rgb(self, frame_id): 37 | rgb_path = os.path.join(self.samples_dir, f"{frame_id}_rgb.jpg") 38 | rgb = cv2.imread(rgb_path)[..., ::-1] / 255. 39 | return itt(rgb) 40 | 41 | def load_segm(self, frame_id): 42 | rgb_path = os.path.join(self.samples_dir, f"{frame_id}_segm.png") 43 | rgb = cv2.imread(rgb_path)[..., ::-1] / 255. 44 | return itt(rgb) 45 | 46 | def load_landmarks(self, frame_id): 47 | landmarks_path = os.path.join(self.samples_dir, f"{frame_id}_keypoints.json") 48 | with open(landmarks_path, 'r') as f: 49 | landmarks = json.load(f) 50 | landmarks = json2kps(landmarks) 51 | landmarks = torch.FloatTensor(landmarks).unsqueeze(0) 52 | return landmarks 53 | 54 | def load_smplx(self, frame_id): 55 | smplx_path = os.path.join(self.samples_dir, f"{frame_id}_smplx.pkl") 56 | with open(smplx_path, 'rb') as f: 57 | smpl_params = pickle.load(f) 58 | 59 | gender = smpl_params['gender'] 60 | for k, v in smpl_params.items(): 61 | if type(v) == np.ndarray: 62 | smpl_params[k] = torch.FloatTensor(v) 63 | 64 | smpl_params['left_hand_pose'] = smpl_params['left_hand_pose'][:, :6] 65 | smpl_params['right_hand_pose'] = smpl_params['right_hand_pose'][:, :6] 66 | 67 | with torch.no_grad(): 68 | smpl_output = self.smplx_models_dict[gender](**smpl_params) 69 | vertices = smpl_output.vertices.detach() 70 | vertices = vertices[:, self.v_inds] 71 | K = smpl_params['camera_intrinsics'].unsqueeze(0) 72 | vertices = torch.bmm(vertices, K.transpose(1, 2)) 73 | smpl_params.pop('camera_intrinsics') 74 | smpl_params['gender'] = [smpl_params['gender']] 75 | 76 | return vertices, K, smpl_params 77 | 78 | def __getitem__(self, item): 79 | frame_id = self.frame_list[item] 80 | 81 | rgb_orig = self.load_rgb(frame_id) 82 | segm_orig = self.load_segm(frame_id) 83 | landmarks_orig = self.load_landmarks(frame_id) 84 | verts_orig, K_orig, smpl_params = self.load_smplx(frame_id) 85 | 86 | ltrb = get_ltrb_bbox(verts_orig).float() 87 | vertices_crop, K_crop = crop_resize_verts(verts_orig, K_orig, ltrb, self.input_size) 88 | 89 | landmarks_crop = crop_resize_coords(landmarks_orig, ltrb, self.image_size)[0] 90 | rgb_crop = crop_resize_image(rgb_orig.unsqueeze(0), ltrb, self.image_size)[0] 91 | segm_crop = crop_resize_image(segm_orig.unsqueeze(0), ltrb, self.image_size)[0] 92 | 93 | rgb_crop = rgb_crop * segm_crop[:1] 94 | rgb_crop = to_tanh(rgb_crop) 95 | 96 | vertices_crop = vertices_crop[0] 97 | K_crop = K_crop[0] 98 | 99 | smpl_params = {k:v[0] for (k,v) in smpl_params.items()} 100 | data_dict = dict(real_rgb=rgb_crop, real_segm=segm_crop, landmarks=landmarks_crop, verts=vertices_crop, 101 | K=K_crop) 102 | data_dict.update(smpl_params) 103 | 104 | return data_dict 105 | 106 | def __len__(self): 107 | return len(self.frame_list) 108 | -------------------------------------------------------------------------------- /inference_module/criterion.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import lpips 5 | import kornia 6 | 7 | import models 8 | 9 | 10 | class MAECriterion(torch.nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def forward(self, pred, target): 15 | return torch.mean(torch.abs(pred - target)) 16 | 17 | 18 | class MSECriterion(torch.nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | 22 | def forward(self, pred, target): 23 | return torch.mean((pred - target) ** 2) 24 | 25 | 26 | class LPIPSCriterion(torch.nn.Module): 27 | def __init__(self, net='vgg'): 28 | super().__init__() 29 | 30 | self.lpips = lpips.LPIPS(net=net) 31 | 32 | def forward(self, pred, target, valid_mask=None): 33 | loss_not_reduced = self.lpips(pred.contiguous(), target.contiguous()) 34 | if valid_mask is not None: 35 | invalid_mask = torch.logical_not(valid_mask) 36 | loss_not_reduced[invalid_mask] = 0.0 37 | 38 | return loss_not_reduced.mean() 39 | 40 | 41 | class SSIMCriterion(torch.nn.Module): 42 | def __init__(self, window_size=5, max_val=1, reduction='mean'): 43 | super().__init__() 44 | 45 | self.window_size = window_size 46 | self.max_val = max_val 47 | self.reduction = reduction 48 | 49 | def forward(self, pred, target): 50 | loss = kornia.losses.ssim( 51 | pred, 52 | target, 53 | window_size=self.window_size, 54 | max_val=self.max_val, 55 | reduction=self.reduction 56 | ).mean() 57 | 58 | return loss 59 | 60 | 61 | class UnaryDiscriminatorNonsaturatingCriterion(torch.nn.Module): 62 | def __init__(self, discriminator, input_shape=(256, 256)): 63 | super().__init__() 64 | 65 | self.discriminator = discriminator 66 | self.input_shape = tuple(input_shape) 67 | 68 | 69 | def forward(self, pred_image, pred_segm): 70 | input_dict = { 71 | 'main': { 72 | 'rgb': kornia.resize(pred_image, self.input_shape), 73 | 'segm': kornia.resize(pred_segm, self.input_shape) 74 | } 75 | } 76 | output_dict = self.discriminator(input_dict) 77 | 78 | fake_score = output_dict['d_score'] 79 | loss = models.styleganv2.modules.g_nonsaturating_loss(fake_score) 80 | 81 | return loss 82 | 83 | 84 | class UnaryDiscriminatorFeatureMatchingCriterion(torch.nn.Module): 85 | def __init__( 86 | self, 87 | discriminator, 88 | endpoints=['reduction_1', 'reduction_2', 'reduction_3', 'reduction_4', 'reduction_5', 'reduction_6', 'reduction_7', 'reduction_8'], 89 | input_shape=(256, 256) 90 | ): 91 | super().__init__() 92 | 93 | self.discriminator = discriminator 94 | self.endpoints = endpoints 95 | self.input_shape = tuple(input_shape) if input_shape else input_shape 96 | 97 | 98 | def forward(self, pred_dict, target_dict): 99 | pred_output_dict = self.discriminator.extract_endpoints({ 100 | 'main': { 101 | 'rgb': kornia.resize(pred_dict['rgb'], self.input_shape) if self.input_shape else pred_dict['rgb'], 102 | 'segm': kornia.resize(pred_dict['segm'], self.input_shape) if self.input_shape else pred_dict['segm'], 103 | 'landmarks': pred_dict['landmarks'] 104 | } 105 | }) 106 | 107 | target_output_dict = self.discriminator.extract_endpoints({ 108 | 'main': { 109 | 'rgb': kornia.resize(target_dict['rgb'], self.input_shape) if self.input_shape else target_dict['rgb'], 110 | 'segm': kornia.resize(target_dict['segm'], self.input_shape) if self.input_shape else target_dict['segm'], 111 | 'landmarks': target_dict['landmarks'] 112 | } 113 | }) 114 | 115 | loss = 0.0 116 | for endpoint in self.endpoints: 117 | loss += torch.abs(pred_output_dict['endpoints'][endpoint] - target_output_dict['endpoints'][endpoint]).mean() 118 | 119 | loss /= len(self.endpoints) 120 | 121 | return { 122 | 'loss': loss, 123 | 'pred_output_dict': pred_output_dict, 124 | 'target_output_dict': target_output_dict 125 | } -------------------------------------------------------------------------------- /utils/demo.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import smplx 4 | import torch 5 | import importlib 6 | from torch.utils.data import DataLoader 7 | from collections import defaultdict 8 | 9 | import models 10 | from models.generator import Generator 11 | from models.renderer import Renderer 12 | from utils.bbox import get_ltrb_bbox, crop_resize_verts 13 | from utils.common import get_rotation_matrix, rotate_verts, to_sigm 14 | from utils.common import dict2device, setup_environment 15 | from utils.uv_renderer import UVRenderer 16 | from utils.smplx_models import build_smplx_model_dict 17 | 18 | import os 19 | import collections 20 | 21 | import inference_module 22 | from dataloaders.inference_loader import InferenceDataset 23 | 24 | from inference_module.inferer import get_config 25 | 26 | 27 | def concat_all_samples(dataloader): 28 | datadicts = defaultdict(list) 29 | 30 | for data_dict in dataloader: 31 | for k, v in data_dict.items(): 32 | datadicts[k].append(v) 33 | 34 | dict_combined = {} 35 | for k, v in datadicts.items(): 36 | if type(v[0]) == torch.Tensor: 37 | dict_combined[k] = torch.cat(v, dim=0) 38 | elif type(v[0]) == list: 39 | dict_combined[k] = [x[0] for x in v] 40 | 41 | return dict_combined 42 | 43 | def load_models(checkpoint_path='data/checkpoints/generative_model.pth', device='cuda:0'): 44 | ainp_path = 'data/spectral_texture16.pth' 45 | ainp_scales = [64, 128, 256, 512] 46 | 47 | ainp_tensor = torch.load(ainp_path) 48 | generator = Generator(ainp_tensor=ainp_tensor, ainp_scales=ainp_scales).to(device) 49 | renderer = Renderer().to(device) 50 | 51 | checkpoint = torch.load(checkpoint_path) 52 | generator.load_state_dict(checkpoint['g']) 53 | renderer.load_state_dict(checkpoint['r']) 54 | 55 | generator.eval() 56 | renderer.eval() 57 | 58 | return generator, renderer 59 | 60 | 61 | 62 | 63 | class DemoInferer(): 64 | def __init__(self, checkpoint_path, smplx_model_dir, imsize=1024, config_path='data/config.yaml', device='cuda:0'): 65 | 66 | self.smplx_model_dir = smplx_model_dir 67 | self.smplx_models_dict = build_smplx_model_dict(smplx_model_dir, device) 68 | 69 | self.generator_config = get_config(config_path) 70 | self.generator, self.renderer = load_models(checkpoint_path, device) 71 | 72 | self.v_inds = torch.LongTensor(np.load('data/v_inds.npy')).to(device) 73 | 74 | 75 | self.image_size = imsize 76 | self.input_size = imsize // 2 # input resolution is twice as small as output 77 | 78 | self.uv_renderer = UVRenderer(self.input_size, self.input_size).to(device) 79 | 80 | self.device = device 81 | self.style_dim = 512 82 | 83 | def infer(self, config, input_path): 84 | # setup environment 85 | setup_environment(config.random_seed) 86 | dataset = InferenceDataset(input_path, self.image_size, self.v_inds, self.smplx_model_dir) 87 | dataloader = DataLoader(dataset) 88 | train_dict = concat_all_samples(dataloader) 89 | train_dict = dict2device(train_dict, self.device) 90 | 91 | print("Successfully loaded data") 92 | 93 | # load inferer 94 | inferer = inference_module.inferer.Inferer(self.generator, self.renderer, self.generator_config, config) 95 | inferer = inferer.to(self.device) 96 | inferer.eval() 97 | print("Successfully loaded inferer") 98 | 99 | # load runner 100 | runner = inference_module.runner.Runner(config, inferer, self.smplx_models_dict, self.image_size) 101 | print("Successfully loaded runner") 102 | 103 | # train loop 104 | ntexture = runner.run_epoch(train_dict) 105 | 106 | return ntexture 107 | 108 | def sample_texture(self): 109 | z_val = [models.styleganv2.modules.make_noise(1, self.style_dim, 1, self.device)] 110 | ntexture = self.generator(z_val) 111 | return ntexture 112 | 113 | def load_smplx(self, sample_path): 114 | with open(sample_path, 'rb') as f: 115 | smpl_params = pickle.load(f) 116 | 117 | gender = smpl_params['gender'] 118 | 119 | for k, v in smpl_params.items(): 120 | if type(v) == np.ndarray: 121 | if 'hand_pose' in k: 122 | v = v[:, :6] 123 | smpl_params[k] = torch.FloatTensor(v).to(self.device) 124 | 125 | smpl_output = self.smplx_models_dict[gender](**smpl_params) 126 | vertices = smpl_output.vertices 127 | vertices = vertices[:, self.v_inds] 128 | K = smpl_params['camera_intrinsics'].unsqueeze(0) 129 | vertices = torch.bmm(vertices, K.transpose(1, 2)) 130 | return vertices, K 131 | 132 | def crop_vertices(self, vertices, K): 133 | ltrb = get_ltrb_bbox(vertices) 134 | vertices, K = crop_resize_verts(vertices, K, ltrb, self.input_size) 135 | return vertices, K, ltrb 136 | 137 | def make_rgb(self, vertices, ntexture): 138 | uv = self.uv_renderer(vertices, negbg=True) 139 | 140 | nrender = torch.nn.functional.grid_sample(ntexture, uv.permute(0, 2, 3, 1), align_corners=True) 141 | renderer_input = dict(uv=uv, nrender=nrender) 142 | 143 | with torch.no_grad(): 144 | renderer_output = self.renderer(renderer_input) 145 | 146 | fake_rgb = renderer_output['fake_rgb'] 147 | fake_segm = renderer_output['fake_segm'] 148 | fake_rgb = to_sigm(fake_rgb) * (fake_segm > 0.8) 149 | 150 | return fake_rgb 151 | 152 | def make_rotation_images(self, ntexture, n_rotimgs, smplx_path='data/smplx_sample.pkl'): 153 | vertices, K = self.load_smplx(smplx_path) 154 | vertices, K, ltrb = self.crop_vertices(vertices, K) 155 | 156 | K_inv = torch.inverse(K) 157 | 158 | rgb_frames = [] 159 | for j in range(n_rotimgs): 160 | angle = np.pi * 2 * j / n_rotimgs 161 | verts_rot, mean_point = rotate_verts(vertices, angle, K, K_inv, axis='y') 162 | rgb = self.make_rgb(verts_rot, ntexture) 163 | rgb_frames.append(rgb) 164 | 165 | return rgb_frames, ltrb 166 | -------------------------------------------------------------------------------- /models/styleganv2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.utils.cpp_extension import load 6 | 7 | module_path = os.path.dirname(__file__) 8 | upfirdn2d_op = load( 9 | 'upfirdn2d', 10 | sources=[ 11 | os.path.join(module_path, 'upfirdn2d.cpp'), 12 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 13 | ], 14 | ) 15 | 16 | 17 | class UpFirDn2dBackward(Function): 18 | @staticmethod 19 | def forward( 20 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 21 | ): 22 | 23 | up_x, up_y = up 24 | down_x, down_y = down 25 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 26 | 27 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 28 | 29 | grad_input = upfirdn2d_op.upfirdn2d( 30 | grad_output, 31 | grad_kernel, 32 | down_x, 33 | down_y, 34 | up_x, 35 | up_y, 36 | g_pad_x0, 37 | g_pad_x1, 38 | g_pad_y0, 39 | g_pad_y1, 40 | ) 41 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 42 | 43 | ctx.save_for_backward(kernel) 44 | 45 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 46 | 47 | ctx.up_x = up_x 48 | ctx.up_y = up_y 49 | ctx.down_x = down_x 50 | ctx.down_y = down_y 51 | ctx.pad_x0 = pad_x0 52 | ctx.pad_x1 = pad_x1 53 | ctx.pad_y0 = pad_y0 54 | ctx.pad_y1 = pad_y1 55 | ctx.in_size = in_size 56 | ctx.out_size = out_size 57 | 58 | return grad_input 59 | 60 | @staticmethod 61 | def backward(ctx, gradgrad_input): 62 | kernel, = ctx.saved_tensors 63 | 64 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 65 | 66 | gradgrad_out = upfirdn2d_op.upfirdn2d( 67 | gradgrad_input, 68 | kernel, 69 | ctx.up_x, 70 | ctx.up_y, 71 | ctx.down_x, 72 | ctx.down_y, 73 | ctx.pad_x0, 74 | ctx.pad_x1, 75 | ctx.pad_y0, 76 | ctx.pad_y1, 77 | ) 78 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 79 | gradgrad_out = gradgrad_out.view( 80 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 81 | ) 82 | 83 | return gradgrad_out, None, None, None, None, None, None, None, None 84 | 85 | 86 | class UpFirDn2d(Function): 87 | @staticmethod 88 | def forward(ctx, input, kernel, up, down, pad): 89 | if input.dtype != kernel.dtype: 90 | input = input.type(kernel.dtype) 91 | 92 | up_x, up_y = up 93 | down_x, down_y = down 94 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 95 | 96 | kernel_h, kernel_w = kernel.shape 97 | batch, channel, in_h, in_w = input.shape 98 | ctx.in_size = input.shape 99 | 100 | input = input.reshape(-1, in_h, in_w, 1) 101 | 102 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 103 | 104 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 105 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 106 | ctx.out_size = (out_h, out_w) 107 | 108 | ctx.up = (up_x, up_y) 109 | ctx.down = (down_x, down_y) 110 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 111 | 112 | g_pad_x0 = kernel_w - pad_x0 - 1 113 | g_pad_y0 = kernel_h - pad_y0 - 1 114 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 115 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 116 | 117 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 118 | 119 | # print('UpFirDn2d::forward') 120 | # print('input', input.dtype) 121 | # print('kernel', kernel.dtype) 122 | # print('-------\n') 123 | 124 | out = upfirdn2d_op.upfirdn2d( 125 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 126 | ) 127 | # out = out.view(major, out_h, out_w, minor) 128 | out = out.view(-1, channel, out_h, out_w) 129 | 130 | return out 131 | 132 | @staticmethod 133 | def backward(ctx, grad_output): 134 | kernel, grad_kernel = ctx.saved_tensors 135 | 136 | grad_input = UpFirDn2dBackward.apply( 137 | grad_output, 138 | kernel, 139 | grad_kernel, 140 | ctx.up, 141 | ctx.down, 142 | ctx.pad, 143 | ctx.g_pad, 144 | ctx.in_size, 145 | ctx.out_size, 146 | ) 147 | 148 | return grad_input, None, None, None, None 149 | 150 | 151 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 152 | out = UpFirDn2d.apply( 153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 154 | ) 155 | 156 | return out 157 | 158 | 159 | def upfirdn2d_native( 160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 161 | ): 162 | _, in_h, in_w, minor = input.shape 163 | kernel_h, kernel_w = kernel.shape 164 | 165 | out = input.view(-1, in_h, 1, in_w, 1, minor) 166 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 167 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 168 | 169 | out = F.pad( 170 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 171 | ) 172 | out = out[ 173 | :, 174 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 175 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 176 | :, 177 | ] 178 | 179 | out = out.permute(0, 3, 1, 2) 180 | out = out.reshape( 181 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 182 | ) 183 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 184 | out = F.conv2d(out, w) 185 | out = out.reshape( 186 | -1, 187 | minor, 188 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 189 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 190 | ) 191 | out = out.permute(0, 2, 3, 1) 192 | 193 | return out[:, ::down_y, ::down_x, :] 194 | 195 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import os 4 | import time 5 | import importlib 6 | 7 | import cv2 8 | from PIL import Image 9 | 10 | import math 11 | import pickle 12 | 13 | import torch 14 | from torch import distributed as dist 15 | from torch.utils.data.sampler import Sampler 16 | 17 | 18 | def load_module(module_type, module_name): 19 | m = importlib.import_module(f'{module_type}.{module_name}') 20 | return m 21 | 22 | 23 | def return_empty_dict_if_none(x): 24 | return {} if x is None else x 25 | 26 | 27 | def get_data_sampler(dataset, shuffle=False, is_distributed=False): 28 | if is_distributed: 29 | return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) 30 | 31 | if shuffle: 32 | return torch.utils.data.RandomSampler(dataset) 33 | else: 34 | return torch.utils.data.SequentialSampler(dataset) 35 | 36 | 37 | def dict2device(d, device, dtype=None): 38 | if isinstance(d, np.ndarray): 39 | d = torch.from_numpy(d) 40 | 41 | if torch.is_tensor(d): 42 | d = d.to(device) 43 | if dtype is not None: 44 | d = d.type(dtype) 45 | return d 46 | 47 | if isinstance(d, dict): 48 | for k, v in d.items(): 49 | d[k] = dict2device(v, device, dtype=dtype) 50 | 51 | return d 52 | 53 | 54 | def setup_environment(seed): 55 | # random 56 | random.seed(seed) 57 | 58 | # numpy 59 | np.random.seed(seed) 60 | 61 | # cv2 62 | cv2.setNumThreads(0) 63 | cv2.ocl.setUseOpenCL(False) 64 | 65 | # pytorch 66 | os.environ['OMP_NUM_THREADS'] = '1' 67 | torch.set_num_threads(1) 68 | torch.manual_seed(seed) 69 | torch.cuda.manual_seed_all(seed) 70 | torch.backends.cudnn.benchmark = True 71 | torch.backends.cudnn.enabled = True 72 | 73 | 74 | def squeeze_metrics(d): 75 | metrics = dict() 76 | for k, v in d.items(): 77 | if torch.is_tensor(v): 78 | metrics[k] = v.mean().item() 79 | elif isinstance(v, float): 80 | metrics[k] = v 81 | else: 82 | raise NotImplementedError("Unknown datatype for metric: {}".format(type(v))) 83 | 84 | return metrics 85 | 86 | 87 | def reduce_metrics(metrics): 88 | metrics_dict = dict() 89 | for k in metrics[0].keys(): 90 | metrics_dict[k] = np.mean([item[k] for item in metrics]) 91 | 92 | return metrics_dict 93 | 94 | 95 | def get_world_size(): 96 | if not dist.is_available(): 97 | return 1 98 | 99 | if not dist.is_initialized(): 100 | return 1 101 | 102 | return dist.get_world_size() 103 | 104 | 105 | def reduce_loss_dict(loss_dict): 106 | world_size = get_world_size() 107 | 108 | if world_size < 2: 109 | return loss_dict 110 | 111 | with torch.no_grad(): 112 | keys = [] 113 | losses = [] 114 | 115 | for k in sorted(loss_dict.keys()): 116 | keys.append(k) 117 | losses.append(loss_dict[k]) 118 | 119 | losses = torch.stack(losses, 0) 120 | dist.reduce(losses, dst=0) 121 | 122 | if dist.get_rank() == 0: 123 | losses /= world_size 124 | 125 | reduced_losses = {k: v for k, v in zip(keys, losses)} 126 | 127 | return reduced_losses 128 | 129 | 130 | def flatten_parameters(parameters): 131 | list_of_flat_parameters = [torch.flatten(p) for p in parameters] 132 | flat_parameters = torch.cat(list_of_flat_parameters).view(-1, 1) 133 | return flat_parameters 134 | 135 | 136 | def set_random_seed(seed): 137 | random.seed(seed) 138 | np.random.seed(seed) 139 | torch.manual_seed(seed) 140 | 141 | 142 | def itt(img): 143 | tensor = torch.FloatTensor(img) # 144 | if len(tensor.shape) == 3: 145 | tensor = tensor.permute(2, 0, 1) 146 | else: 147 | tensor = tensor.unsqueeze(0) 148 | return tensor 149 | 150 | 151 | def tti(tensor): 152 | tensor = tensor.detach().cpu() 153 | tensor = tensor[0].permute(1, 2, 0) 154 | image = tensor.numpy() 155 | if image.shape[-1] == 1: 156 | image = image[..., 0] 157 | return image 158 | 159 | 160 | def to_tanh(t): 161 | return t * 2 - 1. 162 | 163 | 164 | def to_sigm(t): 165 | return (t + 1) / 2 166 | 167 | 168 | def get_rotation_matrix(angle, axis='x'): 169 | if axis == 'x': 170 | return np.array([ 171 | [1, 0, 0], 172 | [0, np.cos(angle), -np.sin(angle)], 173 | [0, np.sin(angle), np.cos(angle)] 174 | ]) 175 | elif axis == 'y': 176 | return np.array([ 177 | [np.cos(angle), 0, -np.sin(angle)], 178 | [0, 1, 0], 179 | [np.sin(angle), 0, np.cos(angle)] 180 | ]) 181 | elif axis == 'z': 182 | return np.array([ 183 | [np.cos(angle), -np.sin(angle), 0], 184 | [np.sin(angle), np.cos(angle), 0], 185 | [0, 0, 1] 186 | ]) 187 | else: 188 | raise ValueError(f"Unkown axis {axis}") 189 | 190 | 191 | def rotate_verts(vertices, angle, K, K_inv, axis='y', mean_point=None): 192 | rot_mat = get_rotation_matrix(angle, axis) 193 | rot_mat = torch.FloatTensor(rot_mat).to(vertices.device).unsqueeze(0) 194 | 195 | vertices_world = torch.bmm(vertices, K_inv.transpose(1, 2)) 196 | if mean_point is None: 197 | mean_point = vertices_world.mean(dim=1) 198 | 199 | vertices_rot = vertices_world - mean_point 200 | vertices_rot = torch.bmm(vertices_rot, rot_mat.transpose(1, 2)) 201 | vertices_rot = vertices_rot + mean_point 202 | vertices_rot_cam = torch.bmm(vertices_rot, K.transpose(1, 2)) 203 | 204 | return vertices_rot_cam, mean_point 205 | 206 | 207 | def json2kps(openpose_dict): 208 | list2kps = lambda x: np.array(x).reshape(-1, 3) 209 | keys_to_save = ['pose_keypoints_2d', 'face_keypoints_2d', 'hand_right_keypoints_2d', 'hand_left_keypoints_2d'] 210 | 211 | kps = openpose_dict['people'] 212 | if len(kps) == 0: 213 | kp_stacked = np.ones((137, 2)) * -1 214 | return kp_stacked 215 | kps = kps[0] 216 | kp_parts = [list2kps(kps[key]) for key in keys_to_save] 217 | kp_stacked = np.concatenate(kp_parts, axis=0) 218 | kp_stacked[kp_stacked[:, 2] < 0.1, :] = -1 219 | kp_stacked = kp_stacked[:, :2] 220 | 221 | return kp_stacked 222 | 223 | 224 | def segment_img(img, segm): 225 | img = to_sigm(img) * segm 226 | img = to_tanh(img) 227 | return img 228 | 229 | 230 | def segm2mask(segm): 231 | segm = torch.sum(segm, dim=1, keepdims=True) # Bx3xHxW -> Bx1xHxW 232 | segm = (segm > 0.0).type(torch.float32) 233 | return segm -------------------------------------------------------------------------------- /utils/uv_renderer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import minimal_pytorch_rasterizer 3 | import torch 4 | from torch import nn 5 | 6 | import nvdiffrast.torch as dr 7 | 8 | 9 | class UVRenderer(torch.nn.Module): 10 | def __init__(self, H, W, faces_path='data/uv_renderer/face_tex.npy', 11 | vertice_values_path='data/uv_renderer/uv.npy'): 12 | super().__init__() 13 | faces_cpu = np.load(faces_path) 14 | uv_cpu = np.load(vertice_values_path) 15 | 16 | self.faces = torch.nn.Parameter(torch.tensor(faces_cpu, dtype=torch.int32).contiguous(), requires_grad=False) 17 | self.vertice_values = torch.nn.Parameter(torch.tensor(uv_cpu, dtype=torch.float32).contiguous(), 18 | requires_grad=False) 19 | 20 | self.pinhole = minimal_pytorch_rasterizer.Pinhole2D( 21 | fx=1, fy=1, 22 | cx=0, cy=0, 23 | h=H, w=W 24 | ) 25 | 26 | def set_vertice_values(self, vertive_values): 27 | self.vertice_values = torch.nn.Parameter( 28 | torch.tensor(vertive_values, dtype=torch.float32).to(self.vertice_values.device), requires_grad=False) 29 | 30 | def forward(self, verts, norm=True, negbg=True, return_mask=False): 31 | N = verts.shape[0] 32 | 33 | uvs = [] 34 | for i in range(N): 35 | v = verts[i] 36 | uv = minimal_pytorch_rasterizer.project_mesh(v, self.faces, self.vertice_values, self.pinhole) 37 | uvs.append(uv) 38 | 39 | uvs = torch.stack(uvs, dim=0).permute(0, 3, 1, 2) 40 | mask = (uvs > 0).sum(dim=1, keepdim=True).float().clamp(0., 1.) 41 | 42 | if norm: 43 | uvs = (uvs * 2 - 1.) 44 | 45 | if negbg: 46 | uvs = uvs * mask - 10 * torch.logical_not(mask) 47 | 48 | if return_mask: 49 | return uvs, mask 50 | else: 51 | return uvs 52 | 53 | 54 | class NVDiffRastUVRenderer(torch.nn.Module): 55 | def __init__(self, faces_path='data/uv_renderer/face_tex.npy', 56 | uv_vert_values_path='data/uv_renderer/uv.npy'): 57 | super().__init__() 58 | 59 | self.glctx = dr.RasterizeGLContext() 60 | 61 | # load faces 62 | self.faces = nn.Parameter( 63 | torch.tensor(np.load(faces_path), dtype=torch.int32).contiguous(), 64 | requires_grad=False 65 | ) 66 | 67 | # load uv vert values 68 | self.uv_vert_values = nn.Parameter( 69 | torch.tensor(np.load(uv_vert_values_path), dtype=torch.float32).contiguous(), 70 | requires_grad=False 71 | ) 72 | 73 | def convert_to_ndc(self, verts, calibration_matrix, orig_w, orig_h, near=0.0001, far=10.0, invert_verts=True): 74 | device = verts.device 75 | 76 | # unproject verts 77 | if invert_verts: 78 | calibration_matrix_inv = torch.inverse(calibration_matrix) 79 | verts_3d = torch.bmm(verts, calibration_matrix_inv.transpose(1, 2)) 80 | else: 81 | verts_3d = verts 82 | 83 | # build ndc projection matrix 84 | matrix_ndc = [] 85 | for batch_i in range(calibration_matrix.shape[0]): 86 | fx, fy = calibration_matrix[batch_i, 0, 0], calibration_matrix[batch_i, 1, 1] 87 | cx, cy = calibration_matrix[batch_i, 0, 2], calibration_matrix[batch_i, 1, 2] 88 | 89 | matrix_ndc.append(torch.tensor([ 90 | [2*fx/orig_w, 0.0, (orig_w - 2*cx)/orig_w, 0.0], 91 | [0.0, -2*fy/orig_h, -(orig_h - 2*cy)/orig_h, 0.0], 92 | [0.0, 0.0, (-far - near) / (far - near), -2.0*far*near/(far-near)], 93 | [0.0, 0.0, -1.0, 0.0] 94 | ], device=device)) 95 | 96 | matrix_ndc = torch.stack(matrix_ndc, dim=0) 97 | 98 | # convert verts to verts ndc 99 | verts_3d_homo = torch.cat([verts_3d, torch.ones(*verts_3d.shape[:2], 1, device=device)], dim=-1) 100 | verts_3d_homo[:, :, 2] *= -1 # invert z-axis 101 | 102 | verts_ndc = torch.bmm(verts_3d_homo, matrix_ndc.transpose(1, 2)) 103 | 104 | return verts_ndc, matrix_ndc 105 | 106 | def render(self, verts_ndc, matrix_ndc, render_h=256, render_w=256): 107 | device = verts_ndc.device 108 | 109 | rast, rast_db = dr.rasterize(self.glctx, verts_ndc, self.faces, resolution=[render_h, render_w]) 110 | mask = (rast[:, :, :, 2] > 0.0).unsqueeze(-1).type(torch.float32) 111 | 112 | uv, uv_da = dr.interpolate(self.uv_vert_values, rast, self.faces, rast_db=rast_db, diff_attrs='all') 113 | 114 | # invert y-axis 115 | inv_idx = torch.arange(uv.shape[1] - 1, -1, -1).long().to(device) 116 | 117 | uv = uv.index_select(1, inv_idx) 118 | uv_da = uv_da.index_select(1, inv_idx) 119 | mask = mask.index_select(1, inv_idx) 120 | 121 | # make channel dim second 122 | uv = uv.permute(0, 3, 1, 2) 123 | uv_da = uv_da.permute(0, 3, 1, 2) 124 | mask = mask.permute(0, 3, 1, 2) 125 | rast = rast.permute(0, 3, 1, 2) 126 | 127 | # norm uv to [-1.0, 1.0] 128 | uv = 2 * uv - 1 129 | 130 | # set empty pixels to -10.0 131 | uv = uv * mask + (-10.0) * (1 - mask) 132 | 133 | return uv, uv_da, mask, rast 134 | 135 | def texture(self, texture, uv, mask=None, uv_da=None, mip=None, filter_mode='auto', boundary_mode='wrap', max_mip_level=None): 136 | texture = texture.permute(0, 2, 3, 1).contiguous() 137 | uv = (uv.permute(0, 2, 3, 1).contiguous() + 1) / 2 # norm to [0.0, 1.0] 138 | 139 | if uv_da is not None: 140 | uv_da = uv_da.permute(0, 2, 3, 1).contiguous() 141 | 142 | sampled_texture = dr.texture( 143 | texture, 144 | uv, 145 | uv_da=uv_da, 146 | mip=mip, 147 | filter_mode=filter_mode, 148 | boundary_mode=boundary_mode, 149 | max_mip_level=max_mip_level, 150 | ) 151 | 152 | sampled_texture = sampled_texture.permute(0, 3, 1, 2) 153 | 154 | if mask is not None: 155 | sampled_texture = sampled_texture * mask 156 | 157 | return sampled_texture 158 | 159 | def antialias(self, color, rast, verts_ndc, topology_hash=None, pos_gradient_boost=1.0): 160 | color = color.permute(0, 2, 3, 1).contiguous() 161 | rast = rast.permute(0, 2, 3, 1).contiguous() 162 | 163 | color = dr.antialias( 164 | color, 165 | rast, 166 | verts_ndc, 167 | self.faces, 168 | topology_hash=topology_hash, 169 | pos_gradient_boost=pos_gradient_boost 170 | ) 171 | 172 | color = color.permute(0, 3, 1, 2) 173 | 174 | return color -------------------------------------------------------------------------------- /models/generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | 5 | import torch 6 | from torch import nn 7 | 8 | from models.styleganv2.modules import ( 9 | PixelNorm, EqualLinear, ConstantInput, StyledConv, ToRGB, StyledConvAInp 10 | ) 11 | 12 | class Generator(nn.Module): 13 | def __init__( 14 | self, 15 | size=256, 16 | style_dim=512, 17 | n_mlp=8, 18 | channel_multiplier=4, 19 | output_channels=16, 20 | blur_kernel=[1, 3, 3, 1], 21 | lr_mlp=0.01, 22 | ainp_tensor=None, 23 | ainp_scales=None, 24 | ): 25 | super().__init__() 26 | 27 | if ainp_tensor is None or ainp_scales is None: 28 | ainp_tensor = None 29 | ainp_scales = [] 30 | 31 | self.size = size 32 | self.style_dim = style_dim 33 | self.output_channels = output_channels 34 | 35 | layers = [PixelNorm()] 36 | 37 | for i in range(n_mlp): 38 | layers.append( 39 | EqualLinear( 40 | style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu' 41 | ) 42 | ) 43 | 44 | self.style = nn.Sequential(*layers) 45 | 46 | self.channels = { 47 | 4: min(128 * channel_multiplier, 512), 48 | 8: min(128 * channel_multiplier, 512), 49 | 16: min(128 * channel_multiplier, 512), 50 | 32: min(128 * channel_multiplier, 512), 51 | 64: 128 * channel_multiplier, 52 | 128: 64 * channel_multiplier, 53 | 256: 32 * channel_multiplier, 54 | 512: 16 * channel_multiplier, 55 | 1024: 8 * channel_multiplier, 56 | } 57 | 58 | self.input = ConstantInput(self.channels[4]) 59 | self.conv1 = StyledConv( 60 | self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel 61 | ) 62 | self.to_rgb1 = ToRGB(self.channels[4], style_dim, out_channel=output_channels, upsample=False) 63 | 64 | self.log_size = int(math.log(size, 2)) 65 | self.num_layers = (self.log_size - 2) * 2 + 1 66 | 67 | self.convs = nn.ModuleList() 68 | self.upsamples = nn.ModuleList() 69 | self.to_rgbs = nn.ModuleList() 70 | self.noises = nn.Module() 71 | 72 | in_channel = self.channels[4] 73 | 74 | for layer_idx in range(self.num_layers): 75 | res = (layer_idx + 5) // 2 76 | shape = [1, 1, 2 ** res, 2 ** res] 77 | self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape)) 78 | 79 | for i in range(3, self.log_size + 1): 80 | scale = 2 ** i 81 | out_channel = self.channels[scale] 82 | 83 | self.convs.append( 84 | StyledConv( 85 | in_channel, 86 | out_channel, 87 | 3, 88 | style_dim, 89 | upsample=True, 90 | blur_kernel=blur_kernel, 91 | ) 92 | ) 93 | 94 | if scale in ainp_scales: 95 | ainp = torch.nn.functional.interpolate(ainp_tensor, size=(scale, scale), mode='bilinear') 96 | self.convs.append( 97 | StyledConvAInp( 98 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel, add_input=ainp 99 | ) 100 | ) 101 | else: 102 | self.convs.append( 103 | StyledConv( 104 | out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel 105 | ) 106 | ) 107 | 108 | self.to_rgbs.append(ToRGB(out_channel, style_dim, out_channel=output_channels)) 109 | 110 | in_channel = out_channel 111 | 112 | self.n_latent = self.log_size * 2 - 2 113 | 114 | def mean_latent(self, n_latent): 115 | latent_in = torch.randn( 116 | n_latent, self.style_dim, device=self.input.input.device 117 | ) 118 | latent = self.style(latent_in).mean(0, keepdim=True) 119 | 120 | return latent 121 | 122 | def make_noise(self): 123 | device = self.input.input.device 124 | 125 | noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)] 126 | 127 | for i in range(3, self.log_size + 1): 128 | for _ in range(2): 129 | noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device)) 130 | 131 | return noises 132 | 133 | def forward( 134 | self, 135 | styles, 136 | inject_index=None, 137 | truncation=1, 138 | truncation_latent=None, 139 | input_is_latent=False, 140 | noise=None, 141 | randomize_noise=True, 142 | ): 143 | if not input_is_latent: 144 | styles = [self.style(s) for s in styles] 145 | 146 | 147 | if noise is None: 148 | if randomize_noise: 149 | noise = [None] * self.num_layers 150 | else: 151 | noise = [ 152 | getattr(self.noises, f'noise_{i}') for i in range(self.num_layers) 153 | ] 154 | 155 | if truncation != 1: 156 | style_t = [] 157 | 158 | for style in styles: 159 | style_t.append( 160 | truncation_latent + truncation * (style - truncation_latent) 161 | ) 162 | 163 | styles = style_t 164 | 165 | 166 | if len(styles) < 2: 167 | inject_index = self.n_latent 168 | 169 | if styles[0].ndim < 3: 170 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 171 | 172 | else: 173 | latent = styles[0] 174 | 175 | elif len(styles) == 2: 176 | if inject_index is None: 177 | inject_index = random.randint(1, self.n_latent - 1) 178 | 179 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 180 | latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1) 181 | 182 | latent = torch.cat([latent, latent2], 1) 183 | elif len(styles) == (len(self.convs) + 2): 184 | latent = torch.stack(styles, 1) 185 | 186 | out = self.input(latent) 187 | out = self.conv1(out, latent[:, 0], noise=noise[0]) 188 | 189 | skip = self.to_rgb1(out, latent[:, 1]) 190 | 191 | i = 1 192 | for conv1, conv2, noise1, noise2, to_rgb in zip( 193 | self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs 194 | ): 195 | out = conv1(out, latent[:, i], noise=noise1) 196 | out = conv2(out, latent[:, i + 1], noise=noise2) 197 | skip = to_rgb(out, latent[:, i + 2], skip) 198 | i += 2 199 | 200 | ntexture = skip 201 | 202 | return ntexture 203 | 204 | 205 | -------------------------------------------------------------------------------- /inference_module/inferer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pydoc 4 | from omegaconf import OmegaConf 5 | import yaml 6 | import munch 7 | 8 | import torch 9 | from torch import nn 10 | 11 | import utils 12 | 13 | from utils.common import dict2device, tti, itt, to_sigm, to_tanh 14 | from utils.common import get_rotation_matrix, segment_img, load_module 15 | from utils import uv_renderer as uvr 16 | from torch.utils.data import DataLoader 17 | from models.efficientnet import EfficientNetLevelEncoder 18 | 19 | 20 | 21 | def get_config(config_path, divide_n_channels=1): 22 | with open(config_path, 'r') as f: 23 | config_ = yaml.load(f, Loader=yaml.FullLoader) 24 | 25 | config = dict() 26 | for k, v in config_.items(): 27 | if type(v) == dict and 'value' in v: 28 | config[k] = v['value'] 29 | else: 30 | config[k] = v 31 | 32 | if 'fix_renderer' not in config: 33 | config['fix_renderer'] = False 34 | if 'renderer_checkpoint' not in config: 35 | config['renderer_checkpoint'] = "" 36 | if 'ainp_path' not in config: 37 | config['ainp_path'] = None 38 | config['ainp_scales'] = [] 39 | 40 | config = munch.munchify(config) 41 | 42 | # config hacks 43 | config.segm_channels = 1 44 | config.ainp_path='data/spectral_texture16.pth' 45 | config.checkpoint512_path = None 46 | 47 | config.bidis_channel_multiplier = int(config.bidis_channel_multiplier / divide_n_channels) 48 | config.udis_channel_multiplier = int(config.udis_channel_multiplier / divide_n_channels) 49 | 50 | if 'alternating' not in config: 51 | config.alternating = False 52 | 53 | return config 54 | 55 | class Inferer(nn.Module): 56 | def __init__(self, generator, renderer, generator_config, config): 57 | super().__init__() 58 | 59 | self.config = config 60 | self.generator_config = generator_config 61 | self.device = config.device 62 | 63 | # self.renderer = renderer.eval().to(self.device) 64 | self.renderer = renderer.eval().to(self.device) 65 | self.generator = generator.eval().to(self.device) 66 | 67 | self.uv_renderer = uvr.UVRenderer(self.generator_config.image_size, self.generator_config.image_size) 68 | self.uv_renderer.to(self.device) 69 | 70 | self.batch_size = 1 71 | self.n_rotsamples = 64 72 | # self.eval() 73 | 74 | # load encoder 75 | encoder_config_path = os.path.join(config.encoder.experiment_dir, "encoder_config.yaml") 76 | with open(encoder_config_path) as f: 77 | encoder_config = OmegaConf.load(f) 78 | self.encoder = EfficientNetLevelEncoder( 79 | **utils.common.return_empty_dict_if_none(encoder_config.model.encoder.args) 80 | ) 81 | state_dict = torch.load(os.path.join(config.encoder.experiment_dir, "checkpoint", config.encoder.checkpoint_name)) 82 | self.encoder.load_state_dict(state_dict['inferer']['encoder']) 83 | self.encoder.eval() 84 | for p in self.encoder.parameters(): 85 | p.requires_grad = False 86 | 87 | # trainable parameters 88 | ## latent 89 | latent_mean, _ = self.calc_latent_stats( 90 | n_mean_latent=config.inferer.n_mean_latent 91 | ) 92 | 93 | latent_init = latent_mean \ 94 | .detach() \ 95 | .clone() \ 96 | .view(1, 1, self.generator.style_dim) \ 97 | .repeat(1, self.generator.n_latent, 1) 98 | 99 | self.latent = nn.Parameter(latent_init, requires_grad=True) 100 | 101 | ## noise 102 | noise_init = self.generator.make_noise() 103 | self.noises = nn.ParameterList([nn.Parameter(x, requires_grad=True) for x in noise_init]) 104 | 105 | ## ntexture 106 | self.ntexture = None 107 | 108 | # diff renderer 109 | self.diff_uv_renderer = utils.uv_renderer.NVDiffRastUVRenderer(config.inferer.faces_path, config.inferer.uv_vert_values_path) 110 | 111 | 112 | def sample_ntexture(self, ntexture, uv): 113 | return torch.nn.functional.grid_sample(ntexture, uv.permute(0, 2, 3, 1), align_corners=True) 114 | 115 | def infer_pass(self, noise, verts, ntexture=None, uv=None, sampled_ntexture=None): 116 | 117 | if verts is not None: 118 | B = verts.shape[0] 119 | else: 120 | B = uv.shape[0] 121 | 122 | 123 | if uv is None and sampled_ntexture is None: 124 | uv = self.uv_renderer(verts) 125 | 126 | if type(noise) == torch.Tensor and noise.ndim == 3: 127 | noise = torch.split(noise, 1, dim=1) 128 | noise = [x[0] for x in noise] 129 | 130 | if ntexture is None and sampled_ntexture is None: 131 | fake_ntexture = self.generator(noise, noise=self.noises, input_is_latent=True) 132 | fake_ntexture = torch.cat([fake_ntexture]*B, dim=0) 133 | else: 134 | fake_ntexture = ntexture 135 | 136 | if sampled_ntexture is None: 137 | sampled_ntexture = self.sample_ntexture(fake_ntexture, uv) 138 | 139 | renderer_dict_out = self.renderer(dict(uv=uv, sampled_ntextures=sampled_ntexture, nrender=sampled_ntexture)) 140 | fake_img = renderer_dict_out['fake_rgb'] 141 | fake_segm = renderer_dict_out['fake_segm'] 142 | fake_img = segment_img(fake_img, fake_segm) 143 | 144 | out = dict( 145 | fake_img=fake_img, 146 | fake_segm=fake_segm, 147 | fake_ntexture=fake_ntexture, 148 | uv=uv, 149 | sampled_ntexture=sampled_ntexture 150 | ) 151 | 152 | return out 153 | 154 | def calc_latent_stats(self, n_mean_latent=10000): 155 | with torch.no_grad(): 156 | noise_sample = torch.randn(n_mean_latent, self.generator.style_dim).to(self.device) 157 | latent_out = self.generator.style(noise_sample) 158 | latent_mean = latent_out.mean(0) 159 | latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5 160 | 161 | return latent_mean, latent_std 162 | 163 | def make_noise_multiple(self, batch_size, latent_dim, n_noises): 164 | noises = torch.randn(n_noises, batch_size, latent_dim, device=self.device).unbind(0) 165 | return noises 166 | 167 | def get_state_dict(self): 168 | # collect generator parameters 169 | generator_params = [] 170 | generator_params.extend(self.generator.conv1.parameters()) 171 | for l in self.generator.convs: 172 | generator_params.extend(l.parameters()) 173 | generator_params.extend(self.generator.to_rgb1.parameters()) 174 | for l in self.generator.to_rgbs: 175 | generator_params.extend(l.parameters()) 176 | 177 | state_dict = { 178 | 'latent': self.latent, 179 | 'noise': self.noise, 180 | 'generator_params': generator_params, 181 | 'ntexture': self.ntexture 182 | } 183 | 184 | return state_dict -------------------------------------------------------------------------------- /models/efficientnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from efficientnet_pytorch import EfficientNet 5 | 6 | 7 | class EfficientNetLevelEncoder(nn.Module): 8 | def __init__( 9 | self, 10 | output_n_latents=14, 11 | output_style_dim=512, 12 | feature2latent_input_size=512, 13 | model_name='efficientnet-b7', 14 | pretrained=False 15 | ): 16 | """ 17 | Check valid model names here: https://github.com/lukemelas/EfficientNet-PyTorch/blob/761ac94cdbecffca2eecec8cc51ac99afce2025e/efficientnet_pytorch/model.py#L26 18 | """ 19 | super().__init__() 20 | 21 | self.output_n_latents = output_n_latents 22 | self.output_style_dim = output_style_dim 23 | self.feature2latent_input_size = feature2latent_input_size 24 | 25 | self.pretrained = pretrained 26 | 27 | if pretrained: 28 | self.backbone = EfficientNet.from_pretrained(model_name) 29 | else: 30 | self.backbone = EfficientNet.from_name(model_name) 31 | 32 | # feature2latent 33 | feature_sizes = self._model_name_to_feature_sizes(model_name) 34 | 35 | n_latents_left = output_n_latents 36 | feature2feature_layers = [] 37 | feature2latent_layers = [] 38 | for i, feature_size in enumerate(feature_sizes): 39 | if i == len(feature_sizes) - 1: 40 | output_size = n_latents_left * output_style_dim 41 | n_latents_left -= n_latents_left 42 | else: 43 | output_size = 2 * output_style_dim 44 | n_latents_left -= 2 45 | 46 | feature2feature_layers.append(nn.Sequential( 47 | nn.Conv2d(feature_size, feature2latent_input_size, 3, padding=1), 48 | nn.ReLU(inplace=True) 49 | )) 50 | 51 | feature2latent_layers.append(nn.Sequential( 52 | nn.AdaptiveAvgPool2d(1), 53 | nn.Flatten(), 54 | nn.Linear(feature2latent_input_size, output_size) 55 | )) 56 | 57 | self.feature2feature_layers = nn.ModuleList(feature2feature_layers) 58 | self.feature2latent_layers = nn.ModuleList(feature2latent_layers) 59 | 60 | # normalization 61 | self.register_buffer('imagenet_mean', torch.tensor([0.4914, 0.4822, 0.4465]).view(-1, 1, 1)) 62 | self.register_buffer('imagenet_std', torch.tensor([0.2023, 0.1994, 0.2010]).view(-1, 1, 1)) 63 | 64 | @staticmethod 65 | def _model_name_to_feature_sizes(model_name): 66 | return { 67 | 'efficientnet-b0': [16, 24, 40, 112, 1280], 68 | 'efficientnet-b1': [16, 24, 40, 112, 1280], 69 | 'efficientnet-b2': [16, 24, 48, 120, 1408], 70 | 'efficientnet-b3': [24, 32, 48, 136, 1536], 71 | 'efficientnet-b4': [24, 32, 56, 160, 1792], 72 | 'efficientnet-b5': [24, 40, 64, 176, 2048], 73 | 'efficientnet-b6': [32, 40, 72, 200, 2304], 74 | 'efficientnet-b7': [32, 48, 80, 224, 2560] 75 | }[model_name] 76 | 77 | def renormalize_image_imagenet(self, image): 78 | image = (image + 1.0) / 2 # [-1.0, 1.0] -> [0.0, 1.0] 79 | image = (image - self.imagenet_mean) / self.imagenet_std 80 | return image 81 | 82 | def forward(self, x): 83 | if self.pretrained: 84 | x = self.renormalize_image_imagenet(x) 85 | 86 | endpoints = self.backbone.extract_endpoints(x) 87 | endpoint_names = sorted(list(endpoints.keys())) 88 | 89 | cumulative_endpoint_feature = None 90 | latents = [] 91 | for i, endpoint_name in zip(reversed(range(len(endpoint_names))), endpoint_names[::-1]): 92 | endpoint_feature = endpoints[endpoint_name] 93 | feature = self.feature2feature_layers[i](endpoint_feature) 94 | 95 | if cumulative_endpoint_feature is None: 96 | cumulative_endpoint_feature = feature 97 | else: 98 | cumulative_endpoint_feature = feature + nn.functional.upsample(cumulative_endpoint_feature, 99 | scale_factor=2.0, mode='bilinear', 100 | align_corners=True) 101 | 102 | current_latent = self.feature2latent_layers[i](cumulative_endpoint_feature) 103 | 104 | latents.append(current_latent) 105 | 106 | latent = torch.cat(latents, dim=1) 107 | latent = latent.view(-1, self.output_n_latents, self.output_style_dim) 108 | 109 | return latent 110 | 111 | 112 | class EfficientNetEncoder(nn.Module): 113 | def __init__( 114 | self, 115 | output_n_latents=14, 116 | output_style_dim=512, 117 | model_name='efficientnet-b7', 118 | pretrained=False, 119 | dropout_rate=0.5 120 | ): 121 | """ 122 | Check valid model names here: https://github.com/lukemelas/EfficientNet-PyTorch/blob/761ac94cdbecffca2eecec8cc51ac99afce2025e/efficientnet_pytorch/model.py#L26 123 | """ 124 | super().__init__() 125 | 126 | self.output_n_latents = output_n_latents 127 | self.output_style_dim = output_style_dim 128 | 129 | self.pretrained = pretrained 130 | 131 | if pretrained: 132 | self.backbone = EfficientNet.from_pretrained(model_name) 133 | else: 134 | self.backbone = EfficientNet.from_name(model_name) 135 | 136 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 137 | 138 | feature_size = self._model_name_to_feature_size(model_name) 139 | self.head = nn.Sequential( 140 | nn.Dropout(dropout_rate), 141 | nn.Linear(feature_size, output_n_latents * output_style_dim) 142 | ) 143 | 144 | # normalization 145 | self.register_buffer('imagenet_mean', torch.tensor([0.4914, 0.4822, 0.4465]).view(-1, 1, 1)) 146 | self.register_buffer('imagenet_std', torch.tensor([0.2023, 0.1994, 0.2010]).view(-1, 1, 1)) 147 | 148 | @staticmethod 149 | def _model_name_to_feature_size(model_name): 150 | return { 151 | 'efficientnet-b0': 1280, 152 | 'efficientnet-b1': 1280, 153 | 'efficientnet-b2': 1408, 154 | 'efficientnet-b3': 1536, 155 | 'efficientnet-b4': 1792, 156 | 'efficientnet-b5': 2048, 157 | 'efficientnet-b6': 2304, 158 | 'efficientnet-b7': 2560 159 | }[model_name] 160 | 161 | def renormalize_image_imagenet(self, image): 162 | image = (image + 1.0) / 2 # [-1.0, 1.0] -> [0.0, 1.0] 163 | image = (image - self.imagenet_mean) / self.imagenet_std 164 | return image 165 | 166 | def forward(self, x): 167 | if self.pretrained: 168 | x = self.renormalize_image_imagenet(x) 169 | 170 | x = self.backbone.extract_features(x) 171 | x = self.avg_pool(x) 172 | 173 | x = x.flatten(start_dim=1) 174 | x = self.head(x) 175 | 176 | x = x.view(-1, self.output_n_latents, self.output_style_dim) 177 | 178 | return x 179 | -------------------------------------------------------------------------------- /models/common/encoders/resnetv2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from models.styleganv2.modules import EqualConv2d, EqualConv2dSame 4 | from models.styleganv2.op import FusedLeakyReLU 5 | from . import _utils as utils 6 | 7 | 8 | class EncoderMixin: 9 | """Add encoder functionality such as: 10 | - output channels specification of feature tensors (produced by encoder) 11 | - patching first convolution for arbitrary input channels 12 | """ 13 | 14 | @property 15 | def out_channels(self): 16 | """Return channels dimensions for each tensor of forward output of encoder""" 17 | return self._out_channels[: self._depth + 1] 18 | 19 | def set_in_channels(self, in_channels): 20 | """Change first convolution chennels""" 21 | if in_channels == 3: 22 | return 23 | 24 | self._in_channels = in_channels 25 | if self._out_channels[0] == 3: 26 | self._out_channels = tuple([in_channels] + list(self._out_channels)[1:]) 27 | 28 | utils.patch_first_conv(model=self, in_channels=in_channels) 29 | 30 | def get_stages(self): 31 | """Method should be overridden in encoder""" 32 | raise NotImplementedError 33 | 34 | def make_dilated(self, stage_list, dilation_list): 35 | stages = self.get_stages() 36 | for stage_indx, dilation_rate in zip(stage_list, dilation_list): 37 | utils.replace_strides_with_dilation( 38 | module=stages[stage_indx], 39 | dilation_rate=dilation_rate, 40 | ) 41 | 42 | 43 | 44 | class BasicBlockv2(nn.Module): 45 | def __init__(self, inplanes, planes, stride=1, downsample=False, norm_layer=None, same=False): 46 | super().__init__() 47 | if norm_layer is None: 48 | norm_layer = nn.BatchNorm2d 49 | 50 | ConvLayer = EqualConv2d if not same else EqualConv2dSame 51 | 52 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 53 | self.conv1 = ConvLayer(inplanes, planes, 3, stride=stride, padding=1, bias=False) 54 | self.bn1 = norm_layer(planes, affine=True) 55 | self.relu = FusedLeakyReLU(planes) 56 | self.conv2 = ConvLayer(planes, planes, 3, padding=1, bias=False) 57 | self.bn2 = norm_layer(planes, affine=True) 58 | self.stride = stride 59 | 60 | if downsample: 61 | conv_down = ConvLayer(inplanes, planes, 1, stride=2, bias=False) 62 | norm_down = norm_layer(planes, affine=True) 63 | self.downsample = nn.Sequential(conv_down, norm_down) 64 | else: 65 | self.downsample = None 66 | 67 | def forward(self, x): 68 | identity = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | 77 | if self.downsample is not None: 78 | identity = self.downsample(x) 79 | 80 | out += identity 81 | out = self.relu(out) 82 | 83 | return out 84 | 85 | class ResNetv2Encoder(nn.Module, EncoderMixin): 86 | def __init__(self, in_channels, ngf=64, norm_layer=None, same=False): 87 | super().__init__() 88 | 89 | self._out_channels = [in_channels, ngf] 90 | self._out_channels += [ngf*2**i for i in range(4)] 91 | self._depth = 5 92 | 93 | if norm_layer is None: 94 | norm_layer = nn.BatchNorm2d 95 | 96 | ConvLayer = EqualConv2d if not same else EqualConv2dSame 97 | 98 | self.conv1 = ConvLayer(in_channels, ngf, kernel_size=7, stride=2, padding=3, bias=False) 99 | self.norm1 = norm_layer(ngf, affine=True) 100 | self.relu = FusedLeakyReLU(ngf) 101 | 102 | maxpool_layers = [] 103 | if same: 104 | maxpool_layers.append(nn.ReplicationPad2d(1)) 105 | else: 106 | maxpool_layers.append(nn.ZeroPad2d(1)) 107 | maxpool_layers.append(nn.MaxPool2d(kernel_size=3, stride=2)) 108 | self.maxpool = nn.Sequential(*maxpool_layers) 109 | 110 | block1_1 = BasicBlockv2(ngf, ngf, stride=1, norm_layer=norm_layer, same=same) 111 | block1_2 = BasicBlockv2(ngf, ngf, stride=1, norm_layer=norm_layer, same=same) 112 | self.layer1 = nn.Sequential(block1_1, block1_2) 113 | 114 | block2_1 = BasicBlockv2(ngf, ngf*2, stride=2, norm_layer=norm_layer, downsample=True, same=same) 115 | block2_2 = BasicBlockv2(ngf*2, ngf*2, stride=1, norm_layer=norm_layer, same=same) 116 | self.layer2 = nn.Sequential(block2_1, block2_2) 117 | 118 | block3_1 = BasicBlockv2(ngf*2, ngf*4, stride=2, norm_layer=norm_layer, downsample=True, same=same) 119 | block3_2 = BasicBlockv2(ngf*4, ngf*4, stride=1, norm_layer=norm_layer, same=same) 120 | self.layer3 = nn.Sequential(block3_1, block3_2) 121 | 122 | block4_1 = BasicBlockv2(ngf*4, ngf*8, stride=2, norm_layer=norm_layer, downsample=True, same=same) 123 | block4_2 = BasicBlockv2(ngf*8, ngf*8, stride=1, norm_layer=norm_layer, same=same) 124 | self.layer4 = nn.Sequential(block4_1, block4_2) 125 | 126 | def get_stages(self): 127 | return [ 128 | nn.Identity(), 129 | nn.Sequential(self.conv1, self.norm1, self.relu), 130 | nn.Sequential(self.maxpool, self.layer1), 131 | self.layer2, 132 | self.layer3, 133 | self.layer4, 134 | ] 135 | 136 | def forward(self, x): 137 | stages = self.get_stages() 138 | 139 | features = [] 140 | for i in range(self._depth + 1): 141 | x = stages[i](x) 142 | features.append(x) 143 | 144 | return features 145 | 146 | 147 | class ResNetv2EncoderCont(nn.Module, EncoderMixin): 148 | def __init__(self, in_channels, add_channels=1, ngf=64, norm_layer=None, same=False): 149 | super().__init__() 150 | 151 | self._out_channels = [in_channels, ngf] 152 | self._out_channels += [ngf*2**i for i in range(4)] 153 | self._depth = 5 154 | 155 | if norm_layer is None: 156 | norm_layer = nn.BatchNorm2d 157 | 158 | ConvLayer = EqualConv2d if not same else EqualConv2dSame 159 | 160 | self.in_channels = in_channels 161 | self.add_channels = add_channels 162 | 163 | self.conv1 = ConvLayer(in_channels, ngf, kernel_size=7, stride=2, padding=3, bias=False) 164 | self.conv1_add = ConvLayer(add_channels, ngf, kernel_size=7, stride=2, padding=3, bias=False) 165 | 166 | self.norm1 = norm_layer(ngf, affine=True) 167 | self.relu = FusedLeakyReLU(ngf) 168 | 169 | maxpool_layers = [] 170 | if same: 171 | maxpool_layers.append(nn.ReplicationPad2d(1)) 172 | else: 173 | maxpool_layers.append(nn.ZeroPad2d(1)) 174 | maxpool_layers.append(nn.MaxPool2d(kernel_size=3, stride=2)) 175 | self.maxpool = nn.Sequential(*maxpool_layers) 176 | 177 | block1_1 = BasicBlockv2(ngf, ngf, stride=1, norm_layer=norm_layer, same=same) 178 | block1_2 = BasicBlockv2(ngf, ngf, stride=1, norm_layer=norm_layer, same=same) 179 | self.layer1 = nn.Sequential(block1_1, block1_2) 180 | 181 | block2_1 = BasicBlockv2(ngf, ngf*2, stride=2, norm_layer=norm_layer, downsample=True, same=same) 182 | block2_2 = BasicBlockv2(ngf*2, ngf*2, stride=1, norm_layer=norm_layer, same=same) 183 | self.layer2 = nn.Sequential(block2_1, block2_2) 184 | 185 | block3_1 = BasicBlockv2(ngf*2, ngf*4, stride=2, norm_layer=norm_layer, downsample=True, same=same) 186 | block3_2 = BasicBlockv2(ngf*4, ngf*4, stride=1, norm_layer=norm_layer, same=same) 187 | self.layer3 = nn.Sequential(block3_1, block3_2) 188 | 189 | block4_1 = BasicBlockv2(ngf*4, ngf*8, stride=2, norm_layer=norm_layer, downsample=True, same=same) 190 | block4_2 = BasicBlockv2(ngf*8, ngf*8, stride=1, norm_layer=norm_layer, same=same) 191 | self.layer4 = nn.Sequential(block4_1, block4_2) 192 | 193 | def get_stages(self): 194 | return [ 195 | nn.Identity(), 196 | nn.Sequential(self.conv1, self.norm1, self.relu), 197 | nn.Sequential(self.maxpool, self.layer1), 198 | self.layer2, 199 | self.layer3, 200 | self.layer4, 201 | ] 202 | 203 | def forward(self, x, alfa=1.): 204 | stages = self.get_stages() 205 | 206 | features = [] 207 | 208 | x_main = x[:, :self.in_channels] 209 | x_add = x[:, self.in_channels:] 210 | assert(x_add.shape[1] == self.add_channels) 211 | 212 | features.append(x_main) 213 | out_main = self.conv1(x_main) 214 | out_add = self.conv1_add(x_add) 215 | 216 | out = out_main + alfa*out_add 217 | 218 | out = self.norm1(out) 219 | out = self.relu(out) 220 | features.append(out) 221 | 222 | 223 | out = self.maxpool(out) 224 | out = self.layer1(out) 225 | features.append(out) 226 | 227 | 228 | out = self.layer2(out) 229 | features.append(out) 230 | 231 | out = self.layer3(out) 232 | features.append(out) 233 | 234 | out = self.layer4(out) 235 | features.append(out) 236 | 237 | return features -------------------------------------------------------------------------------- /models/styleganv2/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 19 | int c = a / b; 20 | 21 | if (c * b > a) { 22 | c--; 23 | } 24 | 25 | return c; 26 | } 27 | 28 | 29 | struct UpFirDn2DKernelParams { 30 | int up_x; 31 | int up_y; 32 | int down_x; 33 | int down_y; 34 | int pad_x0; 35 | int pad_x1; 36 | int pad_y0; 37 | int pad_y1; 38 | 39 | int major_dim; 40 | int in_h; 41 | int in_w; 42 | int minor_dim; 43 | int kernel_h; 44 | int kernel_w; 45 | int out_h; 46 | int out_w; 47 | int loop_major; 48 | int loop_x; 49 | }; 50 | 51 | 52 | template 53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { 54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 56 | 57 | __shared__ volatile float sk[kernel_h][kernel_w]; 58 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 59 | 60 | int minor_idx = blockIdx.x; 61 | int tile_out_y = minor_idx / p.minor_dim; 62 | minor_idx -= tile_out_y * p.minor_dim; 63 | tile_out_y *= tile_out_h; 64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 65 | int major_idx_base = blockIdx.z * p.loop_major; 66 | 67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { 68 | return; 69 | } 70 | 71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { 72 | int ky = tap_idx / kernel_w; 73 | int kx = tap_idx - ky * kernel_w; 74 | scalar_t v = 0.0; 75 | 76 | if (kx < p.kernel_w & ky < p.kernel_h) { 77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 78 | } 79 | 80 | sk[ky][kx] = v; 81 | } 82 | 83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { 84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { 85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 87 | int tile_in_x = floor_div(tile_mid_x, up_x); 88 | int tile_in_y = floor_div(tile_mid_y, up_y); 89 | 90 | __syncthreads(); 91 | 92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { 93 | int rel_in_y = in_idx / tile_in_w; 94 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 95 | int in_x = rel_in_x + tile_in_x; 96 | int in_y = rel_in_y + tile_in_y; 97 | 98 | scalar_t v = 0.0; 99 | 100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; 102 | } 103 | 104 | sx[rel_in_y][rel_in_x] = v; 105 | } 106 | 107 | __syncthreads(); 108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { 109 | int rel_out_y = out_idx / tile_out_w; 110 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 111 | int out_x = rel_out_x + tile_out_x; 112 | int out_y = rel_out_y + tile_out_y; 113 | 114 | int mid_x = tile_mid_x + rel_out_x * down_x; 115 | int mid_y = tile_mid_y + rel_out_y * down_y; 116 | int in_x = floor_div(mid_x, up_x); 117 | int in_y = floor_div(mid_y, up_y); 118 | int rel_in_x = in_x - tile_in_x; 119 | int rel_in_y = in_y - tile_in_y; 120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 122 | 123 | scalar_t v = 0.0; 124 | 125 | #pragma unroll 126 | for (int y = 0; y < kernel_h / up_y; y++) 127 | #pragma unroll 128 | for (int x = 0; x < kernel_w / up_x; x++) 129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; 130 | 131 | if (out_x < p.out_w & out_y < p.out_h) { 132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; 133 | } 134 | } 135 | } 136 | } 137 | } 138 | 139 | 140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 141 | int up_x, int up_y, int down_x, int down_y, 142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 143 | int curDevice = -1; 144 | cudaGetDevice(&curDevice); 145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 146 | 147 | UpFirDn2DKernelParams p; 148 | 149 | auto x = input.contiguous(); 150 | auto k = kernel.contiguous(); 151 | 152 | p.major_dim = x.size(0); 153 | p.in_h = x.size(1); 154 | p.in_w = x.size(2); 155 | p.minor_dim = x.size(3); 156 | p.kernel_h = k.size(0); 157 | p.kernel_w = k.size(1); 158 | p.up_x = up_x; 159 | p.up_y = up_y; 160 | p.down_x = down_x; 161 | p.down_y = down_y; 162 | p.pad_x0 = pad_x0; 163 | p.pad_x1 = pad_x1; 164 | p.pad_y0 = pad_y0; 165 | p.pad_y1 = pad_y1; 166 | 167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; 168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; 169 | 170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 171 | 172 | int mode = -1; 173 | 174 | int tile_out_h; 175 | int tile_out_w; 176 | 177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 178 | mode = 1; 179 | tile_out_h = 16; 180 | tile_out_w = 64; 181 | } 182 | 183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { 184 | mode = 2; 185 | tile_out_h = 16; 186 | tile_out_w = 64; 187 | } 188 | 189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 190 | mode = 3; 191 | tile_out_h = 16; 192 | tile_out_w = 64; 193 | } 194 | 195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { 196 | mode = 4; 197 | tile_out_h = 16; 198 | tile_out_w = 64; 199 | } 200 | 201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { 202 | mode = 5; 203 | tile_out_h = 8; 204 | tile_out_w = 32; 205 | } 206 | 207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { 208 | mode = 6; 209 | tile_out_h = 8; 210 | tile_out_w = 32; 211 | } 212 | 213 | dim3 block_size; 214 | dim3 grid_size; 215 | 216 | if (tile_out_h > 0 && tile_out_w) { 217 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 218 | p.loop_x = 1; 219 | block_size = dim3(32 * 8, 1, 1); 220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 222 | (p.major_dim - 1) / p.loop_major + 1); 223 | } 224 | 225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 226 | switch (mode) { 227 | case 1: 228 | upfirdn2d_kernel<<>>( 229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 230 | ); 231 | 232 | break; 233 | 234 | case 2: 235 | upfirdn2d_kernel<<>>( 236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 237 | ); 238 | 239 | break; 240 | 241 | case 3: 242 | upfirdn2d_kernel<<>>( 243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 244 | ); 245 | 246 | break; 247 | 248 | case 4: 249 | upfirdn2d_kernel<<>>( 250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 251 | ); 252 | 253 | break; 254 | 255 | case 5: 256 | upfirdn2d_kernel<<>>( 257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 258 | ); 259 | 260 | break; 261 | 262 | case 6: 263 | upfirdn2d_kernel<<>>( 264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 265 | ); 266 | 267 | break; 268 | } 269 | }); 270 | 271 | return out; 272 | } -------------------------------------------------------------------------------- /inference_module/runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | import pickle 5 | import cv2 6 | 7 | import torch 8 | from torch import nn 9 | 10 | import lpips 11 | import kornia 12 | 13 | import utils 14 | from utils.bbox import compute_bboxes_from_keypoints, crop_resize_image 15 | 16 | from inference_module.criterion import LPIPSCriterion, MSECriterion, MAECriterion, SSIMCriterion, \ 17 | UnaryDiscriminatorNonsaturatingCriterion, UnaryDiscriminatorFeatureMatchingCriterion 18 | 19 | 20 | class Runner: 21 | def __init__(self, config, inferer, smplx_models_dict, image_size, device='cuda:0'): 22 | self.config = config 23 | self.inferer = inferer 24 | self.smplx_models_dict = smplx_models_dict 25 | 26 | self.image_size = image_size 27 | self.input_size = image_size // 2 28 | 29 | 30 | # load criterions 31 | ## lpips 32 | self.lpips_criterion = LPIPSCriterion(net='vgg').to(device) 33 | 34 | ## mse 35 | self.mse_criterion = MSECriterion().to(device) 36 | 37 | ## encoder_latent_deviation 38 | self.encoder_latent_deviation_criterion = MAECriterion().to(device) 39 | 40 | ## generator_params_deviation 41 | self.generator_params_deviation_criterion = MAECriterion().to(device) 42 | 43 | ## ntexture_deviation 44 | self.ntexture_deviation_criterion = MAECriterion().to(device) 45 | 46 | ## face_lpips 47 | self.face_lpips_criterion = LPIPSCriterion(net='vgg').to(device) 48 | 49 | 50 | def _get_latent_optimization_group(self, lr): 51 | optimization_group = { 52 | 'params': [self.inferer.latent], 53 | 'lr': lr 54 | } 55 | return optimization_group 56 | 57 | def _get_generator_optimization_group(self, lr): 58 | generator_params = [] 59 | 60 | # convs 61 | generator_params.extend(self.inferer.generator.conv1.parameters()) 62 | for l in self.inferer.generator.convs: 63 | generator_params.extend(l.parameters()) 64 | 65 | # to_rgbs 66 | generator_params.extend(self.inferer.generator.to_rgb1.parameters()) 67 | for l in self.inferer.generator.to_rgbs: 68 | generator_params.extend(l.parameters()) 69 | 70 | optimization_group = { 71 | 'params': generator_params, 72 | 'lr': lr 73 | } 74 | 75 | return optimization_group 76 | 77 | def _get_noise_optimization_group(self, lr): 78 | optimization_group = { 79 | 'params': self.inferer.noises, 80 | 'lr': lr 81 | } 82 | return optimization_group 83 | 84 | def _get_ntexture_optimization_group(self, lr): 85 | optimization_group = { 86 | 'params': self.inferer.ntexture, 87 | 'lr': lr 88 | } 89 | return optimization_group 90 | 91 | def _get_beta_optimization_group(self, lr): 92 | optimization_group = { 93 | 'params': self.inferer.betas, 94 | 'lr': lr 95 | } 96 | return optimization_group 97 | 98 | 99 | def _get_optimization_group(self, name, lr): 100 | if name == 'latent': 101 | optimization_group = self._get_latent_optimization_group(lr) 102 | elif name == 'generator': 103 | optimization_group = self._get_generator_optimization_group(lr) 104 | elif name == 'noise': 105 | optimization_group = self._get_noise_optimization_group(lr) 106 | elif name == 'ntexture': 107 | optimization_group = self._get_ntexture_optimization_group(lr) 108 | elif name == 'beta': 109 | optimization_group = self._get_beta_optimization_group(lr) 110 | else: 111 | raise NotImplementedError(f"Unknown name {name}") 112 | 113 | return optimization_group 114 | 115 | 116 | def crop_face_with_landmarks(self, image, landmarks, face_size=256): 117 | face_kp = landmarks[:, 25:93].clone() 118 | 119 | valid_mask = (face_kp[..., 0] < 0).sum(dim=1) <= 0 120 | invalid_mask = torch.logical_not(valid_mask) 121 | 122 | # add dummy valid landmarks for invalid images 123 | face_kp[invalid_mask] = torch.ones( 124 | torch.sum(invalid_mask), face_kp.shape[1], face_kp.shape[2], device=face_kp.device 125 | ) 126 | 127 | bboxes_estimate = compute_bboxes_from_keypoints(face_kp) 128 | face = crop_resize_image(image, bboxes_estimate, face_size) 129 | 130 | return face, valid_mask 131 | 132 | def make_smplx(self, train_dict): 133 | gender = train_dict['gender'][0] 134 | 135 | B = train_dict['global_orient'].shape[0] 136 | verts_list = [] 137 | 138 | for i in range(B): 139 | smplx_output = self.smplx_models_dict[gender]( 140 | global_orient=train_dict['global_orient'][i:i+1], 141 | transl=train_dict['transl'][i:i+1], 142 | betas=self.inferer.betas[i:i+1], 143 | expression=train_dict['expressions'][i:i+1], 144 | body_pose=train_dict['body_pose'][i:i+1], 145 | left_hand_pose=train_dict['left_hand_pose'][i:i+1, :6], 146 | right_hand_pose=train_dict['right_hand_pose'][i:i+1, :6], 147 | jaw_pose=train_dict['jaw_pose'][i:i+1], 148 | ) 149 | verts = smplx_output.vertices 150 | verts_list.append(verts) 151 | 152 | verts = torch.cat(verts_list, dim=0) 153 | 154 | 155 | return verts 156 | 157 | def run_epoch(self, train_dict): 158 | # self.inferer.eval() 159 | 160 | # get smplx betas, hand and body pose 161 | self.inferer.betas = train_dict['betas'] 162 | self.inferer.betas.requires_grad_(True) 163 | 164 | 165 | v_inds = np.load(self.config.runner.v_inds_path) 166 | 167 | # save initial generator params 168 | generator_params_inital = utils.common.flatten_parameters( 169 | self.inferer.generator.parameters() 170 | ).detach().clone() 171 | 172 | # init with encoder 173 | encoder_latent = self.inferer.encoder.forward(kornia.resize(train_dict['real_rgb'], (self.config.encoder.image_size , self.config.encoder.image_size ))) 174 | encoder_latent = encoder_latent.mean(0, keepdim=True) # TODO: more smart averaging of latent vectors 175 | self.inferer.latent.data = encoder_latent.detach().clone() 176 | 177 | # save initial ntexture 178 | with torch.no_grad(): 179 | infer_output_dict = self.inferer.infer_pass( 180 | torch.cat([self.inferer.latent] * len(train_dict['real_rgb']), dim=0), 181 | train_dict['verts'], 182 | ntexture=None, 183 | uv=None 184 | ) 185 | 186 | initial_ntexture = infer_output_dict['fake_ntexture'][:1].detach().clone() 187 | 188 | stages = sorted(self.config.stages.keys()) 189 | # stages = stages[-1:] 190 | for stage in stages: 191 | stage_config = self.config.stages[stage] 192 | 193 | # maybe switch input source 194 | if stage_config.input_source == 'ntexture': 195 | with torch.no_grad(): 196 | infer_output_dict = self.inferer.infer_pass( 197 | torch.cat([self.inferer.latent] * len(train_dict['real_rgb']), dim=0), 198 | train_dict['verts'], 199 | ntexture=None, 200 | uv=None 201 | ) 202 | 203 | self.inferer.ntexture = nn.Parameter(infer_output_dict['fake_ntexture'][:1].detach().clone(), requires_grad=True) 204 | 205 | self.inferer.latent = None 206 | self.inferer.noises = None 207 | 208 | # setup optimizer 209 | optimization_groups = [] 210 | for optimization_target in stage_config.optimization_targets: 211 | optimization_groups.append( 212 | self._get_optimization_group(optimization_target, stage_config.lr[optimization_target]) 213 | ) 214 | 215 | optimizer = torch.optim.Adam(optimization_groups) 216 | 217 | # run optimization 218 | pbar = tqdm(range(stage_config.n_iters)) 219 | pbar.set_description(f"{self.config.log.experiment_name}") 220 | for i in pbar: 221 | torch.autograd.set_grad_enabled(True) 222 | 223 | verts_3d = self.make_smplx(train_dict) 224 | verts_3d = verts_3d[:, v_inds] 225 | 226 | verts_ndc, matrix_ndc = self.inferer.diff_uv_renderer.convert_to_ndc( 227 | verts_3d, 228 | train_dict['K'], 229 | self.input_size, self.input_size, 230 | near=0.01, far=20.0, 231 | invert_verts=False 232 | ) 233 | 234 | uv, uv_da, mask, rast = self.inferer.diff_uv_renderer.render(verts_ndc, matrix_ndc, render_h=self.input_size, render_w=self.input_size) 235 | 236 | 237 | 238 | infer_output_dict = self.inferer.infer_pass( 239 | None if self.inferer.latent is None else torch.cat([self.inferer.latent] * len(train_dict['real_rgb']), dim=0), 240 | None, 241 | ntexture=None if self.inferer.ntexture is None else torch.cat([self.inferer.ntexture] * len(train_dict['real_rgb']), dim=0), 242 | uv=uv 243 | ) 244 | 245 | image_pred = infer_output_dict['fake_img'] 246 | segm_pred = infer_output_dict['fake_segm'] 247 | ntexture_pred = infer_output_dict['fake_ntexture'] 248 | if i == 0: 249 | torch.save(ntexture_pred[:1], f'data/textures/step_by_step/texture_{stage}.pth') 250 | 251 | # calculate losses 252 | loss = 0.0 253 | 254 | ## lpips 255 | if stage_config.loss_weight.lpips: 256 | lpips_loss = self.lpips_criterion(image_pred, train_dict['real_rgb']) 257 | loss += stage_config.loss_weight.lpips * lpips_loss 258 | 259 | ## mse 260 | if stage_config.loss_weight.mse: 261 | mse_loss = self.mse_criterion(image_pred, train_dict['real_rgb']) 262 | loss += stage_config.loss_weight.mse * mse_loss 263 | 264 | ## encoder_latent_deviation 265 | if stage_config.loss_weight.encoder_latent_deviation: 266 | encoder_latent_deviation_loss = self.encoder_latent_deviation_criterion(self.inferer.latent, encoder_latent) 267 | loss += stage_config.loss_weight.encoder_latent_deviation * encoder_latent_deviation_loss 268 | 269 | ## generator_params_deviation 270 | if stage_config.loss_weight.generator_params_deviation: 271 | generator_params_deviation_loss = self.generator_params_deviation_criterion( 272 | utils.common.flatten_parameters(self.inferer.generator.parameters()), 273 | generator_params_inital 274 | ) 275 | loss += stage_config.loss_weight.generator_params_deviation * generator_params_deviation_loss 276 | 277 | ## ntexture_deviation 278 | if stage_config.loss_weight.ntexture_deviation: 279 | ntexture_deviation_loss = self.ntexture_deviation_criterion( 280 | ntexture_pred, 281 | initial_ntexture 282 | ) 283 | loss += stage_config.loss_weight.ntexture_deviation * ntexture_deviation_loss 284 | 285 | 286 | ## face_lpips 287 | image_real_face = image_pred_face = None 288 | if stage_config.loss_weight.face_lpips: 289 | ### crop faces and get valik mask (some images don't contain faces) 290 | image_real_face, valid_mask = self.crop_face_with_landmarks( 291 | train_dict['real_rgb'], train_dict['landmarks'], face_size=256 292 | ) 293 | 294 | image_pred_face, _ = self.crop_face_with_landmarks( # valid_mask is the same as for real 295 | image_pred, train_dict['landmarks'], face_size=256 296 | ) 297 | 298 | face_lpips_loss = self.face_lpips_criterion(image_pred_face, image_real_face, valid_mask=valid_mask) 299 | loss += stage_config.loss_weight.face_lpips * face_lpips_loss 300 | else: 301 | face_image_pred = None 302 | face_image_target = None 303 | 304 | 305 | optimizer.zero_grad() 306 | loss.backward() 307 | optimizer.step() 308 | 309 | 310 | torch.save(ntexture_pred[:1], f'data/textures/step_by_step/texture_final.pth') 311 | 312 | return self.inferer.ntexture -------------------------------------------------------------------------------- /utils/bbox.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def get_ltrb_from_verts(verts): 6 | verts_projected = (verts / (verts[..., 2:]))[..., :2] 7 | 8 | x = verts_projected[..., 0] 9 | y = verts_projected[..., 1] 10 | 11 | # get bbox in format (left, top, right, bottom) 12 | l = torch.min(x, dim=1)[0].long() 13 | t = torch.min(y, dim=1)[0].long() 14 | r = torch.max(x, dim=1)[0].long() 15 | b = torch.max(y, dim=1)[0].long() 16 | 17 | return torch.stack([l, t, r, b], dim=-1) 18 | 19 | 20 | def scale_bbox(ltrb, scale): 21 | width, height = ltrb[:, 2] - ltrb[:, 0], ltrb[:, 3] - ltrb[:, 1] 22 | 23 | x_center, y_center = (ltrb[:, 2] + ltrb[:, 0]) // 2, (ltrb[:, 3] + ltrb[:, 1]) // 2 24 | new_width, new_height = int(scale * width), int(scale * height) 25 | 26 | new_left = x_center - new_width // 2 27 | new_right = new_left + new_width 28 | 29 | new_top = y_center - new_height // 2 30 | new_bottom = new_top + new_height 31 | 32 | new_ltrb = torch.stack([new_left, new_top, new_right, new_bottom], dim=-1) 33 | 34 | return new_ltrb 35 | 36 | 37 | def get_square_bbox(ltrb): 38 | """Makes square bbox from any bbox by stretching of minimal length side 39 | Args: 40 | bbox tuple of size 4: input bbox (left, upper, right, lower) 41 | Returns: 42 | bbox: tuple of size 4: resulting square bbox (left, upper, right, lower) 43 | """ 44 | 45 | left = ltrb[:, 0] 46 | right = ltrb[:, 2] 47 | top = ltrb[:, 1] 48 | bottom = ltrb[:, 3] 49 | 50 | width, height = right - left, bottom - top 51 | 52 | if width > height: 53 | y_center = (ltrb[:, 3] + ltrb[:, 1]) // 2 54 | top = y_center - width // 2 55 | bottom = top + width 56 | else: 57 | x_center = (ltrb[:, 2] + ltrb[:, 0]) // 2 58 | left = x_center - height // 2 59 | right = left + height 60 | 61 | new_ltrb = torch.stack([left, top, right, bottom], dim=-1) 62 | return new_ltrb 63 | 64 | 65 | def get_ltrb_bbox(verts, scale=1.2): 66 | ltrb = get_ltrb_from_verts(verts) 67 | 68 | ltrb = scale_bbox(ltrb, scale) 69 | ltrb = get_square_bbox(ltrb) 70 | 71 | return ltrb 72 | 73 | 74 | def crop_resize_image(image, ltrb, new_image_size): 75 | size = new_image_size 76 | 77 | l, t, r, b = ltrb.t().float() 78 | batch_size, num_channels, h, w = image.shape 79 | 80 | affine_matrix = torch.zeros(batch_size, 2, 3, dtype=torch.float32, device=image.device) 81 | affine_matrix[:, 0, 0] = (r - l) / w 82 | affine_matrix[:, 1, 1] = (b - t) / h 83 | affine_matrix[:, 0, 2] = (l + r) / w - 1 84 | affine_matrix[:, 1, 2] = (t + b) / h - 1 85 | 86 | output_shape = (batch_size, num_channels) + (size, size) 87 | grid = torch.affine_grid_generator(affine_matrix, output_shape, align_corners=True) 88 | grid = grid.to(image.dtype) 89 | return torch.nn.functional.grid_sample(image, grid, 'bilinear', 'reflection', align_corners=True) 90 | 91 | 92 | def crop_resize_coords(coords, ltrb, new_image_size): 93 | coords = coords.clone() 94 | 95 | width = ltrb[:, 2] - ltrb[:, 0] 96 | heigth = ltrb[:, 3] - ltrb[:, 1] 97 | 98 | coords[..., 0] -= ltrb[:, 0] 99 | coords[..., 1] -= ltrb[:, 1] 100 | 101 | coords[..., 0] *= new_image_size / width 102 | coords[..., 1] *= new_image_size / heigth 103 | 104 | return coords 105 | 106 | 107 | def crop_resize_verts(verts, K, ltrb, new_image_size): 108 | # it's supposed that it smplifyx's verts are in trivial camera coordinates 109 | fx, fy, cx, cy = 1.0, 1.0, 0.0, 0.0 110 | # crop 111 | cx, cy = cx - ltrb[:, 0], cy - ltrb[:, 1] 112 | # scale 113 | width, height = ltrb[:, 2] - ltrb[:, 0], ltrb[:, 3] - ltrb[:, 1] 114 | new_h = new_w = new_image_size 115 | 116 | h_scale, w_scale = new_w / width.float(), new_h / height.float() 117 | 118 | fx, fy = fx * w_scale, fy * h_scale 119 | cx, cy = cx * w_scale, cy * h_scale 120 | 121 | # update verts 122 | B = verts.shape[0] 123 | K_upd = torch.eye(3) 124 | K_upd = torch.stack([K_upd] * B, dim=0).to(verts.device) 125 | 126 | K_upd[:, 0, 0] = fx 127 | K_upd[:, 1, 1] = fy 128 | K_upd[:, 0, 2] = cx 129 | K_upd[:, 1, 2] = cy 130 | 131 | verts_cropped = torch.bmm(verts, K_upd.transpose(1, 2)) 132 | K_cropped = torch.bmm(K_upd, K) 133 | 134 | return verts_cropped, K_cropped 135 | 136 | 137 | def compute_bboxes_from_keypoints(keypoints): 138 | """ 139 | keypoints: B x 68*2 140 | return value: B x 4 (t, b, l, r) 141 | Compute a very rough bounding box approximate from 68 keypoints. 142 | """ 143 | x, y = keypoints.float().view(-1, 68, 2).transpose(0, 2) 144 | 145 | face_height = y[8] - y[27] 146 | b = y[8] + face_height * 0.2 147 | t = y[27] - face_height * 0.47 148 | 149 | midpoint_x = (x.min(dim=0)[0] + x.max(dim=0)[0]) / 2 150 | half_height = (b - t) * 0.5 151 | 152 | l = midpoint_x - half_height 153 | r = midpoint_x + half_height 154 | 155 | 156 | return torch.stack([t, b, l, r], dim=1) 157 | 158 | 159 | # def crop_and_resize(images, bboxes, target_size=None): 160 | # """ 161 | # images: B x C x H x W 162 | # bboxes: B x 4; [t, b, l, r], in pixel coordinates 163 | # target_size (optional): tuple (h, w) 164 | 165 | # return value: B x C x h x w 166 | 167 | # Crop i-th image using i-th bounding box, then resize all crops to the 168 | # desired shape (default is the original images' size, H x W). 169 | # """ 170 | 171 | # t, b, l, r = bboxes.t().float() 172 | # batch_size, num_channels, h, w = images.shape 173 | 174 | # affine_matrix = torch.zeros(batch_size, 2, 3, dtype=torch.float32, device=images.device) 175 | # affine_matrix[:, 0, 0] = (r-l) / w 176 | # affine_matrix[:, 1, 1] = (b-t) / h 177 | # affine_matrix[:, 0, 2] = (l+r) / w - 1 178 | # affine_matrix[:, 1, 2] = (t+b) / h - 1 179 | 180 | # output_shape = (batch_size, num_channels) + (target_size or (h, w)) 181 | # grid = torch.affine_grid_generator(affine_matrix, output_shape, align_corners=True) 182 | # grid = grid.to(images.dtype) 183 | # return torch.nn.functional.grid_sample(images, grid, 'bilinear', 'reflection', align_corners=True) 184 | 185 | 186 | # import numpy as np 187 | # import torch 188 | 189 | # from PIL import Image 190 | 191 | 192 | # def get_ltrb_from_verts(verts): 193 | # verts_projected = (verts / (verts[..., 2:]))[..., :2] 194 | 195 | # x = verts_projected[..., 0] 196 | # y = verts_projected[..., 1] 197 | 198 | # # get bbox in format (left, top, right, bottom) 199 | # l = torch.min(x, dim=1)[0].long() 200 | # t = torch.min(y, dim=1)[0].long() 201 | # r = torch.max(x, dim=1)[0].long() 202 | # b = torch.max(y, dim=1)[0].long() 203 | 204 | # return torch.stack([l, t, r, b], dim=-1) 205 | 206 | 207 | # def get_square_bbox(ltrb): 208 | # """Makes square bbox from any bbox by stretching of minimal length side 209 | 210 | # Args: 211 | # bbox tuple of size 4: input bbox (left, upper, right, lower) 212 | 213 | # Returns: 214 | # bbox: tuple of size 4: resulting square bbox (left, upper, right, lower) 215 | # """ 216 | 217 | # left = ltrb[:, 0] 218 | # right = ltrb[:, 2] 219 | # top = ltrb[:, 1] 220 | # bottom = ltrb[:, 3] 221 | 222 | # width, height = right - left, bottom - top 223 | 224 | # if width > height: 225 | # y_center = (ltrb[:, 3] + ltrb[:, 1]) // 2 226 | # top = y_center - width // 2 227 | # bottom = top + width 228 | # else: 229 | # x_center = (ltrb[:, 2] + ltrb[:, 0]) // 2 230 | # left = x_center - height // 2 231 | # right = left + height 232 | 233 | # new_ltrb = torch.stack([left, top, right, bottom], dim=-1) 234 | # return new_ltrb 235 | 236 | 237 | # def scale_bbox(ltrb, scale): 238 | # width, height = ltrb[:, 2] - ltrb[:, 0], ltrb[:, 3] - ltrb[:, 1] 239 | 240 | # x_center, y_center = (ltrb[:, 2] + ltrb[:, 0]) // 2, (ltrb[:, 3] + ltrb[:, 1]) // 2 241 | # new_width, new_height = int(scale * width), int(scale * height) 242 | 243 | # new_left = x_center - new_width // 2 244 | # new_right = new_left + new_width 245 | 246 | # new_top = y_center - new_height // 2 247 | # new_bottom = new_top + new_height 248 | 249 | # new_ltrb = torch.stack([new_left, new_top, new_right, new_bottom], dim=-1) 250 | 251 | # return new_ltrb 252 | 253 | 254 | # def get_ltrb_bbox(verts, scale=1.2): 255 | # ltrb = get_ltrb_from_verts(verts) 256 | 257 | # ltrb = scale_bbox(ltrb, scale) 258 | # ltrb = get_square_bbox(ltrb) 259 | 260 | # return ltrb 261 | 262 | 263 | # def compute_bboxes_from_keypoints(keypoints): 264 | # """ 265 | # keypoints: B x 68*2 266 | 267 | # return value: B x 4 (t, b, l, r) 268 | 269 | # Compute a very rough bounding box approximate from 68 keypoints. 270 | # """ 271 | # x, y = keypoints.float().view(-1, 68, 2).transpose(0, 2) 272 | 273 | # face_height = y[8] - y[27] 274 | # b = y[8] + face_height * 0.2 275 | # t = y[27] - face_height * 0.47 276 | 277 | # midpoint_x = (x.min(dim=0)[0] + x.max(dim=0)[0]) / 2 278 | # half_height = (b - t) * 0.5 279 | 280 | # l = midpoint_x - half_height 281 | # r = midpoint_x + half_height 282 | 283 | 284 | # return torch.stack([t, b, l, r], dim=1) 285 | 286 | 287 | # def crop_image(image, bbox, return_mask=False): 288 | # """Crops area from image specified as bbox. Always returns area of size as bbox filling missing parts with zeros 289 | # Args: 290 | # image numpy array of shape (height, width, 3): input image 291 | # bbox tuple of size 4: input bbox (left, upper, right, lower) 292 | 293 | # Returns: 294 | # cropped_image numpy array of shape (height, width, 3): resulting cropped image 295 | 296 | # """ 297 | # image_pil = Image.fromarray(image, mode='RGB') 298 | # image_pil = image_pil.crop(bbox) 299 | 300 | # if return_mask: 301 | # mask = np.ones_like(image) 302 | # mask_pil = Image.fromarray(mask, mode='RGB') 303 | # mask_pil = mask_pil.crop(bbox) 304 | # return np.asarray(image_pil), np.asarray(mask_pil)[..., :1] 305 | 306 | # return np.asarray(image_pil) 307 | 308 | 309 | # def crop_resize_image(image, ltrb, new_image_size=None): 310 | # if new_image_size is None: 311 | # new_image_size = image.shape 312 | # size = new_image_size 313 | 314 | # l, t, r, b = ltrb.t().float() 315 | # batch_size, num_channels, h, w = image.shape 316 | 317 | # affine_matrix = torch.zeros(batch_size, 2, 3, dtype=torch.float32, device=image.device) 318 | # affine_matrix[:, 0, 0] = (r - l) / w 319 | # affine_matrix[:, 1, 1] = (b - t) / h 320 | # affine_matrix[:, 0, 2] = (l + r) / w - 1 321 | # affine_matrix[:, 1, 2] = (t + b) / h - 1 322 | 323 | # output_shape = (batch_size, num_channels) + (size, size) 324 | # grid = torch.affine_grid_generator(affine_matrix, output_shape, align_corners=True) 325 | # grid = grid.to(image.dtype) 326 | # return torch.nn.functional.grid_sample(image, grid, 'bilinear', 'reflection', align_corners=True) 327 | 328 | 329 | 330 | 331 | 332 | # def crop_resize_coords(coords, ltrb, imsize): 333 | # invalid_mask = coords.sum(axis=-1) <= 0 334 | # crop_j, crop_i, r, b = ltrb 335 | # crop_sz = r - crop_j 336 | 337 | # resize_ratio = imsize / crop_sz 338 | # cropped = coords.copy() 339 | # cropped[:, 1] -= crop_i 340 | # cropped[:, 0] -= crop_j 341 | # cropped[:, :2] *= resize_ratio 342 | # cropped[invalid_mask] = -1 343 | # return cropped 344 | 345 | 346 | # def crop_resize_verts(verts, K, ltrb, new_image_size): 347 | # # it's supposed that it smplifyx's verts are in trivial camera coordinates 348 | # fx, fy, cx, cy = 1.0, 1.0, 0.0, 0.0 349 | # # crop 350 | # cx, cy = cx - ltrb[:, 0], cy - ltrb[:, 1] 351 | # # scale 352 | # width, height = ltrb[:, 2] - ltrb[:, 0], ltrb[:, 3] - ltrb[:, 1] 353 | # new_h = new_w = new_image_size 354 | 355 | # h_scale, w_scale = new_w / width.float(), new_h / height.float() 356 | 357 | # fx, fy = fx * w_scale, fy * h_scale 358 | # cx, cy = cx * w_scale, cy * h_scale 359 | 360 | # # update verts 361 | # B = verts.shape[0] 362 | # K_upd = torch.eye(3) 363 | # K_upd = torch.stack([K_upd] * B, dim=0).to(verts.device) 364 | 365 | # K_upd[:, 0, 0] = fx 366 | # K_upd[:, 1, 1] = fy 367 | # K_upd[:, 0, 2] = cx 368 | # K_upd[:, 1, 2] = cy 369 | 370 | # verts_cropped = torch.bmm(verts, K_upd.transpose(1, 2)) 371 | # K_cropped = torch.bmm(K_upd, K) 372 | 373 | # return verts_cropped, K_cropped 374 | 375 | 376 | # def get_bbox_from_smplifyx(verts): 377 | # verts_projected = (verts / verts[:, 2:])[:, :2] 378 | 379 | # # get bbox in format (left, top, right, bottom) 380 | # l = np.min(verts_projected[:, 0]) 381 | # t = np.min(verts_projected[:, 1]) 382 | # r = np.max(verts_projected[:, 0]) 383 | # b = np.max(verts_projected[:, 1]) 384 | 385 | # return (l, t, r, b) 386 | 387 | 388 | # def update_smplifyx_after_crop_and_resize(verts, K, bbox, image_shape, new_image_shape): 389 | # # it's supposed that it smplifyx's verts are in trivial camera coordinates 390 | # fx, fy, cx, cy = 1.0, 1.0, 0.0, 0.0 391 | 392 | # # crop 393 | # cx, cy = cx - bbox[0], cy - bbox[1] 394 | 395 | # # scale 396 | # h, w = image_shape 397 | # new_h, new_w = new_image_shape 398 | 399 | # h_scale, w_scale = new_w / w, new_h / h 400 | 401 | # fx, fy = fx * w_scale, fy * h_scale 402 | # cx, cy = cx * w_scale, cy * h_scale 403 | 404 | # # update verts 405 | # new_K = np.array([ 406 | # [fx, 0.0, cx], 407 | # [0.0, fy, cy], 408 | # [0.0, 0.0, 1.0] 409 | # ]) 410 | 411 | # return verts @ new_K.T, new_K @ K -------------------------------------------------------------------------------- /models/styleganv2/modules.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn, autograd 7 | from torch.nn import functional as F 8 | 9 | from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d 10 | 11 | 12 | def d_logistic_loss(real_pred, fake_pred): 13 | real_loss = F.softplus(-real_pred) 14 | fake_loss = F.softplus(fake_pred) 15 | 16 | return real_loss.mean() + fake_loss.mean() 17 | 18 | 19 | def d_r1_loss(real_pred, real_img): 20 | grad_real, = autograd.grad( 21 | outputs=real_pred.sum(), inputs=real_img, create_graph=True 22 | ) 23 | grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() 24 | 25 | return grad_penalty 26 | 27 | 28 | def g_nonsaturating_loss(fake_pred): 29 | loss = F.softplus(-fake_pred).mean() 30 | 31 | return loss 32 | 33 | 34 | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): 35 | noise = torch.randn_like(fake_img) / math.sqrt( 36 | fake_img.shape[2] * fake_img.shape[3] 37 | ) 38 | grad, = autograd.grad( 39 | outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True 40 | ) 41 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) 42 | 43 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) 44 | 45 | path_penalty = (path_lengths - path_mean).pow(2).mean() 46 | 47 | return path_penalty, path_mean.detach(), path_lengths 48 | 49 | 50 | def make_noise(batch, latent_dim, n_noise, device): 51 | if n_noise == 1: 52 | return torch.randn(batch, latent_dim, device=device) 53 | 54 | noises = torch.randn(n_noise, batch, latent_dim, device=device).unbind(0) 55 | 56 | return noises 57 | 58 | 59 | def mixing_noise(batch, latent_dim, prob, device): 60 | if prob > 0 and random.random() < prob: 61 | return make_noise(batch, latent_dim, 2, device) 62 | 63 | else: 64 | return [make_noise(batch, latent_dim, 1, device)] 65 | 66 | 67 | class PixelNorm(nn.Module): 68 | def __init__(self): 69 | super().__init__() 70 | 71 | def forward(self, input): 72 | return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8) 73 | 74 | 75 | def make_kernel(k): 76 | k = torch.tensor(k, dtype=torch.float32) 77 | 78 | if k.ndim == 1: 79 | k = k[None, :] * k[:, None] 80 | 81 | k /= k.sum() 82 | 83 | return k 84 | 85 | 86 | class Upsample(nn.Module): 87 | def __init__(self, kernel, factor=2): 88 | super().__init__() 89 | 90 | self.factor = factor 91 | kernel = make_kernel(kernel) * (factor ** 2) 92 | self.register_buffer('kernel', kernel) 93 | 94 | p = kernel.shape[0] - factor 95 | 96 | pad0 = (p + 1) // 2 + factor - 1 97 | pad1 = p // 2 98 | 99 | self.pad = (pad0, pad1) 100 | 101 | def forward(self, input): 102 | out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad) 103 | 104 | return out 105 | 106 | 107 | class Downsample(nn.Module): 108 | def __init__(self, kernel, factor=2): 109 | super().__init__() 110 | 111 | self.factor = factor 112 | kernel = make_kernel(kernel) 113 | self.register_buffer('kernel', kernel) 114 | 115 | p = kernel.shape[0] - factor 116 | 117 | pad0 = (p + 1) // 2 118 | pad1 = p // 2 119 | 120 | self.pad = (pad0, pad1) 121 | 122 | def forward(self, input): 123 | out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad) 124 | 125 | return out 126 | 127 | 128 | class Blur(nn.Module): 129 | def __init__(self, kernel, pad, upsample_factor=1): 130 | super().__init__() 131 | 132 | kernel = make_kernel(kernel) 133 | 134 | if upsample_factor > 1: 135 | kernel = kernel * (upsample_factor ** 2) 136 | 137 | self.register_buffer('kernel', kernel) 138 | 139 | self.pad = pad 140 | 141 | def forward(self, input): 142 | out = upfirdn2d(input, self.kernel, pad=self.pad) 143 | 144 | return out 145 | 146 | 147 | class EqualConv2d(nn.Module): 148 | def __init__( 149 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, replicate_pad=False, 150 | ): 151 | super().__init__() 152 | 153 | self.weight = nn.Parameter( 154 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 155 | ) 156 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 157 | 158 | self.stride = stride 159 | self.padding = padding 160 | # self.padval = padval 161 | if bias: 162 | self.bias = nn.Parameter(torch.zeros(out_channel)) 163 | else: 164 | self.bias = None 165 | 166 | self.replicate_pad = replicate_pad 167 | 168 | # self.i=0 169 | 170 | def forward(self, input): 171 | 172 | pad = self.padding 173 | if self.replicate_pad: 174 | input = F.pad(input, (pad, pad, pad, pad), mode='replicate') 175 | pad = 0 176 | 177 | out = F.conv2d( 178 | input, 179 | self.weight * self.scale, 180 | bias=self.bias, 181 | stride=self.stride, 182 | padding=pad, 183 | ) 184 | 185 | return out 186 | 187 | def __repr__(self): 188 | return ( 189 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' 190 | f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' 191 | ) 192 | 193 | 194 | class EqualConv2dSame(nn.Module): 195 | def __init__( 196 | self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, 197 | ): 198 | super().__init__() 199 | 200 | self.weight = nn.Parameter( 201 | torch.randn(out_channel, in_channel, kernel_size, kernel_size) 202 | ) 203 | self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2) 204 | 205 | self.stride = stride 206 | self.padding = padding 207 | # self.padval = padval 208 | if bias: 209 | self.bias = nn.Parameter(torch.zeros(out_channel)) 210 | else: 211 | self.bias = None 212 | 213 | # self.i=0 214 | 215 | def forward(self, input): 216 | pad = self.padding 217 | input = F.pad(input, (pad, pad, pad, pad), mode='replicate') 218 | 219 | out = F.conv2d( 220 | input, 221 | self.weight * self.scale, 222 | bias=self.bias, 223 | stride=self.stride, 224 | ) 225 | 226 | return out 227 | 228 | def __repr__(self): 229 | return ( 230 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},' 231 | f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})' 232 | ) 233 | 234 | 235 | class EqualLinear(nn.Module): 236 | def __init__( 237 | self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None 238 | ): 239 | super().__init__() 240 | 241 | self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul)) 242 | 243 | if bias: 244 | self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init)) 245 | 246 | else: 247 | self.bias = None 248 | 249 | self.activation = activation 250 | 251 | self.scale = (1 / math.sqrt(in_dim)) * lr_mul 252 | self.lr_mul = lr_mul 253 | 254 | def forward(self, input): 255 | 256 | if self.activation: 257 | out = F.linear(input, self.weight * self.scale) 258 | out = fused_leaky_relu(out, self.bias * self.lr_mul) 259 | 260 | else: 261 | out = F.linear( 262 | input, self.weight * self.scale, bias=self.bias * self.lr_mul 263 | ) 264 | 265 | return out 266 | 267 | def __repr__(self): 268 | return ( 269 | f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})' 270 | ) 271 | 272 | 273 | class ScaledLeakyReLU(nn.Module): 274 | def __init__(self, negative_slope=0.2): 275 | super().__init__() 276 | 277 | self.negative_slope = negative_slope 278 | 279 | def forward(self, input): 280 | out = F.leaky_relu(input, negative_slope=self.negative_slope) 281 | 282 | return out * math.sqrt(2) 283 | 284 | 285 | class NoiseInjection(nn.Module): 286 | def __init__(self): 287 | super().__init__() 288 | 289 | self.weight = nn.Parameter(torch.zeros(1)) 290 | 291 | def forward(self, image, noise=None): 292 | if noise is None: 293 | batch, _, height, width = image.shape 294 | noise = image.new_empty(batch, 1, height, width).normal_() 295 | 296 | return image + self.weight * noise 297 | 298 | 299 | class ConstantInput(nn.Module): 300 | def __init__(self, channel, size=4): 301 | super().__init__() 302 | 303 | self.input = nn.Parameter(torch.randn(1, channel, size, size)) 304 | 305 | def forward(self, input): 306 | batch = input.shape[0] 307 | out = self.input.repeat(batch, 1, 1, 1) 308 | 309 | return out 310 | 311 | 312 | # ============ Normalization 313 | 314 | class AdaTexSpade(nn.Module): 315 | def __init__(self, num_features, segm_tensor, style_dim, kernel_size=1, eps=1e-4): 316 | super().__init__() 317 | self.num_features = num_features 318 | self.weight = self.bias = None 319 | self.norm_layer = nn.InstanceNorm2d(num_features, eps=eps, affine=False) 320 | 321 | self.segm_tensor = nn.Parameter(segm_tensor, requires_grad=False) 322 | n_segmchannels = self.segm_tensor.shape[1] 323 | in_channel = style_dim + n_segmchannels 324 | 325 | self.style_conv = EqualConv2d( 326 | in_channel, 327 | num_features, 328 | kernel_size, 329 | padding=kernel_size // 2, 330 | ) 331 | 332 | def forward(self, input, style): 333 | B, C, H, W = input.shape 334 | out = self.norm_layer(input) 335 | 336 | sB, sC = style.shape 337 | style = style[..., None, None] 338 | style = style.expand(sB, sC, H, W) 339 | segm_tensor = self.segm_tensor.expand(sB, *self.segm_tensor.shape[1:]) 340 | style = torch.cat([style, segm_tensor], dim=1) 341 | gammas = self.style_conv(style) 342 | 343 | out = out * gammas 344 | return out 345 | 346 | 347 | class AdaIn(nn.Module): 348 | def __init__(self, num_features, style_dim, eps=1e-4): 349 | super().__init__() 350 | self.num_features = num_features 351 | self.weight = self.bias = None 352 | self.norm_layer = nn.InstanceNorm2d(num_features, eps=eps, affine=False) 353 | self.modulation = EqualLinear(style_dim, num_features, bias_init=1) 354 | 355 | def forward(self, input, style): 356 | B, C, H, W = input.shape 357 | out = self.norm_layer(input) 358 | gammas = self.modulation(style) 359 | gammas = gammas[..., None, None] 360 | out = out * gammas 361 | return out 362 | 363 | 364 | class ModulatedSiren2d(nn.Module): 365 | def __init__( 366 | self, 367 | in_channel, 368 | out_channel, 369 | style_dim, 370 | demodulate=True, 371 | is_first=False, 372 | omega_0=30 373 | ): 374 | super().__init__() 375 | 376 | self.eps = 1e-8 377 | self.kernel_size = 1 378 | self.in_channel = in_channel 379 | self.out_channel = out_channel 380 | 381 | fan_in = in_channel 382 | self.scale = 1 / math.sqrt(fan_in) 383 | 384 | if is_first: 385 | self.demod_scale = 3 * fan_in 386 | else: 387 | self.demod_scale = omega_0 ** 2 / 2 388 | self.omega0 = omega_0 389 | 390 | self.weight = nn.Parameter( 391 | torch.randn(1, out_channel, in_channel, 1, 1) 392 | ) 393 | 394 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 395 | 396 | self.demodulate = demodulate 397 | 398 | def __repr__(self): 399 | return ( 400 | f'{self.__class__.__name__}({self.in_channel}, {self.out_channel})' 401 | ) 402 | 403 | def forward(self, input, style): 404 | batch, in_channel, height, width = input.shape 405 | 406 | style = self.modulation(style) 407 | style = style.view(batch, 1, in_channel, 1, 1) 408 | weight = self.scale * self.weight * style 409 | 410 | if self.demodulate: 411 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) * self.demod_scale + 1e-8) 412 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 413 | 414 | weight = weight.view( 415 | batch * self.out_channel, in_channel, 1, 1 416 | ) 417 | 418 | input = input.view(1, batch * in_channel, height, width) 419 | out = F.conv2d(input, weight, groups=batch) 420 | _, _, height, width = out.shape 421 | out = out.view(batch, self.out_channel, height, width) 422 | 423 | out = out * self.omega0 424 | 425 | return out 426 | 427 | 428 | class ModulatedConv2d(nn.Module): 429 | def __init__( 430 | self, 431 | in_channel, 432 | out_channel, 433 | kernel_size, 434 | style_dim, 435 | demodulate=True, 436 | upsample=False, 437 | downsample=False, 438 | blur_kernel=[1, 3, 3, 1], 439 | ): 440 | super().__init__() 441 | 442 | self.eps = 1e-8 443 | self.kernel_size = kernel_size 444 | self.in_channel = in_channel 445 | self.out_channel = out_channel 446 | self.upsample = upsample 447 | self.downsample = downsample 448 | 449 | if upsample: 450 | factor = 2 451 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 452 | pad0 = (p + 1) // 2 + factor - 1 453 | pad1 = p // 2 + 1 454 | 455 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 456 | 457 | if downsample: 458 | factor = 2 459 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 460 | pad0 = (p + 1) // 2 461 | pad1 = p // 2 462 | 463 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 464 | 465 | fan_in = in_channel * kernel_size ** 2 466 | self.scale = 1 / math.sqrt(fan_in) 467 | self.padding = kernel_size // 2 468 | 469 | self.weight = nn.Parameter( 470 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 471 | ) 472 | 473 | self.modulation = EqualLinear(style_dim, in_channel, bias_init=1) 474 | 475 | self.demodulate = demodulate 476 | 477 | def __repr__(self): 478 | return ( 479 | f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' 480 | f'upsample={self.upsample}, downsample={self.downsample})' 481 | ) 482 | 483 | def forward(self, input, style): 484 | batch, in_channel, height, width = input.shape 485 | 486 | style = self.modulation(style) 487 | style = style.view(batch, 1, in_channel, 1, 1) 488 | weight = self.scale * self.weight * style 489 | 490 | if self.demodulate: 491 | demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8) 492 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 493 | 494 | weight = weight.view( 495 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 496 | ) 497 | 498 | if self.upsample: 499 | input = input.view(1, batch * in_channel, height, width) 500 | weight = weight.view( 501 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 502 | ) 503 | weight = weight.transpose(1, 2).reshape( 504 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 505 | ) 506 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) 507 | _, _, height, width = out.shape 508 | out = out.view(batch, self.out_channel, height, width) 509 | out = self.blur(out) 510 | 511 | elif self.downsample: 512 | input = self.blur(input) 513 | _, _, height, width = input.shape 514 | input = input.view(1, batch * in_channel, height, width) 515 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) 516 | _, _, height, width = out.shape 517 | out = out.view(batch, self.out_channel, height, width) 518 | 519 | else: 520 | input = input.view(1, batch * in_channel, height, width) 521 | out = F.conv2d(input, weight, padding=self.padding, groups=batch) 522 | _, _, height, width = out.shape 523 | out = out.view(batch, self.out_channel, height, width) 524 | 525 | return out 526 | 527 | 528 | class ModulatedConv2dStyleTensor(nn.Module): 529 | def __init__( 530 | self, 531 | in_channel, 532 | out_channel, 533 | kernel_size, 534 | style_dim, 535 | segm_tensor, 536 | demodulate=True, 537 | upsample=False, 538 | downsample=False, 539 | blur_kernel=[1, 3, 3, 1], 540 | style_ks=1, 541 | norm_style=True 542 | ): 543 | super().__init__() 544 | 545 | self.eps = 1e-8 546 | self.kernel_size = kernel_size 547 | self.in_channel = in_channel 548 | self.out_channel = out_channel 549 | self.upsample = upsample 550 | self.downsample = downsample 551 | 552 | self.norm_style = norm_style 553 | 554 | self.segm_tensor = nn.Parameter(segm_tensor, requires_grad=False) 555 | n_segmchannels = self.segm_tensor.shape[1] 556 | style_in_channel = style_dim + n_segmchannels 557 | 558 | self.style_conv = EqualConv2d( 559 | style_in_channel, 560 | in_channel, 561 | style_ks, 562 | padding=style_ks // 2, 563 | ) 564 | 565 | if upsample: 566 | factor = 2 567 | p = (len(blur_kernel) - factor) - (kernel_size - 1) 568 | pad0 = (p + 1) // 2 + factor - 1 569 | pad1 = p // 2 + 1 570 | 571 | self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor) 572 | 573 | if downsample: 574 | factor = 2 575 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 576 | pad0 = (p + 1) // 2 577 | pad1 = p // 2 578 | 579 | self.blur = Blur(blur_kernel, pad=(pad0, pad1)) 580 | 581 | fan_in = in_channel * kernel_size ** 2 582 | self.scale = 1 / math.sqrt(fan_in) 583 | self.padding = kernel_size // 2 584 | 585 | self.weight = nn.Parameter( 586 | torch.randn(1, out_channel, in_channel, kernel_size, kernel_size) 587 | ) 588 | 589 | self.demodulate = demodulate 590 | 591 | def __repr__(self): 592 | return ( 593 | f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, ' 594 | f'upsample={self.upsample}, downsample={self.downsample})' 595 | ) 596 | 597 | def forward(self, input, style): 598 | batch, in_channel, height, width = input.shape 599 | 600 | input_mean = input.mean().item() 601 | input_std = input.std().item() 602 | 603 | sB, sC = style.shape 604 | style = style[..., None, None] 605 | style = style.expand(sB, sC, height, width) 606 | segm_tensor = self.segm_tensor.expand(sB, *self.segm_tensor.shape[1:]) 607 | style = torch.cat([style, segm_tensor], dim=1) 608 | 609 | style = self.style_conv(style) 610 | 611 | style_scale = 1. / math.sqrt(height * width) 612 | 613 | style_mean = style.mean(dim=(2, 3), keepdim=True) 614 | style_mean = style_mean.view(batch, 1, in_channel, 1, 1) 615 | 616 | weight = self.scale * self.weight 617 | weight = weight.expand(batch, *weight.shape[1:]) 618 | weight_demod = weight * style_mean 619 | 620 | if self.demodulate: 621 | demod = torch.rsqrt(weight_demod.pow(2).sum([2, 3, 4]) + 1e-8) 622 | weight = weight * demod.view(batch, self.out_channel, 1, 1, 1) 623 | 624 | weight = weight.reshape( 625 | batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size 626 | ) 627 | 628 | input = input * style 629 | 630 | if self.upsample: 631 | input = input.view(1, batch * in_channel, height, width) 632 | weight = weight.view( 633 | batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size 634 | ) 635 | weight = weight.transpose(1, 2).reshape( 636 | batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size 637 | ) 638 | out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch) 639 | _, _, height, width = out.shape 640 | out = out.view(batch, self.out_channel, height, width) # * self.scale * style_scale 641 | out = self.blur(out) 642 | 643 | elif self.downsample: 644 | input = self.blur(input) 645 | _, _, height, width = input.shape 646 | input = input.view(1, batch * in_channel, height, width) 647 | out = F.conv2d(input, weight, padding=0, stride=2, groups=batch) # * self.scale * style_scale 648 | _, _, height, width = out.shape 649 | out = out.view(batch, self.out_channel, height, width) 650 | 651 | else: 652 | input = input.view(1, batch * in_channel, height, width) 653 | out = F.conv2d(input, weight, padding=self.padding, groups=batch) # * self.scale * style_scale 654 | _, _, height, width = out.shape 655 | out = out.view(batch, self.out_channel, height, width) 656 | 657 | if self.norm_style: 658 | out = out * self.scale * style_scale 659 | 660 | out_mean = out.mean().item() 661 | out_std = out.std().item() 662 | # print() 663 | # print(input_std, input_mean) 664 | # print(out_std, out_mean) 665 | # if input_std != 0: 666 | # print('std shift:', out_std / input_std) 667 | # if input_mean != 0: 668 | # print('std shift norm:', (out_std / input_std) / input_mean) 669 | 670 | return out 671 | 672 | 673 | # ============ StyledConv 674 | 675 | class Sgv1StyledConv(nn.Module): 676 | def __init__( 677 | self, 678 | in_channel, 679 | out_channel, 680 | kernel_size, 681 | style_dim, 682 | upsample=False, 683 | blur_kernel=[1, 3, 3, 1] 684 | ): 685 | super().__init__() 686 | 687 | self.conv = EqualConv2d( 688 | in_channel, 689 | out_channel, 690 | kernel_size, 691 | padding=kernel_size // 2, 692 | ) 693 | 694 | self.norm = AdaIn(in_channel, style_dim) 695 | 696 | self.noise = NoiseInjection() 697 | self.activate = FusedLeakyReLU(in_channel) 698 | 699 | if upsample: 700 | self.up = Upsample(blur_kernel) 701 | else: 702 | self.up = None 703 | 704 | def forward(self, input, style, noise=None): 705 | out = self.noise(input, noise=noise) 706 | out = self.activate(out) 707 | out = self.norm(out, style) 708 | if self.up is not None: 709 | out = self.up(out) 710 | out = self.conv(out) 711 | return out 712 | 713 | 714 | class AdaTexStyledConv(nn.Module): 715 | def __init__( 716 | self, 717 | in_channel, 718 | out_channel, 719 | kernel_size, 720 | style_dim, 721 | segm_tensor, 722 | upsample=False, 723 | blur_kernel=[1, 3, 3, 1], 724 | add_input=None, 725 | style_ks=1, 726 | ): 727 | super().__init__() 728 | 729 | self.add_input = add_input 730 | if add_input is not None: 731 | in_channel = in_channel + add_input.shape[1] 732 | self.add_input = nn.Parameter(add_input, requires_grad=False) 733 | 734 | self.conv = EqualConv2d( 735 | in_channel, 736 | out_channel, 737 | kernel_size, 738 | padding=kernel_size // 2, 739 | ) 740 | 741 | self.norm = AdaTexSpade(in_channel, segm_tensor, style_dim, kernel_size=style_ks) 742 | 743 | self.noise = NoiseInjection() 744 | self.activate = FusedLeakyReLU(in_channel) 745 | 746 | if upsample: 747 | self.up = Upsample(blur_kernel) 748 | else: 749 | self.up = None 750 | 751 | def forward(self, input, style, noise=None): 752 | if self.add_input is not None: 753 | B = input.shape[0] 754 | ainp = self.add_input.repeat(B, 1, 1, 1) 755 | input = torch.cat([input, ainp], dim=1) 756 | 757 | out = self.noise(input, noise=noise) 758 | out = self.activate(out) 759 | out = self.norm(out, style) 760 | if self.up is not None: 761 | out = self.up(out) 762 | out = self.conv(out) 763 | return out 764 | 765 | 766 | class StyledConvStyleTensor(nn.Module): 767 | def __init__( 768 | self, 769 | in_channel, 770 | out_channel, 771 | kernel_size, 772 | style_dim, 773 | segm_tensor, 774 | upsample=False, 775 | blur_kernel=[1, 3, 3, 1], 776 | demodulate=True, 777 | style_ks=1, 778 | norm_style=True 779 | ): 780 | super().__init__() 781 | 782 | self.conv = ModulatedConv2dStyleTensor( 783 | in_channel, 784 | out_channel, 785 | kernel_size, 786 | style_dim, 787 | segm_tensor, 788 | upsample=upsample, 789 | blur_kernel=blur_kernel, 790 | demodulate=demodulate, 791 | style_ks=style_ks, 792 | norm_style=norm_style 793 | ) 794 | 795 | self.noise = NoiseInjection() 796 | # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 797 | # self.activate = ScaledLeakyReLU(0.2) 798 | self.activate = FusedLeakyReLU(out_channel) 799 | 800 | def forward(self, input, style, noise=None): 801 | out = self.conv(input, style) 802 | out = self.noise(out, noise=noise) 803 | # out = out + self.bias 804 | out = self.activate(out) 805 | 806 | return out 807 | 808 | 809 | class StyledConv(nn.Module): 810 | def __init__( 811 | self, 812 | in_channel, 813 | out_channel, 814 | kernel_size, 815 | style_dim, 816 | upsample=False, 817 | blur_kernel=[1, 3, 3, 1], 818 | demodulate=True, 819 | ): 820 | super().__init__() 821 | 822 | self.conv = ModulatedConv2d( 823 | in_channel, 824 | out_channel, 825 | kernel_size, 826 | style_dim, 827 | upsample=upsample, 828 | blur_kernel=blur_kernel, 829 | demodulate=demodulate, 830 | ) 831 | 832 | self.noise = NoiseInjection() 833 | self.activate = FusedLeakyReLU(out_channel) 834 | 835 | def forward(self, input, style, noise=None): 836 | out = self.conv(input, style) 837 | out = self.noise(out, noise=noise) 838 | # out = out + self.bias 839 | out = self.activate(out) 840 | 841 | return out 842 | 843 | 844 | class StyledSiren(nn.Module): 845 | def __init__( 846 | self, 847 | in_channel, 848 | out_channel, 849 | style_dim, 850 | demodulate=True, 851 | is_first=False, 852 | omega_0=30. 853 | ): 854 | super().__init__() 855 | 856 | self.conv = ModulatedSiren2d( 857 | in_channel, 858 | out_channel, 859 | style_dim, 860 | demodulate=demodulate, 861 | is_first=is_first, 862 | omega_0=omega_0 863 | ) 864 | 865 | self.noise = NoiseInjection() 866 | 867 | def forward(self, input, style, noise=None): 868 | out = self.conv(input, style) 869 | out = self.noise(out, noise=noise) 870 | # out = out + self.bias 871 | out = torch.sin(out) 872 | 873 | return out 874 | 875 | 876 | class StyledConvAInp(nn.Module): 877 | def __init__( 878 | self, 879 | in_channel, 880 | out_channel, 881 | kernel_size, 882 | style_dim, 883 | upsample=False, 884 | blur_kernel=[1, 3, 3, 1], 885 | demodulate=True, 886 | add_input=None, 887 | ainp_trainable=False 888 | ): 889 | super().__init__() 890 | 891 | self.add_input = add_input 892 | if add_input is not None: 893 | in_channel = in_channel + add_input.shape[1] 894 | self.add_input = nn.Parameter(add_input, requires_grad=ainp_trainable) 895 | 896 | self.conv = ModulatedConv2d( 897 | in_channel, 898 | out_channel, 899 | kernel_size, 900 | style_dim, 901 | upsample=upsample, 902 | blur_kernel=blur_kernel, 903 | demodulate=demodulate, 904 | ) 905 | 906 | self.noise = NoiseInjection() 907 | self.activate = FusedLeakyReLU(out_channel) 908 | 909 | def forward(self, input, style, noise=None): 910 | if self.add_input is not None: 911 | B = input.shape[0] 912 | ainp = self.add_input.repeat(B, 1, 1, 1) 913 | input = torch.cat([input, ainp], dim=1) 914 | 915 | out = self.conv(input, style) 916 | out = self.noise(out, noise=noise) 917 | out = self.activate(out) 918 | 919 | return out 920 | 921 | 922 | class SirenResBlock(nn.Module): 923 | def __init__(self, in_channel, out_channel, style_dim, demodulate=True, is_first=False, omega_0=30.): 924 | super().__init__() 925 | 926 | self.conv1 = StyledSiren(in_channel, out_channel, style_dim, demodulate=demodulate, is_first=is_first, 927 | omega_0=omega_0) 928 | self.conv2 = StyledSiren(out_channel, out_channel, style_dim, demodulate=demodulate, omega_0=omega_0) 929 | 930 | self.skip = SineConv1x1(in_channel, out_channel, omega_0=omega_0) 931 | 932 | def forward(self, input, latent): 933 | out = self.conv1(input, latent) 934 | out = self.conv2(out, latent) 935 | skip = self.skip(input) 936 | out = (out + skip) / math.sqrt(2) 937 | 938 | return out 939 | 940 | 941 | class StyledConv1x1ResBlock(nn.Module): 942 | def __init__(self, in_channel, out_channel, style_dim, 943 | demodulate=True): 944 | super().__init__() 945 | 946 | self.conv1 = StyledConv(in_channel, out_channel, 1, style_dim, demodulate=demodulate) 947 | self.conv2 = StyledConv(out_channel, out_channel, 1, style_dim, demodulate=demodulate) 948 | 949 | self.skip = ConvLayer(in_channel, out_channel, 1, bias=False) 950 | 951 | def forward(self, input, latent, noise=None): 952 | if type(noise) == list and len(noise) == 2: 953 | noise1 = noise[0] 954 | noise2 = noise[1] 955 | else: 956 | noise1 = noise 957 | noise2 = noise 958 | 959 | if latent.ndim == 3: 960 | latent1 = latent[:, 0] 961 | latent2 = latent[:, 1] 962 | else: 963 | latent1 = latent 964 | latent2 = latent 965 | 966 | out = self.conv1(input, latent1, noise=noise1) 967 | out = self.conv2(out, latent2, noise=noise2) 968 | skip = self.skip(input) 969 | out = (out + skip) / math.sqrt(2) 970 | 971 | return out 972 | 973 | 974 | # ============ ToRGB 975 | class Sgv1ToRGB(nn.Module): 976 | def __init__(self, in_channel, style_dim, out_channel=3, upsample=True, blur_kernel=[1, 3, 3, 1]): 977 | super().__init__() 978 | 979 | if upsample: 980 | self.upsample = Upsample(blur_kernel) 981 | 982 | self.norm = AdaIn(in_channel, style_dim) 983 | self.activate = FusedLeakyReLU(in_channel) 984 | self.conv = EqualConv2d(in_channel, out_channel, 1) 985 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 986 | 987 | def forward(self, input, style, skip=None, return_delta=False): 988 | out = self.activate(input) 989 | out = self.norm(out, style) 990 | out = self.conv(out) 991 | out = out + self.bias 992 | 993 | if skip is not None: 994 | skip = self.upsample(skip) 995 | delta = out 996 | out = out + skip 997 | 998 | if return_delta: 999 | return out, delta 1000 | else: 1001 | return out 1002 | 1003 | 1004 | class AdaTexToRGB(nn.Module): 1005 | def __init__(self, in_channel, style_dim, segm_tensor, out_channel=3, upsample=True, blur_kernel=[1, 3, 3, 1], 1006 | style_ks=1): 1007 | super().__init__() 1008 | 1009 | if upsample: 1010 | self.upsample = Upsample(blur_kernel) 1011 | 1012 | self.norm = AdaTexSpade(in_channel, segm_tensor, style_dim, kernel_size=style_ks) 1013 | self.activate = FusedLeakyReLU(in_channel) 1014 | self.conv = EqualConv2d(in_channel, out_channel, 1) 1015 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 1016 | 1017 | def forward(self, input, style, skip=None, return_delta=False): 1018 | out = self.activate(input) 1019 | out = self.norm(out, style) 1020 | out = self.conv(out) 1021 | out = out + self.bias 1022 | 1023 | if skip is not None: 1024 | skip = self.upsample(skip) 1025 | delta = out 1026 | out = out + skip 1027 | 1028 | if return_delta: 1029 | return out, delta 1030 | else: 1031 | return out 1032 | 1033 | 1034 | class ToRGBStyleTensor(nn.Module): 1035 | def __init__(self, 1036 | in_channel, 1037 | style_dim, 1038 | segm_tensor, 1039 | out_channel=3, 1040 | upsample=True, 1041 | blur_kernel=[1, 3, 3, 1], 1042 | style_ks=1, 1043 | norm_style=True): 1044 | super().__init__() 1045 | 1046 | if upsample: 1047 | self.upsample = Upsample(blur_kernel) 1048 | 1049 | self.conv = ModulatedConv2dStyleTensor(in_channel, out_channel, 1, style_dim, segm_tensor, demodulate=False, 1050 | style_ks=style_ks, norm_style=norm_style) 1051 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 1052 | 1053 | def forward(self, input, style, skip=None, return_delta=False): 1054 | out = self.conv(input, style) 1055 | out = out + self.bias 1056 | 1057 | if skip is not None: 1058 | skip = self.upsample(skip) 1059 | delta = out 1060 | out = out + skip 1061 | 1062 | if return_delta: 1063 | return out, delta 1064 | else: 1065 | return out 1066 | 1067 | 1068 | class ToRGB(nn.Module): 1069 | def __init__(self, in_channel, style_dim, out_channel=3, upsample=True, blur_kernel=[1, 3, 3, 1]): 1070 | super().__init__() 1071 | 1072 | if upsample: 1073 | self.upsample = Upsample(blur_kernel) 1074 | 1075 | self.conv = ModulatedConv2d(in_channel, out_channel, 1, style_dim, demodulate=False) 1076 | self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1)) 1077 | 1078 | def forward(self, input, style, skip=None, return_delta=False, alfa=1.): 1079 | out = self.conv(input, style) 1080 | out = out + self.bias 1081 | 1082 | if skip is not None: 1083 | skip = self.upsample(skip) 1084 | delta = out 1085 | out = alfa * out + skip 1086 | 1087 | if return_delta: 1088 | return out, delta 1089 | else: 1090 | return out 1091 | 1092 | 1093 | # ============ Other 1094 | 1095 | 1096 | class SineConv1x1(nn.Module): 1097 | def __init__( 1098 | self, 1099 | in_channel, 1100 | out_channel, 1101 | is_first=False, 1102 | omega_0=30, 1103 | ): 1104 | super().__init__() 1105 | self.omega_0 = omega_0 1106 | self.is_first = is_first 1107 | 1108 | self.in_channel = in_channel 1109 | self.weight = nn.Parameter( 1110 | torch.randn(out_channel, in_channel, 1, 1) 1111 | ) 1112 | 1113 | if is_first: 1114 | a = -1 / self.in_channel 1115 | else: 1116 | a = -np.sqrt(6 / self.in_channel) / self.omega_0 1117 | b = -a 1118 | 1119 | self.scale = np.abs(b - a) * math.sqrt(1 / 12) 1120 | 1121 | spectral_tex = torch.load('/Vol1/dbstore/datasets/a.grigorev/gent/smplx_spectral_texture_norm.pth').cuda() 1122 | spectral_tex_inp = spectral_tex[:, 2:100] 1123 | self.spectral_tex_mask = ((spectral_tex ** 2).sum(dim=1) > 0)[0] 1124 | self.sin_input = None 1125 | 1126 | def forward(self, x): 1127 | weight = self.weight * self.scale 1128 | out = F.conv2d(x, weight) 1129 | mask = self.spectral_tex_mask 1130 | out = torch.sin(self.omega_0 * out) 1131 | return out 1132 | 1133 | 1134 | class ConvLayer(nn.Sequential): 1135 | def __init__( 1136 | self, 1137 | in_channel, 1138 | out_channel, 1139 | kernel_size, 1140 | downsample=False, 1141 | blur_kernel=[1, 3, 3, 1], 1142 | bias=True, 1143 | activate=True, 1144 | replicate_pad=False 1145 | ): 1146 | layers = [] 1147 | 1148 | if downsample: 1149 | factor = 2 1150 | p = (len(blur_kernel) - factor) + (kernel_size - 1) 1151 | pad0 = (p + 1) // 2 1152 | pad1 = p // 2 1153 | 1154 | layers.append(Blur(blur_kernel, pad=(pad0, pad1))) 1155 | 1156 | stride = 2 1157 | self.padding = 0 1158 | 1159 | else: 1160 | stride = 1 1161 | self.padding = kernel_size // 2 1162 | 1163 | layers.append( 1164 | EqualConv2d( 1165 | in_channel, 1166 | out_channel, 1167 | kernel_size, 1168 | padding=self.padding, 1169 | stride=stride, 1170 | bias=bias and not activate, 1171 | replicate_pad=replicate_pad 1172 | ) 1173 | ) 1174 | 1175 | if activate: 1176 | if bias: 1177 | layers.append(FusedLeakyReLU(out_channel)) 1178 | 1179 | else: 1180 | layers.append(ScaledLeakyReLU(0.2)) 1181 | 1182 | super().__init__(*layers) 1183 | 1184 | 1185 | class ResBlock(nn.Module): 1186 | def __init__(self, in_channel, out_channel, blur_kernel=None, replicate_pad=False, ): 1187 | super().__init__() 1188 | 1189 | self.conv1 = ConvLayer(in_channel, in_channel, 3, replicate_pad=replicate_pad) 1190 | self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True, replicate_pad=replicate_pad) 1191 | 1192 | self.skip = ConvLayer( 1193 | in_channel, out_channel, 1, downsample=True, activate=False, bias=False, replicate_pad=replicate_pad 1194 | ) 1195 | 1196 | def forward(self, input): 1197 | out = self.conv1(input) 1198 | out = self.conv2(out) 1199 | 1200 | skip = self.skip(input) 1201 | out = (out + skip) / math.sqrt(2) 1202 | 1203 | return out 1204 | 1205 | # ==================== Old 1206 | --------------------------------------------------------------------------------