├── README.md ├── code ├── config │ ├── __init__.py │ └── defaults.py ├── configs │ └── configs.yml ├── modeling │ ├── UNet │ │ ├── __init__.py │ │ ├── unet_model.py │ │ └── unet_parts.py │ ├── __init__.py │ ├── model.py │ └── tree_render.py ├── models_lpf │ ├── __init__.py │ ├── alexnet.py │ ├── densenet.py │ ├── downsample.py │ ├── mobilenet.py │ ├── resnet.py │ └── vgg.py ├── render.py └── utils │ ├── __init__.py │ ├── bone_parsing.py │ ├── build_octree.py │ ├── llinear_transform.py │ ├── logger.py │ ├── ray_sampling.py │ ├── rendering.py │ ├── spherical_harmonics.py │ └── utils.py ├── medias ├── featured1.png └── overview_v3-1.png └── requirement.txt /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 |

4 | 5 |

Artemis: Articulated Neural Pets with Appearance and Motion Synthesis

6 | 7 |

8 | Logo 9 |

10 |

11 | SIGGRAPH, 2022 12 |
13 | Haimin Luo 14 | · 15 | Teng Xu 16 | · 17 | Yuheng Jiang 18 | · 19 | Chenglin Zhou 20 | · 21 | Qiwei Qiu 22 | · 23 | Yingliang Zhang 24 | · 25 | Wei Yang 26 | · 27 | Lan Xu 28 | · 29 | Jingyi Yu 30 |

31 | 32 |

33 | 34 | 35 | Paper PDF 36 | 37 | 38 | Project Page 39 | 40 | Youtube Video 41 | 42 |

43 |

44 | 45 |
46 | 47 | This repository contains a pytorch implementation and **D**ynamic **F**urry **A**nimal (DFA) dataset for the paper: [Artemis: Articulated Neural Pets with Appearance and Motion Synthesis](https://arxiv.org/abs/2202.05628). In this paper, we present **ARTEMIS**, a novel neural modeling and rendering pipeline for generating **ART**iculated neural pets with app**E**arance and **M**otion synthes**IS**.

48 | 49 | ## Overview 50 | ![Artemis_overview](medias/overview_v3-1.png "Magic Gardens") 51 | 52 | ## Installation 53 | Create a virtual environment and install requirements as follow 54 | ``` 55 | conda create -n artemis python=3.7 56 | conda activate artemis 57 | pip install -r requirement.txt 58 | ``` 59 | - Install pytorch3d following the official [installation steps](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md). 60 | - Install cuda based pytorch extension for animatable volumes: [svox_t](https://github.com/HaiminLuo/svox_t). 61 | 62 | 63 | The code is tested on Ubuntu 18.04 + Pytorch 1.12.1. 64 | 65 | ## Dynamic Furry Animals (DFA) Dataset 66 | The DFA dataset can be found at [DFA datasets](https://shanghaitecheducn-my.sharepoint.com/:f:/g/personal/luohm_shanghaitech_edu_cn/Et60lJpJdp5DoyQF7uzP6jgB_JEW4LIHixAyXEiVhHT3Vw?e=d09jtz), which contains multi-view renderings and skeletal motions of 9 high quality CGI furry animals. As described in Artemis, each dataset contains $1920\times1080$ RGBA images rendered under 36 cameras around the dynamic animal. 67 | 68 | The datasets (e.g., cat) are organized as follows: 69 | 70 | ``` 71 | cat 72 | ├── img 73 | │ └── run - Motion name. 74 | │ └── %d - The frame number, start from 0. 75 | │ └──img_%04d.jpg - RGB images for each view. view number start from 0. 76 | │ └──img_%04d_alpha.png - Alpha mattes for corresponding RGB image. 77 | │ └── ... 78 | │ 79 | ├── volumes 80 | │ └── coords_init.pth - voxel coordinates represent an animal in rest pose. 81 | │ └── volume_indices.pth - The indices of bones to which the voxels are bound. 82 | | └── volume_weights.pth - The skinning weights of the voxels. 83 | | └── radius.txt - The radius of the volume. 84 | | 85 | ├── bones 86 | │ └── run - Motion name. 87 | │ └── Bones_%04d.inf - The skeletal pose for each frame, start from 0. In each row, the 3x4 [R T] matrix is displayed in columns, with the third column followed by columns 1, 2, and 4. 88 | │ └── ... 89 | | └── bones_parents.npy - The parent bone index of each bone. 90 | | └── Bones_0000.inf - The rest pose. 91 | | 92 | ├── CamPose.inf - Camera extrinsics. In each row, the 3x4 [R T] matrix is displayed in columns, with the third column followed by columns 1, 2, and 4, where R*X^{camera}+T=X^{world}. 93 | │ 94 | └── Intrinsic.inf - Camera intrinsics. The format of each intrinsics is: "idx \n fx 0 cx \n 0 fy cy \n 0 0 1 \n \n" (idx starts from 0) 95 | │ 96 | └── sequences - Motion sequences. The format of each motion is: "motion_name frames\n" 97 | ``` 98 | 99 | ## Pre-trained model 100 | Our pre-trained NGI models (e.g., `cat`, `wolf`, `panda`, ...) described in our paper can be found at [pre-trained models](https://shanghaitecheducn-my.sharepoint.com/:f:/g/personal/luohm_shanghaitech_edu_cn/EgJC3IrKA1pPjYTFiO2rVswBW4QVHkV54TqGxcxMIr-Bvw?e=H3xhTG). 101 | 102 | 103 | ## Rendering 104 | We provide a rendering script to illustrate how to load motions, cameras, models and how to animate the NGI animals and render them. 105 | 106 | To render a RGBA image with specified camera view and motion pose: 107 | ``` 108 | cd code 109 | 110 | python render.py --config ../model/cat/configs.yml --model ../model/cat/model.pt --dataset ../dataset/cat --output_path ./output 111 | ``` 112 | 113 | To render an around-view video with dynamic motion and spiral camera trajectory: 114 | ``` 115 | cd code 116 | 117 | python render.py --config ../model/cat/configs.yml --model ../model/cat/model.pt --dataset ../dataset/cat --output_path ./output --render_video --camera_path ../model/cat/ 118 | ``` 119 | 120 | 121 | ## Citation 122 | If you find our code or paper helps, please consider citing: 123 | ``` 124 | @article{10.1145/3528223.3530086, 125 | author = {Luo, Haimin and Xu, Teng and Jiang, Yuheng and Zhou, Chenglin and Qiu, Qiwei and Zhang, Yingliang and Yang, Wei and Xu, Lan and Yu, Jingyi}, 126 | title = {Artemis: Articulated Neural Pets with Appearance and Motion Synthesis}, 127 | year = {2022}, 128 | issue_date = {July 2022}, 129 | publisher = {Association for Computing Machinery}, 130 | address = {New York, NY, USA}, 131 | volume = {41}, 132 | number = {4}, 133 | issn = {0730-0301}, 134 | url = {https://doi.org/10.1145/3528223.3530086}, 135 | doi = {10.1145/3528223.3530086}, 136 | journal = {ACM Trans. Graph.}, 137 | month = {jul}, 138 | articleno = {164}, 139 | numpages = {19}, 140 | keywords = {novel view syntheis, neural rendering, dynamic scene modeling, neural volumetric animal, motion synthesis, neural representation} 141 | } 142 | ``` 143 | And also consider citing another related and intresting work for high-quality photo-realistic rendering of real fuzzy objects: [Convolutional Neural Opacity Radiance Fields](https://www.computer.org/csdl/proceedings-article/iccp/2021/09466273/1uSSXDOinlu): 144 | ``` 145 | @INPROCEEDINGS {9466273, 146 | author = {H. Luo and A. Chen and Q. Zhang and B. Pang and M. Wu and L. Xu and J. Yu}, 147 | booktitle = {2021 IEEE International Conference on Computational Photography (ICCP)}, 148 | title = {Convolutional Neural Opacity Radiance Fields}, 149 | year = {2021}, 150 | volume = {}, 151 | issn = {}, 152 | pages = {1-12}, 153 | keywords = {training;photography;telepresence;image color analysis;computational modeling;entertainment industry;image capture}, 154 | doi = {10.1109/ICCP51581.2021.9466273}, 155 | url = {https://doi.ieeecomputersociety.org/10.1109/ICCP51581.2021.9466273}, 156 | publisher = {IEEE Computer Society}, 157 | address = {Los Alamitos, CA, USA}, 158 | month = {may} 159 | } 160 | ``` -------------------------------------------------------------------------------- /code/config/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .defaults import _C as cfg 8 | -------------------------------------------------------------------------------- /code/config/defaults.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | # ----------------------------------------------------------------------------- 4 | # Convention about Training / Test specific parameters 5 | # ----------------------------------------------------------------------------- 6 | # Whenever an argument can be either used for training or for testing, the 7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter, 8 | # or _TEST for a test-specific parameter. 9 | # For example, the number of images during training will be 10 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be 11 | # IMAGES_PER_BATCH_TEST 12 | 13 | # ----------------------------------------------------------------------------- 14 | # Config definition 15 | # ----------------------------------------------------------------------------- 16 | 17 | _C = CN() 18 | 19 | _C.MODEL = CN() 20 | _C.MODEL.DEVICE = "cuda" 21 | _C.MODEL.COARSE_RAY_SAMPLING = 64 22 | _C.MODEL.FINE_RAY_SAMPLING = 80 23 | _C.MODEL.INPORTANCE_RAY_SAMPLE = 0 24 | _C.MODEL.SAMPLE_METHOD = "NEAR_FAR" 25 | _C.MODEL.BOARDER_WEIGHT = 1e10 26 | _C.MODEL.SAME_SPACENET = False 27 | _C.MODEL.DEPTH_FIELD = 1e-1 28 | _C.MODEL.DEPTH_RATIO = 1e-1 29 | _C.MODEL.SAMPLE_INF = False 30 | _C.MODEL.USE_MOTION = False 31 | 32 | _C.MODEL.UNET_DIR = False 33 | _C.MODEL.USE_SH = True 34 | _C.MODEL.SH_DIM = 9 35 | _C.MODEL.SH_FEAT_DIM = 3 36 | _C.MODEL.BONE_FEAT_DIM = 0 37 | _C.MODEL.UNET_FEAT = False 38 | _C.MODEL.USE_RENDER_NET = True 39 | _C.MODEL.USE_LIGHT_RENDERER = False 40 | _C.MODEL.USE_BONE_NET = False 41 | _C.MODEL.USE_WARP_NET = False 42 | 43 | _C.MODEL.SAMPLE_NET = False 44 | _C.MODEL.OPACITY_NET = False 45 | _C.MODEL.NOISE_STD = 0.0 46 | 47 | _C.MODEL.TKERNEL_INC_RAW = True 48 | _C.MODEL.ENCODE_POS_DIM = 10 49 | _C.MODEL.ENCODE_DIR_DIM = 4 50 | _C.MODEL.GAUSSIAN_SIGMA = 10. 51 | _C.MODEL.KERNEL_TYPE = "POS" 52 | _C.MODEL.PAIR_TRAINING = False 53 | _C.MODEL.TREE_DEPTH = 8 54 | _C.MODEL.RANDOM_INI = True 55 | 56 | # ----------------------------------------------------------------------------- 57 | # INPUT 58 | # ----------------------------------------------------------------------------- 59 | _C.INPUT = CN() 60 | # Size of the image during training 61 | _C.INPUT.SIZE_TRAIN = [400, 250] 62 | # Size of the image during test 63 | _C.INPUT.SIZE_TEST = [400, 250] 64 | # Minimum scale for the image during training 65 | _C.INPUT.MIN_SCALE_TRAIN = 0.5 66 | # Maximum scale for the image during test 67 | _C.INPUT.MAX_SCALE_TRAIN = 1.2 68 | # Random probability for image horizontal flip 69 | _C.INPUT.PROB = 0.5 70 | # Values to be used for image normalization 71 | _C.INPUT.PIXEL_MEAN = [0.1307, ] 72 | # Values to be used for image normalization 73 | _C.INPUT.PIXEL_STD = [0.3081, ] 74 | 75 | # ----------------------------------------------------------------------------- 76 | # Dataset 77 | # ----------------------------------------------------------------------------- 78 | _C.DATASETS = CN() 79 | # List of the dataset names for training, as present in paths_catalog.py 80 | _C.DATASETS.TRAIN = "" 81 | # List of the dataset names for testing, as present in paths_catalog.py 82 | _C.DATASETS.TEST = "" 83 | _C.DATASETS.SHIFT = 0.0 84 | _C.DATASETS.MAXRATION = 0.0 85 | _C.DATASETS.ROTATION = 0.0 86 | _C.DATASETS.USE_MASK = False 87 | _C.DATASETS.USE_DEPTH = False 88 | _C.DATASETS.USE_ALPHA = False 89 | _C.DATASETS.USE_BG = False 90 | _C.DATASETS.NUM_FRAME = 1 91 | _C.DATASETS.NUM_CAMERA = 1000 92 | _C.DATASETS.TYPE = "NR" 93 | _C.DATASETS.SYNTHESIS = True 94 | _C.DATASETS.NO_BOUNDARY = False 95 | _C.DATASETS.BOUNDARY_WIDTH = 3 96 | _C.DATASETS.PATCH_SIZE = 16 97 | _C.DATASETS.KEEP_BG = False 98 | _C.DATASETS.PAIR_SAMPLE = False 99 | 100 | # ----------------------------------------------------------------------------- 101 | # DataLoader 102 | # ----------------------------------------------------------------------------- 103 | _C.DATALOADER = CN() 104 | # Number of data loading threads 105 | _C.DATALOADER.NUM_WORKERS = 8 106 | 107 | # ---------------------------------------------------------------------------- # 108 | # Solver 109 | # ---------------------------------------------------------------------------- # 110 | _C.SOLVER = CN() 111 | _C.SOLVER.OPTIMIZER_NAME = "SGD" 112 | 113 | _C.SOLVER.MAX_EPOCHS = 50 114 | 115 | _C.SOLVER.BASE_LR = 0.001 116 | _C.SOLVER.BIAS_LR_FACTOR = 2 117 | 118 | _C.SOLVER.MOMENTUM = 0.9 119 | 120 | _C.SOLVER.WEIGHT_DECAY = 0.0005 121 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0 122 | 123 | _C.SOLVER.GAMMA = 0.1 124 | _C.SOLVER.STEPS = (30000,) 125 | 126 | _C.SOLVER.WARMUP_FACTOR = 1.0 / 3 127 | _C.SOLVER.WARMUP_ITERS = 500 128 | _C.SOLVER.WARMUP_METHOD = "linear" 129 | 130 | _C.SOLVER.LOSS_FN = "L1" 131 | 132 | _C.SOLVER.CHECKPOINT_PERIOD = 10 133 | _C.SOLVER.LOG_PERIOD = 100 134 | _C.SOLVER.BUNCH = 4096 135 | _C.SOLVER.BATCH_SIZE = 64 136 | _C.SOLVER.START_ITERS = 50 137 | _C.SOLVER.END_ITERS = 200 138 | _C.SOLVER.LR_SCALE = 0.1 139 | _C.SOLVER.COARSE_STAGE = 10 140 | 141 | # used for update 3d geometry 142 | _C.SOLVER.START_EPOCHS = 10 143 | _C.SOLVER.EPOCH_STEP = 1 144 | _C.SOLVER.CUBE_RESOLUTION = 512 145 | _C.SOLVER.CURVE_BUNCH = 10000 146 | _C.SOLVER.THRESHOLD = 5. 147 | _C.SOLVER.CURVE_SAMPLE_NUM = 64 148 | _C.SOLVER.CURVE_SAMPLE_SCALE = 1 149 | _C.SOLVER.UPDATE_GEOMETRY = False 150 | _C.SOLVER.UPDATE_RANGE = False 151 | _C.SOLVER.USE_BOARDER_COLOR = False 152 | _C.SOLVER.USE_AMP = False 153 | _C.SOLVER.SEED = 2021 154 | 155 | 156 | # Number of images per batch 157 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 158 | # see 2 images per batch 159 | _C.SOLVER.IMS_PER_BATCH = 16 160 | 161 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will 162 | # see 2 images per batch 163 | _C.TEST = CN() 164 | _C.TEST.IMS_PER_BATCH = 8 165 | _C.TEST.WEIGHT = "" 166 | 167 | # ---------------------------------------------------------------------------- # 168 | # Misc options 169 | # ---------------------------------------------------------------------------- # 170 | _C.OUTPUT_DIR = "" 171 | -------------------------------------------------------------------------------- /code/configs/configs.yml: -------------------------------------------------------------------------------- 1 | SOLVER: 2 | OPTIMIZER_NAME: "Adam" 3 | BASE_LR: 0.0012 # 12 4 | WEIGHT_DECAY: 0. # 0.0000001 5 | IMS_PER_BATCH: 1 6 | START_ITERS: 2000 7 | END_ITERS: 50000 8 | LR_SCALE: 0.08 9 | WARMUP_ITERS: 1000 10 | 11 | LOSS_FN: "L1" # L1 or L2 12 | 13 | MAX_EPOCHS: 200 14 | CHECKPOINT_PERIOD: 5000 15 | LOG_PERIOD: 30 16 | BUNCH: 60000 17 | BATCH_SIZE: 8 # 12 18 | COARSE_STAGE: 0 19 | 20 | USE_AMP: True 21 | SEED: 2021 22 | 23 | INPUT: 24 | SIZE_TRAIN: [960,540] 25 | SIZE_TEST: [960,540] 26 | 27 | DATASETS: 28 | TYPE: "NR" # "NR" "LLFF" 29 | SYNTHESIS: True 30 | 31 | TRAIN: "wolf" 32 | TEST: "wolf" 33 | 34 | NUM_FRAME: 48 35 | NUM_CAMERA: 24 # 24 36 | SHIFT: 0.0 37 | MAXRATION: 0.0 38 | ROTATION: 0.0 39 | 40 | USE_MASK: False 41 | USE_DEPTH: True 42 | USE_ALPHA: True 43 | 44 | PAIR_SAMPLE: False 45 | 46 | DATALOADER: 47 | NUM_WORKERS: 20 48 | 49 | MODEL: 50 | USE_MOTION: True 51 | USE_RENDER_NET: True 52 | USE_LIGHT_RENDERER: False 53 | 54 | USE_SH: True 55 | SH_DIM: 9 56 | SH_FEAT_DIM: 10 57 | BONE_FEAT_DIM: 10 58 | TREE_DEPTH: 8 59 | RANDOM_INI: True 60 | 61 | PAIR_TRAINING: False 62 | 63 | TEST: 64 | IMS_PER_BATCH: 1 65 | 66 | OUTPUT_DIR: "" 67 | 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /code/modeling/UNet/__init__.py: -------------------------------------------------------------------------------- 1 | from .unet_model import UNet, LightUNet 2 | -------------------------------------------------------------------------------- /code/modeling/UNet/unet_model.py: -------------------------------------------------------------------------------- 1 | # full assembly of the sub-parts to form the complete net 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from .unet_parts import * 6 | 7 | class LightUNet(nn.Module): 8 | def __init__(self, rgb_feat_channels, alpha_feat_channels, n_classes1, n_classes2): 9 | super(LightUNet, self).__init__() 10 | self.inc = inconv(rgb_feat_channels, 24) 11 | self.down1 = down(24, 48) 12 | self.down2 = down(48, 192) 13 | self.down3 = down(192, 192) 14 | self.up1 = up(192+192, 96) 15 | self.up2 = up(96+48, 48) 16 | self.up3 = up(48+24, 30) 17 | self.outc = outconv(30, n_classes1 + n_classes2) 18 | 19 | # self.inc2 = inconv(alpha_feat_channels + n_classes1, 16) 20 | # self.down5 = down(16, 32) 21 | # self.down6 = down(32, 64) 22 | # self.up5 = up(64 + 48 + 32, 32) 23 | # self.up6 = up(32 + 24 + 16, 8) 24 | # self.outc2 = outconv(8, n_classes2) 25 | 26 | def forward(self, rgb_feature, alpha_feature, motion_feature=None): 27 | # if motion_feature is not None: 28 | # rgb_feature = torch.cat([rgb_feature, motion_feature], dim=1) 29 | 30 | x1 = self.inc(rgb_feature) # 24 31 | x2 = self.down1(x1) # 48 32 | x3 = self.down2(x2) # 192 33 | x4 = self.down3(x3) # 192 34 | 35 | x5 = self.up1(x4, x3) # 96 36 | x5 = self.up2(x5, x2) # 48 37 | x5 = self.up3(x5, x1) # 30 38 | x_rgb = self.outc(x5) # n_classes1 39 | 40 | # x = torch.cat([alpha_feature, x_rgb], dim=1) 41 | # x1_2 = self.inc2(x) 42 | # x2_2 = self.down5(x1_2) 43 | # x3_2 = self.down6(x2_2) 44 | # 45 | # x6 = self.up5(x3_2, torch.cat([x2, x2_2], dim=1)) 46 | # x6 = self.up6(x6, torch.cat([x1, x1_2], dim=1)) 47 | # x_alpha = self.outc2(x6) 48 | # 49 | # x_rgba = torch.cat([x_rgb, x_alpha], dim=1) 50 | 51 | x_rgba = x_rgb 52 | return x_rgba 53 | 54 | 55 | class UNet(nn.Module): 56 | def __init__(self, rgb_feat_channels, alpha_feat_channels, n_classes1, n_classes2): 57 | super(UNet, self).__init__() 58 | self.inc = inconv(rgb_feat_channels, 24) 59 | self.down1 = down(24, 48) 60 | self.down2 = down(48, 96) 61 | self.down3 = down(96, 384) 62 | self.down4 = down(384, 384) 63 | self.up1 = up(768, 192) 64 | self.up2 = up(288, 48) 65 | self.up3 = up(96, 48) 66 | self.up4 = up(72, 30) 67 | self.outc = outconv(30, n_classes1) 68 | 69 | self.inc2 = inconv(alpha_feat_channels + n_classes1, 16) 70 | self.down5 = down(16, 32) 71 | self.down6 = down(32, 64) 72 | self.up5 = up(144, 32) 73 | self.up6 = up(72, 8) 74 | self.outc2 = outconv(8, n_classes2) 75 | 76 | def forward(self, rgb_feature, alpha_feature): 77 | x1 = self.inc(rgb_feature) 78 | x2 = self.down1(x1) 79 | x3 = self.down2(x2) 80 | x4 = self.down3(x3) 81 | x5 = self.down4(x4) 82 | 83 | x6 = self.up1(x5, x4) 84 | x6 = self.up2(x6, x3) 85 | x6 = self.up3(x6, x2) 86 | x6 = self.up4(x6, x1) 87 | x_rgb = self.outc(x6) 88 | 89 | x = torch.cat([alpha_feature, x_rgb], dim=1) 90 | x1_2 = self.inc2(x) 91 | x2_2 = self.down5(x1_2) 92 | x3_2 = self.down6(x2_2) 93 | 94 | x6 = self.up5(x3_2, torch.cat([x2, x2_2], dim=1)) 95 | x6 = self.up6(x6, torch.cat([x1, x1_2], dim=1)) 96 | x_alpha = self.outc2(x6) 97 | 98 | x = torch.cat([x_rgb, x_alpha], dim=1) 99 | return x 100 | -------------------------------------------------------------------------------- /code/modeling/UNet/unet_parts.py: -------------------------------------------------------------------------------- 1 | # sub-parts of the U-Net model 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import models_lpf 7 | 8 | 9 | class gated_conv(nn.Module): 10 | def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0): 11 | super(gated_conv, self).__init__() 12 | self.conv2 = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding) 13 | self.conv2_gate = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding) 14 | 15 | def forward(self, x): 16 | feat = self.conv2(x) 17 | mask = self.conv2_gate(x) 18 | 19 | return torch.sigmoid(mask) * feat 20 | 21 | 22 | class double_conv(nn.Module): 23 | '''(conv => BN => ReLU) * 2''' 24 | 25 | def __init__(self, in_ch, out_ch): 26 | super(double_conv, self).__init__() 27 | self.conv = nn.Sequential( 28 | gated_conv(in_ch, out_ch, 3, padding=1), 29 | nn.BatchNorm2d(out_ch), 30 | nn.ReLU(inplace=True), 31 | gated_conv(out_ch, out_ch, 3, padding=1), 32 | nn.BatchNorm2d(out_ch), 33 | nn.ReLU(inplace=True) 34 | ) 35 | 36 | def forward(self, x): 37 | x = self.conv(x) 38 | return x 39 | 40 | 41 | class inconv(nn.Module): 42 | def __init__(self, in_ch, out_ch): 43 | super(inconv, self).__init__() 44 | self.conv = double_conv(in_ch, out_ch) 45 | 46 | def forward(self, x): 47 | x = self.conv(x) 48 | return x 49 | 50 | 51 | class down(nn.Module): 52 | def __init__(self, in_ch, out_ch): 53 | super(down, self).__init__() 54 | self.mpconv = nn.Sequential( 55 | nn.MaxPool2d(2, stride=1), 56 | models_lpf.Downsample(channels=in_ch, filt_size=3, stride=2), 57 | double_conv(in_ch, out_ch) 58 | ) 59 | 60 | def forward(self, x): 61 | x = self.mpconv(x) 62 | return x 63 | 64 | 65 | class up(nn.Module): 66 | def __init__(self, in_ch, out_ch, bilinear=True): 67 | super(up, self).__init__() 68 | 69 | # would be a nice idea if the upsampling could be learned too, 70 | # but my machine do not have enough memory to handle all those weights 71 | if bilinear: 72 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 73 | else: 74 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2) 75 | 76 | self.conv = double_conv(in_ch, out_ch) 77 | 78 | def forward(self, x1, x2): 79 | x1 = self.up(x1) 80 | 81 | # input is CHW 82 | diffY = x2.size()[2] - x1.size()[2] 83 | diffX = x2.size()[3] - x1.size()[3] 84 | 85 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2, 86 | diffY // 2, diffY - diffY // 2)) 87 | 88 | # for padding issues, see 89 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 90 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 91 | 92 | x = torch.cat([x2, x1], dim=1) 93 | x = self.conv(x) 94 | return x 95 | 96 | 97 | class outconv(nn.Module): 98 | def __init__(self, in_ch, out_ch): 99 | super(outconv, self).__init__() 100 | self.conv = nn.Conv2d(in_ch, out_ch, 1) 101 | self.conv2 = nn.Conv2d(in_ch, out_ch, 3, padding=1) 102 | 103 | def forward(self, x): 104 | x1 = self.conv(x) 105 | x2 = self.conv2(x) 106 | return x1 + x2 107 | -------------------------------------------------------------------------------- /code/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import GeneralModel, NLayerDiscriminator 2 | from .tree_render import TreeRenderer 3 | import torch 4 | import os 5 | 6 | def build_model(cfg, flut_size=0, bone_feature_dim=0): 7 | if not cfg.MODEL.USE_MOTION: 8 | bone_feature_dim = 0 9 | return GeneralModel(cfg, flut_size=flut_size, use_render_net=cfg.MODEL.USE_RENDER_NET, bone_feature_dim=bone_feature_dim, texture_feature_dim=cfg.MODEL.SH_FEAT_DIM) 10 | 11 | 12 | def build_discriminator(cfg): 13 | return NLayerDiscriminator(input_nc=input_nc, ndf=32, n_layers=5, norm_layer='spectral') -------------------------------------------------------------------------------- /code/modeling/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .UNet import UNet, LightUNet 4 | from .tree_render import TreeRenderer 5 | 6 | import functools 7 | import torch.nn.utils.spectral_norm as spectral_norm 8 | from torch.cuda.amp import autocast 9 | from utils import generate_transformation_matrices 10 | 11 | import time 12 | 13 | class GeneralModel(nn.Module): 14 | def __init__(self, cfg, flut_size=0, use_render_net=True, bone_feature_dim=0, texture_feature_dim=0, sh_dim=9): 15 | super(GeneralModel, self).__init__() 16 | self.use_render_net = use_render_net 17 | self.bone_feature_dim = bone_feature_dim 18 | 19 | if not cfg.MODEL.USE_MOTION: 20 | self.bone_feature_dim = 0 21 | 22 | if flut_size != 0: 23 | self.tree_renderer = TreeRenderer(flut_size, background_brightness=0., random_init=cfg.MODEL.RANDOM_INI, feature_dim=texture_feature_dim * sh_dim + 1) 24 | 25 | 26 | if self.use_render_net: 27 | if not cfg.MODEL.USE_LIGHT_RENDERER: 28 | self.render_net = UNet(rgb_feat_channels=texture_feature_dim + self.bone_feature_dim, 29 | alpha_feat_channels=1, n_classes1=3, n_classes2=1) 30 | else: 31 | print('Light') 32 | self.render_net = LightUNet(rgb_feat_channels=texture_feature_dim + self.bone_feature_dim, 33 | alpha_feat_channels=1, n_classes1=3, n_classes2=1) 34 | 35 | def forward(self, rays, tree=None, joint_features=None, skinning_weights=None, joint_index=None, coords=None, transformation_matrices=None): 36 | batch_size, h, w = rays.size(0), rays.size(1), rays.size(2) 37 | joint_num, joint_feature_dim = (joint_features.size(1), joint_features.size(2)) if joint_features is not None else (0, 0) 38 | if isinstance(tree, list): 39 | assert len(tree) == batch_size 40 | else: 41 | tree = [tree] 42 | 43 | if isinstance(transformation_matrices, list): 44 | assert len(transformation_matrices) == batch_size 45 | else: 46 | transformation_matrices = [transformation_matrices] 47 | 48 | # print(rays.shape) 49 | 50 | rgbas, results = [], [] 51 | rgb_features, alpha_features, motion_features = [], [], [] 52 | features_in = [] 53 | vrts = [] 54 | 55 | joint_features = joint_features.reshape(-1, joint_feature_dim).type_as(rays) if joint_features is not None else None 56 | # print(joint_features.max()) 57 | if self.use_bone_net and joint_features is not None: 58 | joint_features = self.bone_net(joint_features) 59 | 60 | if joint_features is not None: 61 | joint_features = joint_features.reshape(batch_size, joint_num, -1) 62 | 63 | s = time.time() 64 | for i in range(batch_size): 65 | ray = rays[i:i+1, ...].reshape(-1, rays.size(3)) 66 | 67 | torch.cuda.synchronize() 68 | s = time.time() 69 | matrices = None 70 | # print(batch_size, i, transformation_matrices[i].max()) 71 | if transformation_matrices[i] is not None: 72 | matrices = generate_transformation_matrices(matrices=transformation_matrices[i], skinning_weights=skinning_weights, joint_index=joint_index) 73 | features = self.tree_renderer(tree[i], ray, matrices) 74 | features = features.reshape(1, h, w, -1).permute(0, 3, 1, 2) 75 | 76 | motion_feature = None 77 | if joint_features is not None: 78 | with autocast(enabled=False): 79 | motion_feature = self.tree_renderer.motion_feature_render(tree[i], joint_features[i].float(), skinning_weights, 80 | joint_index, 81 | ray) 82 | motion_feature = motion_feature.reshape(1, h, w, -1).permute(0, 3, 1, 2) 83 | torch.cuda.synchronize() 84 | features_in.append(features) 85 | motion_features.append(motion_feature) 86 | 87 | if coords is not None: 88 | vrts.append(self.tree_renderer.voxel_regularization(tree[i], coords[i])) 89 | 90 | features_in = torch.cat(features_in, dim=0) 91 | motion_features = None if motion_features[0] is None else torch.cat(motion_features, dim=0) 92 | 93 | rgbas = torch.cat([features_in[:, :3, ...], features_in[:, -1:, ...]], dim=1) 94 | torch.cuda.synchronize() 95 | s = time.time() 96 | if self.use_render_net: 97 | results = self.render_net(torch.cat([features[:, :-1, ...], motion_feature], dim=1), features_in[:, -1:, ...]) 98 | else: 99 | results = rgbas 100 | 101 | if coords is None: 102 | return rgbas, results, motion_features 103 | else: 104 | return rgbas, results, sum(vrts) / float(len(vrts)), motion_features 105 | 106 | def render_volume_feature(self, rays, tree=None, joint_features=None, skinning_weights=None, joint_index=None): 107 | batch_size, h, w = rays.size(0), rays.size(1), rays.size(2) 108 | rays = rays.reshape(-1, rays.size(3)) 109 | 110 | features = self.tree_renderer(tree, rays) 111 | features = features.reshape(batch_size, h, w, -1).permute(0, 3, 1, 2) 112 | rgb_feature_map = features[:, :-1, ...] 113 | alpha = features[:, -1:, ...] 114 | rgb = features[:, :3, ...] 115 | rgba = torch.cat([rgb, alpha], dim=1) 116 | 117 | motion_feature = None 118 | if joint_features is not None: 119 | motion_feature = self.tree_renderer.motion_feature_render(tree, joint_features, skinning_weights, joint_index, rays) # joint_features, skinning_weights, joint_index, rays 120 | 121 | return rgba, features, motion_feature 122 | 123 | 124 | 125 | class ConvBlock(nn.Module): 126 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1): 127 | super(ConvBlock, self).__init__() 128 | 129 | self.conv = nn.Sequential( 130 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 131 | kernel_size=kernel_size, padding=padding, stride=stride), 132 | # nn.BatchNorm2d(out_channels), 133 | nn.InstanceNorm2d(out_channels), 134 | nn.LeakyReLU(0.2, inplace=True) 135 | # nn.ReLU(inplace=True) 136 | ) 137 | 138 | def forward(self, x): 139 | return self.conv(x) 140 | 141 | 142 | class NLayerDiscriminator(nn.Module): 143 | """Defines a PatchGAN discriminator""" 144 | 145 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 146 | """Construct a PatchGAN discriminator 147 | Parameters: 148 | input_nc (int) -- the number of channels in input images 149 | ndf (int) -- the number of filters in the last conv layer 150 | n_layers (int) -- the number of conv layers in the discriminator 151 | norm_layer -- normalization layer 152 | """ 153 | super(NLayerDiscriminator, self).__init__() 154 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 155 | use_bias = norm_layer.func == nn.InstanceNorm2d 156 | else: 157 | use_bias = norm_layer == nn.InstanceNorm2d 158 | 159 | kw = 4 160 | padw = 1 161 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 162 | nf_mult = 1 163 | nf_mult_prev = 1 164 | for n in range(1, n_layers): # gradually increase the number of filters 165 | nf_mult_prev = nf_mult 166 | nf_mult = min(2 ** n, 8) 167 | sequence += [ 168 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 169 | norm_layer(ndf * nf_mult), 170 | nn.LeakyReLU(0.2, True) 171 | ] if norm_layer is not 'spectral' else [ 172 | spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias)), 173 | nn.LeakyReLU(0.2, True) 174 | ] 175 | 176 | nf_mult_prev = nf_mult 177 | nf_mult = min(2 ** n_layers, 8) 178 | sequence += [ 179 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 180 | norm_layer(ndf * nf_mult), 181 | nn.LeakyReLU(0.2, True) 182 | ] if norm_layer is not 'spectral' else [ 183 | spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias)), 184 | nn.LeakyReLU(0.2, True) 185 | ] 186 | 187 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 188 | self.model = nn.Sequential(*sequence) 189 | 190 | def forward(self, input): 191 | """Standard forward.""" 192 | return self.model(input) 193 | 194 | 195 | class PixelDiscriminator(nn.Module): 196 | """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" 197 | 198 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): 199 | """Construct a 1x1 PatchGAN discriminator 200 | Parameters: 201 | input_nc (int) -- the number of channels in input images 202 | ndf (int) -- the number of filters in the last conv layer 203 | norm_layer -- normalization layer 204 | """ 205 | super(PixelDiscriminator, self).__init__() 206 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 207 | use_bias = norm_layer.func == nn.InstanceNorm2d 208 | else: 209 | use_bias = norm_layer == nn.InstanceNorm2d 210 | 211 | self.net = [ 212 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 213 | nn.LeakyReLU(0.2, True), 214 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), 215 | norm_layer(ndf * 2), 216 | nn.LeakyReLU(0.2, True), 217 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 218 | 219 | self.net = nn.Sequential(*self.net) 220 | 221 | def forward(self, input): 222 | """Standard forward.""" 223 | return self.net(input) -------------------------------------------------------------------------------- /code/modeling/tree_render.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | import svox_t 7 | import collections 8 | Sh_Rays = collections.namedtuple('Sh_Rays', 'origins dirs viewdirs') 9 | 10 | 11 | class TreeRenderer(nn.Module): 12 | 13 | def __init__(self, flut_size, background_brightness=0., step_size=1e-3, random_init=False, feature_dim=0): 14 | super(TreeRenderer, self).__init__() 15 | self.background_brightness = background_brightness 16 | self.step_size = step_size 17 | print(flut_size) 18 | 19 | features = torch.empty((flut_size, feature_dim)) 20 | nn.init.normal_(features) 21 | self.register_parameter("features", nn.Parameter(features)) 22 | 23 | def forward(self, tree, rays, transformation_matrices=None, fast_rendering=False): 24 | t = tree.to(rays.device) 25 | 26 | r = svox_t.VolumeRenderer(t, background_brightness=self.background_brightness, step_size=self.step_size) 27 | dirs = rays[..., :3].contiguous() 28 | origins = rays[..., 3:].contiguous() 29 | 30 | sh_rays = Sh_Rays(origins, dirs, dirs) 31 | 32 | res = r(self.features, sh_rays, transformation_matrices=transformation_matrices, fast=fast_rendering) 33 | 34 | return res 35 | 36 | def motion_render(self, tree, rays): 37 | t = tree.to(rays.device) 38 | dirs = rays[..., :3].contiguous() 39 | origins = rays[..., 3:].contiguous() 40 | sh_rays = Sh_Rays(origins, dirs, dirs) 41 | r = svox_t.VolumeRenderer(t, background_brightness=self.background_brightness, step_size=self.step_size) 42 | motion_feature, depth, hit_point, data_idx = r.motion_render(self.features, sh_rays) 43 | 44 | return motion_feature, depth, hit_point, data_idx 45 | 46 | def motion_feature_render(self, tree, joint_features, skinning_weights, joint_index, rays): 47 | t = tree.to(rays.device) 48 | dirs = rays[..., :3].contiguous() 49 | origins = rays[..., 3:].contiguous() 50 | sh_rays = Sh_Rays(origins, dirs, dirs) 51 | r = svox_t.VolumeRenderer(t, background_brightness=self.background_brightness, step_size=self.step_size) 52 | motion_feature= r.motion_feature_render(self.features, joint_features, skinning_weights, joint_index, sh_rays) 53 | 54 | return motion_feature 55 | 56 | def voxel_regularization(self, tree, coordinate): 57 | tree = tree.to(coordinate.device) 58 | query_features = tree(self.features, coordinate, want_node_ids=False) 59 | vrt = nn.L1Loss()(query_features[..., :-1].detach(), self.features[..., :-1]) 60 | # vrt = nn.MSELoss()(query_features.detach(), self.features) 61 | return vrt 62 | 63 | @staticmethod 64 | def MotionRender(tree, features, rays): 65 | t = tree.to(rays.device) 66 | dirs = rays[..., :3].contiguous() 67 | origins = rays[..., 3:].contiguous() 68 | sh_rays = Sh_Rays(origins, dirs, dirs) 69 | r = svox_t.VolumeRenderer(t, background_brightness=0., step_size=1e-3) 70 | motion_feature, depth, hit_point, data_idx = r.motion_render(features, sh_rays) 71 | 72 | return motion_feature, depth, hit_point, data_idx 73 | 74 | @staticmethod 75 | def MotionFeatureRender(tree, features, joint_features, skinning_weights, joint_index, rays): 76 | t = tree.to(rays.device) 77 | dirs = rays[..., :3].contiguous() 78 | origins = rays[..., 3:].contiguous() 79 | sh_rays = Sh_Rays(origins, dirs, dirs) 80 | r = svox_t.VolumeRenderer(t, background_brightness=0., step_size=1e-3) 81 | motion_feature = r.motion_feature_render(features, joint_features, skinning_weights, joint_index, sh_rays) 82 | 83 | return motion_feature -------------------------------------------------------------------------------- /code/models_lpf/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, Adobe Inc. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4 | # 4.0 International Public License. To view a copy of this license, visit 5 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. 6 | 7 | from .downsample import * 8 | from .alexnet import * 9 | from .densenet import * 10 | from .mobilenet import * 11 | from .resnet import * 12 | from .vgg import * 13 | -------------------------------------------------------------------------------- /code/models_lpf/alexnet.py: -------------------------------------------------------------------------------- 1 | # This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models. 2 | # Copyright (c) 2017 Torch Contributors. 3 | # The Pytorch examples are available under the BSD 3-Clause License. 4 | # 5 | # ========================================================================================== 6 | # 7 | # Adobe’s modifications are Copyright 2019 Adobe. All rights reserved. 8 | # Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 9 | # 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit 10 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. 11 | # 12 | # ========================================================================================== 13 | # 14 | # BSD-3 License 15 | # 16 | # Redistribution and use in source and binary forms, with or without 17 | # modification, are permitted provided that the following conditions are met: 18 | # 19 | # * Redistributions of source code must retain the above copyright notice, this 20 | # list of conditions and the following disclaimer. 21 | # 22 | # * Redistributions in binary form must reproduce the above copyright notice, 23 | # this list of conditions and the following disclaimer in the documentation 24 | # and/or other materials provided with the distribution. 25 | # 26 | # * Neither the name of the copyright holder nor the names of its 27 | # contributors may be used to endorse or promote products derived from 28 | # this software without specific prior written permission. 29 | # 30 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 31 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 32 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 33 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 34 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 35 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 36 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 37 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 38 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 39 | 40 | import torch 41 | import torch.nn as nn 42 | import torch.utils.model_zoo as model_zoo 43 | import numpy as np 44 | from models_lpf import * 45 | from IPython import embed 46 | 47 | __all__ = ['AlexNet', 'alexnet'] 48 | 49 | 50 | # model_urls = { 51 | # 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth', 52 | # } 53 | 54 | 55 | class AlexNet(nn.Module): 56 | 57 | def __init__(self, num_classes=1000, filter_size=1, pool_only=False, relu_first=True): 58 | super(AlexNet, self).__init__() 59 | 60 | if(pool_only): # only apply LPF to pooling layers, so run conv1 at stride 4 as before 61 | first_ds = [nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),] 62 | else: 63 | if(relu_first): # this is the right order 64 | first_ds = [nn.Conv2d(3, 64, kernel_size=11, stride=2, padding=2), 65 | nn.ReLU(inplace=True), 66 | Downsample(filt_size=filter_size, stride=2, channels=64),] 67 | else: # this is the wrong order, since it's equivalent to downsampling the image first 68 | first_ds = [nn.Conv2d(3, 64, kernel_size=11, stride=2, padding=2), 69 | Downsample(filt_size=filter_size, stride=2, channels=64), 70 | nn.ReLU(inplace=True),] 71 | 72 | first_ds += [nn.MaxPool2d(kernel_size=3, stride=1), 73 | Downsample(filt_size=filter_size, stride=2, channels=64), 74 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 75 | nn.ReLU(inplace=True), 76 | nn.MaxPool2d(kernel_size=3, stride=1), 77 | Downsample(filt_size=filter_size, stride=2, channels=192), 78 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 79 | nn.ReLU(inplace=True), 80 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 81 | nn.ReLU(inplace=True), 82 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 83 | nn.ReLU(inplace=True), 84 | nn.MaxPool2d(kernel_size=3, stride=1), 85 | Downsample(filt_size=filter_size, stride=2, channels=256)] 86 | self.features = nn.Sequential(*first_ds) 87 | 88 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 89 | self.classifier = nn.Sequential( 90 | nn.Dropout(), 91 | nn.Linear(256 * 6 * 6, 4096), 92 | nn.ReLU(inplace=True), 93 | nn.Dropout(), 94 | nn.Linear(4096, 4096), 95 | nn.ReLU(inplace=True), 96 | nn.Linear(4096, num_classes), 97 | ) 98 | 99 | def forward(self, x): 100 | x = self.features(x) 101 | x = self.avgpool(x) 102 | x = x.view(x.size(0), 256 * 6 * 6) 103 | x = self.classifier(x) 104 | return x 105 | 106 | 107 | def alexnet(pretrained=False, **kwargs): 108 | """AlexNet model architecture from the 109 | `"One weird trick..." `_ paper. 110 | 111 | Args: 112 | pretrained (bool): If True, returns a model pre-trained on ImageNet 113 | """ 114 | model = AlexNet(**kwargs) 115 | if pretrained: 116 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) 117 | return model 118 | 119 | 120 | # replacing MaxPool with BlurPool layers 121 | class AlexNetNMP(nn.Module): 122 | 123 | def __init__(self, num_classes=1000, filter_size=1): 124 | super(AlexNetNMP, self).__init__() 125 | self.features = nn.Sequential( 126 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 127 | nn.ReLU(inplace=True), 128 | Downsample(filt_size=filter_size, stride=2, channels=64, pad_off=-1, hidden=True), 129 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 130 | nn.ReLU(inplace=True), 131 | Downsample(filt_size=filter_size, stride=2, channels=192, pad_off=-1, hidden=True), 132 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 133 | nn.ReLU(inplace=True), 134 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 135 | nn.ReLU(inplace=True), 136 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 137 | nn.ReLU(inplace=True), 138 | Downsample(filt_size=filter_size, stride=2, channels=256, pad_off=-1, hidden=True), 139 | ) 140 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 141 | self.classifier = nn.Sequential( 142 | nn.Dropout(), 143 | nn.Linear(256 * 6 * 6, 4096), 144 | nn.ReLU(inplace=True), 145 | nn.Dropout(), 146 | nn.Linear(4096, 4096), 147 | nn.ReLU(inplace=True), 148 | nn.Linear(4096, num_classes), 149 | ) 150 | 151 | def forward(self, x): 152 | # embed() 153 | x = self.features(x) 154 | x = self.avgpool(x) 155 | x = x.view(x.size(0), 256 * 6 * 6) 156 | x = self.classifier(x) 157 | return x 158 | 159 | 160 | def alexnetnmp(pretrained=False, **kwargs): 161 | """AlexNet model architecture from the 162 | `"One weird trick..." `_ paper. 163 | 164 | Args: 165 | pretrained (bool): If True, returns a model pre-trained on ImageNet 166 | """ 167 | model = AlexNetNMP(**kwargs) 168 | if pretrained: 169 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet'])) 170 | return model 171 | 172 | 173 | 174 | 175 | # def __init__(self, num_classes=1000): 176 | # super(AlexNet, self).__init__() 177 | # self.features = nn.Sequential( 178 | # nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 179 | # nn.ReLU(inplace=True), 180 | # nn.MaxPool2d(kernel_size=3, stride=2), 181 | # nn.Conv2d(64, 192, kernel_size=5, padding=2), 182 | # nn.ReLU(inplace=True), 183 | # nn.MaxPool2d(kernel_size=3, stride=2), 184 | # nn.Conv2d(192, 384, kernel_size=3, padding=1), 185 | # nn.ReLU(inplace=True), 186 | # nn.Conv2d(384, 256, kernel_size=3, padding=1), 187 | # nn.ReLU(inplace=True), 188 | # nn.Conv2d(256, 256, kernel_size=3, padding=1), 189 | # nn.ReLU(inplace=True), 190 | # nn.MaxPool2d(kernel_size=3, stride=2), 191 | # ) 192 | # self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) 193 | # self.classifier = nn.Sequential( 194 | # nn.Dropout(), 195 | # nn.Linear(256 * 6 * 6, 4096), 196 | # nn.ReLU(inplace=True), 197 | # nn.Dropout(), 198 | # nn.Linear(4096, 4096), 199 | # nn.ReLU(inplace=True), 200 | # nn.Linear(4096, num_classes), 201 | # ) 202 | 203 | 204 | 205 | -------------------------------------------------------------------------------- /code/models_lpf/densenet.py: -------------------------------------------------------------------------------- 1 | # This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models. 2 | # Copyright (c) 2017 Torch Contributors. 3 | # The Pytorch examples are available under the BSD 3-Clause License. 4 | # 5 | # ========================================================================================== 6 | # 7 | # Adobe’s modifications are Copyright 2019 Adobe. All rights reserved. 8 | # Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 9 | # 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit 10 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. 11 | # 12 | # ========================================================================================== 13 | # 14 | # BSD-3 License 15 | # 16 | # Redistribution and use in source and binary forms, with or without 17 | # modification, are permitted provided that the following conditions are met: 18 | # 19 | # * Redistributions of source code must retain the above copyright notice, this 20 | # list of conditions and the following disclaimer. 21 | # 22 | # * Redistributions in binary form must reproduce the above copyright notice, 23 | # this list of conditions and the following disclaimer in the documentation 24 | # and/or other materials provided with the distribution. 25 | # 26 | # * Neither the name of the copyright holder nor the names of its 27 | # contributors may be used to endorse or promote products derived from 28 | # this software without specific prior written permission. 29 | # 30 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 31 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 32 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 33 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 34 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 35 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 36 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 37 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 38 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 39 | 40 | import re 41 | import torch 42 | import torch.nn as nn 43 | import torch.nn.functional as F 44 | import torch.utils.model_zoo as model_zoo 45 | from collections import OrderedDict 46 | from models_lpf import * 47 | 48 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] 49 | 50 | 51 | # model_urls = { 52 | # 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 53 | # 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 54 | # 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 55 | # 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 56 | # } 57 | 58 | 59 | class _DenseLayer(nn.Sequential): 60 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 61 | super(_DenseLayer, self).__init__() 62 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 63 | self.add_module('relu1', nn.ReLU(inplace=True)), 64 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 65 | growth_rate, kernel_size=1, stride=1, bias=False)), 66 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 67 | self.add_module('relu2', nn.ReLU(inplace=True)), 68 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 69 | kernel_size=3, stride=1, padding=1, bias=False)), 70 | self.drop_rate = drop_rate 71 | 72 | def forward(self, x): 73 | new_features = super(_DenseLayer, self).forward(x) 74 | if self.drop_rate > 0: 75 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 76 | return torch.cat([x, new_features], 1) 77 | 78 | 79 | class _DenseBlock(nn.Sequential): 80 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 81 | super(_DenseBlock, self).__init__() 82 | for i in range(num_layers): 83 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 84 | self.add_module('denselayer%d' % (i + 1), layer) 85 | 86 | 87 | class _Transition(nn.Sequential): 88 | def __init__(self, num_input_features, num_output_features, filter_size=1): 89 | super(_Transition, self).__init__() 90 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 91 | self.add_module('relu', nn.ReLU(inplace=True)) 92 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 93 | kernel_size=1, stride=1, bias=False)) 94 | # self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) 95 | self.add_module('pool', Downsample(filt_size=filter_size, stride=2, channels=num_output_features)) 96 | 97 | 98 | class DenseNet(nn.Module): 99 | r"""Densenet-BC model class, based on 100 | `"Densely Connected Convolutional Networks" `_ 101 | Args: 102 | growth_rate (int) - how many filters to add each layer (`k` in paper) 103 | block_config (list of 4 ints) - how many layers in each pooling block 104 | num_init_features (int) - the number of filters to learn in the first convolution layer 105 | bn_size (int) - multiplicative factor for number of bottle neck layers 106 | (i.e. bn_size * k features in the bottleneck layer) 107 | drop_rate (float) - dropout rate after each dense layer 108 | num_classes (int) - number of classification classes 109 | """ 110 | 111 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 112 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, 113 | filter_size=1, pool_only=True): 114 | 115 | super(DenseNet, self).__init__() 116 | 117 | # First convolution 118 | if(pool_only): 119 | self.features = nn.Sequential(OrderedDict([ 120 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 121 | ('norm0', nn.BatchNorm2d(num_init_features)), 122 | ('relu0', nn.ReLU(inplace=True)), 123 | ('max0', nn.MaxPool2d(kernel_size=3, stride=1, padding=1)), 124 | ('pool0', Downsample(filt_size=filter_size, stride=2, channels=num_init_features)), 125 | ])) 126 | else: 127 | self.features = nn.Sequential(OrderedDict([ 128 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=1, padding=3, bias=False)), 129 | ('norm0', nn.BatchNorm2d(num_init_features)), 130 | ('relu0', nn.ReLU(inplace=True)), 131 | ('ds0', Downsample(filt_size=filter_size, stride=2, channels=num_init_features)), 132 | ('max0', nn.MaxPool2d(kernel_size=3, stride=1, padding=1)), 133 | ('pool0', Downsample(filt_size=filter_size, stride=2, channels=num_init_features)), 134 | ])) 135 | 136 | # Each denseblock 137 | num_features = num_init_features 138 | for i, num_layers in enumerate(block_config): 139 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 140 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 141 | self.features.add_module('denseblock%d' % (i + 1), block) 142 | num_features = num_features + num_layers * growth_rate 143 | if i != len(block_config) - 1: 144 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, filter_size=filter_size) 145 | self.features.add_module('transition%d' % (i + 1), trans) 146 | num_features = num_features // 2 147 | 148 | # Final batch norm 149 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 150 | 151 | # Linear layer 152 | self.classifier = nn.Linear(num_features, num_classes) 153 | 154 | # Official init from torch repo. 155 | for m in self.modules(): 156 | if isinstance(m, nn.Conv2d): 157 | if(m.in_channels!=m.out_channels or m.out_channels!=m.groups or m.bias is not None): 158 | # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics 159 | nn.init.kaiming_normal_(m.weight) 160 | else: 161 | print('Not initializing') 162 | elif isinstance(m, nn.BatchNorm2d): 163 | nn.init.constant_(m.weight, 1) 164 | nn.init.constant_(m.bias, 0) 165 | elif isinstance(m, nn.Linear): 166 | nn.init.constant_(m.bias, 0) 167 | 168 | def forward(self, x): 169 | features = self.features(x) 170 | out = F.relu(features, inplace=True) 171 | out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1) 172 | out = self.classifier(out) 173 | return out 174 | 175 | 176 | def _load_state_dict(model, model_url): 177 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 178 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 179 | # They are also in the checkpoints in model_urls. This pattern is used 180 | # to find such keys. 181 | pattern = re.compile( 182 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 183 | state_dict = model_zoo.load_url(model_url) 184 | for key in list(state_dict.keys()): 185 | res = pattern.match(key) 186 | if res: 187 | new_key = res.group(1) + res.group(2) 188 | state_dict[new_key] = state_dict[key] 189 | del state_dict[key] 190 | model.load_state_dict(state_dict) 191 | 192 | 193 | def densenet121(pretrained=False, filter_size=1, pool_only=True, **kwargs): 194 | r"""Densenet-121 model from 195 | `"Densely Connected Convolutional Networks" `_ 196 | Args: 197 | pretrained (bool): If True, returns a model pre-trained on ImageNet 198 | """ 199 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 200 | filter_size=filter_size, pool_only=pool_only, **kwargs) 201 | if pretrained: 202 | _load_state_dict(model, model_urls['densenet121']) 203 | return model 204 | 205 | 206 | def densenet169(pretrained=False, filter_size=1, pool_only=True, **kwargs): 207 | r"""Densenet-169 model from 208 | `"Densely Connected Convolutional Networks" `_ 209 | Args: 210 | pretrained (bool): If True, returns a model pre-trained on ImageNet 211 | """ 212 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 213 | filter_size=filter_size, pool_only=pool_only, **kwargs) 214 | if pretrained: 215 | _load_state_dict(model, model_urls['densenet169']) 216 | return model 217 | 218 | 219 | def densenet201(pretrained=False, filter_size=1, pool_only=True, **kwargs): 220 | r"""Densenet-201 model from 221 | `"Densely Connected Convolutional Networks" `_ 222 | Args: 223 | pretrained (bool): If True, returns a model pre-trained on ImageNet 224 | """ 225 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 226 | filter_size=filter_size, pool_only=pool_only, **kwargs) 227 | if pretrained: 228 | _load_state_dict(model, model_urls['densenet201']) 229 | return model 230 | 231 | 232 | def densenet161(pretrained=False, filter_size=1, pool_only=True, **kwargs): 233 | r"""Densenet-161 model from 234 | `"Densely Connected Convolutional Networks" `_ 235 | Args: 236 | pretrained (bool): If True, returns a model pre-trained on ImageNet 237 | """ 238 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 239 | filter_size=filter_size, pool_only=pool_only, **kwargs) 240 | if pretrained: 241 | _load_state_dict(model, model_urls['densenet161']) 242 | return model -------------------------------------------------------------------------------- /code/models_lpf/downsample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, Adobe Inc. All rights reserved. 2 | # 3 | # This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4 | # 4.0 International Public License. To view a copy of this license, visit 5 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. 6 | 7 | import torch 8 | import torch.nn.parallel 9 | import numpy as np 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from IPython import embed 13 | 14 | class Downsample(nn.Module): 15 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): 16 | super(Downsample, self).__init__() 17 | self.filt_size = filt_size 18 | self.pad_off = pad_off 19 | self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))] 20 | self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes] 21 | self.stride = stride 22 | self.off = int((self.stride-1)/2.) 23 | self.channels = channels 24 | 25 | if(self.filt_size==1): 26 | a = np.array([1.,]) 27 | elif(self.filt_size==2): 28 | a = np.array([1., 1.]) 29 | elif(self.filt_size==3): 30 | a = np.array([1., 2., 1.]) 31 | elif(self.filt_size==4): 32 | a = np.array([1., 3., 3., 1.]) 33 | elif(self.filt_size==5): 34 | a = np.array([1., 4., 6., 4., 1.]) 35 | elif(self.filt_size==6): 36 | a = np.array([1., 5., 10., 10., 5., 1.]) 37 | elif(self.filt_size==7): 38 | a = np.array([1., 6., 15., 20., 15., 6., 1.]) 39 | 40 | filt = torch.Tensor(a[:,None]*a[None,:]) 41 | filt = filt/torch.sum(filt) 42 | self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1))) 43 | 44 | self.pad = get_pad_layer(pad_type)(self.pad_sizes) 45 | 46 | def forward(self, inp): 47 | if(self.filt_size==1): 48 | if(self.pad_off==0): 49 | return inp[:,:,::self.stride,::self.stride] 50 | else: 51 | return self.pad(inp)[:,:,::self.stride,::self.stride] 52 | else: 53 | return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) 54 | 55 | def get_pad_layer(pad_type): 56 | if(pad_type in ['refl','reflect']): 57 | PadLayer = nn.ReflectionPad2d 58 | elif(pad_type in ['repl','replicate']): 59 | PadLayer = nn.ReplicationPad2d 60 | elif(pad_type=='zero'): 61 | PadLayer = nn.ZeroPad2d 62 | else: 63 | print('Pad type [%s] not recognized'%pad_type) 64 | return PadLayer 65 | 66 | class Downsample1D(nn.Module): 67 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): 68 | super(Downsample1D, self).__init__() 69 | self.filt_size = filt_size 70 | self.pad_off = pad_off 71 | self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] 72 | self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] 73 | self.stride = stride 74 | self.off = int((self.stride - 1) / 2.) 75 | self.channels = channels 76 | 77 | # print('Filter size [%i]' % filt_size) 78 | if(self.filt_size == 1): 79 | a = np.array([1., ]) 80 | elif(self.filt_size == 2): 81 | a = np.array([1., 1.]) 82 | elif(self.filt_size == 3): 83 | a = np.array([1., 2., 1.]) 84 | elif(self.filt_size == 4): 85 | a = np.array([1., 3., 3., 1.]) 86 | elif(self.filt_size == 5): 87 | a = np.array([1., 4., 6., 4., 1.]) 88 | elif(self.filt_size == 6): 89 | a = np.array([1., 5., 10., 10., 5., 1.]) 90 | elif(self.filt_size == 7): 91 | a = np.array([1., 6., 15., 20., 15., 6., 1.]) 92 | 93 | filt = torch.Tensor(a) 94 | filt = filt / torch.sum(filt) 95 | self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1))) 96 | 97 | self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes) 98 | 99 | def forward(self, inp): 100 | if(self.filt_size == 1): 101 | if(self.pad_off == 0): 102 | return inp[:, :, ::self.stride] 103 | else: 104 | return self.pad(inp)[:, :, ::self.stride] 105 | else: 106 | return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) 107 | 108 | def get_pad_layer_1d(pad_type): 109 | if(pad_type in ['refl', 'reflect']): 110 | PadLayer = nn.ReflectionPad1d 111 | elif(pad_type in ['repl', 'replicate']): 112 | PadLayer = nn.ReplicationPad1d 113 | elif(pad_type == 'zero'): 114 | PadLayer = nn.ZeroPad1d 115 | else: 116 | print('Pad type [%s] not recognized' % pad_type) 117 | return PadLayer 118 | -------------------------------------------------------------------------------- /code/models_lpf/mobilenet.py: -------------------------------------------------------------------------------- 1 | # This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models. 2 | # Copyright (c) 2017 Torch Contributors. 3 | # The Pytorch examples are available under the BSD 3-Clause License. 4 | 5 | # ========================================================================================== 6 | 7 | # Adobe’s modifications are Copyright 2019 Adobe. All rights reserved. 8 | # Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 9 | # 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit 10 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. 11 | 12 | # ========================================================================================== 13 | 14 | # BSD-3 License 15 | 16 | # Redistribution and use in source and binary forms, with or without 17 | # modification, are permitted provided that the following conditions are met: 18 | 19 | # * Redistributions of source code must retain the above copyright notice, this 20 | # list of conditions and the following disclaimer. 21 | 22 | # * Redistributions in binary form must reproduce the above copyright notice, 23 | # this list of conditions and the following disclaimer in the documentation 24 | # and/or other materials provided with the distribution. 25 | 26 | # * Neither the name of the copyright holder nor the names of its 27 | # contributors may be used to endorse or promote products derived from 28 | # this software without specific prior written permission. 29 | 30 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 31 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 32 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 33 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 34 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 35 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 36 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 37 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 38 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 39 | 40 | from torch import nn 41 | from models_lpf import * 42 | 43 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 44 | 45 | 46 | # model_urls = { 47 | # 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 48 | # } 49 | 50 | 51 | class ConvBNReLU(nn.Sequential): 52 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): 53 | padding = (kernel_size - 1) // 2 54 | super(ConvBNReLU, self).__init__( 55 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 56 | nn.BatchNorm2d(out_planes), 57 | nn.ReLU6(inplace=True) 58 | ) 59 | 60 | 61 | class InvertedResidual(nn.Module): 62 | def __init__(self, inp, oup, stride, expand_ratio, filter_size=1): 63 | super(InvertedResidual, self).__init__() 64 | self.stride = stride 65 | assert stride in [1, 2] 66 | 67 | hidden_dim = int(round(inp * expand_ratio)) 68 | self.use_res_connect = self.stride == 1 and inp == oup 69 | 70 | layers = [] 71 | if expand_ratio != 1: 72 | # pw 73 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) 74 | if(stride==1): 75 | layers.extend([ 76 | # dw 77 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), 78 | # pw-linear 79 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 80 | nn.BatchNorm2d(oup), 81 | ]) 82 | else: 83 | layers.extend([ 84 | # dw 85 | ConvBNReLU(hidden_dim, hidden_dim, stride=1, groups=hidden_dim), 86 | Downsample(filt_size=filter_size, stride=stride, channels=hidden_dim), 87 | # pw-linear 88 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 89 | nn.BatchNorm2d(oup), 90 | ]) 91 | self.conv = nn.Sequential(*layers) 92 | 93 | def forward(self, x): 94 | if self.use_res_connect: 95 | return x + self.conv(x) 96 | else: 97 | return self.conv(x) 98 | 99 | 100 | class MobileNetV2(nn.Module): 101 | def __init__(self, num_classes=1000, width_mult=1.0, filter_size=1): 102 | super(MobileNetV2, self).__init__() 103 | block = InvertedResidual 104 | input_channel = 32 105 | last_channel = 1280 106 | inverted_residual_setting = [ 107 | # t, c, n, s 108 | [1, 16, 1, 1], 109 | [6, 24, 2, 2], 110 | [6, 32, 3, 2], 111 | [6, 64, 4, 2], 112 | [6, 96, 3, 1], 113 | [6, 160, 3, 2], 114 | [6, 320, 1, 1], 115 | ] 116 | 117 | # building first layer 118 | input_channel = int(input_channel * width_mult) 119 | self.last_channel = int(last_channel * max(1.0, width_mult)) 120 | features = [ConvBNReLU(3, input_channel, stride=2)] 121 | # building inverted residual blocks 122 | for t, c, n, s in inverted_residual_setting: 123 | output_channel = int(c * width_mult) 124 | for i in range(n): 125 | stride = s if i == 0 else 1 126 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, filter_size=filter_size)) 127 | input_channel = output_channel 128 | # building last several layers 129 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1)) 130 | # make it nn.Sequential 131 | self.features = nn.Sequential(*features) 132 | 133 | # building classifier 134 | self.classifier = nn.Sequential( 135 | # nn.Dropout(0.2), 136 | nn.Linear(self.last_channel, num_classes), 137 | ) 138 | 139 | # weight initialization 140 | for m in self.modules(): 141 | if isinstance(m, nn.Conv2d): 142 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 143 | if m.bias is not None: 144 | nn.init.zeros_(m.bias) 145 | elif isinstance(m, nn.BatchNorm2d): 146 | nn.init.ones_(m.weight) 147 | nn.init.zeros_(m.bias) 148 | elif isinstance(m, nn.Linear): 149 | nn.init.normal_(m.weight, 0, 0.01) 150 | nn.init.zeros_(m.bias) 151 | 152 | def forward(self, x): 153 | x = self.features(x) 154 | x = x.mean([2, 3]) 155 | x = self.classifier(x) 156 | return x 157 | 158 | 159 | def mobilenet_v2(pretrained=False, progress=True, filter_size=1, **kwargs): 160 | """ 161 | Constructs a MobileNetV2 architecture from 162 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 163 | Args: 164 | pretrained (bool): If True, returns a model pre-trained on ImageNet 165 | progress (bool): If True, displays a progress bar of the download to stderr 166 | """ 167 | model = MobileNetV2(filter_size=filter_size, **kwargs) 168 | # if pretrained: 169 | # state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 170 | # progress=progress) 171 | # model.load_state_dict(state_dict) 172 | return model 173 | -------------------------------------------------------------------------------- /code/models_lpf/resnet.py: -------------------------------------------------------------------------------- 1 | # This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models. 2 | # Copyright (c) 2017 Torch Contributors. 3 | # The Pytorch examples are available under the BSD 3-Clause License. 4 | # 5 | # ========================================================================================== 6 | # 7 | # Adobe’s modifications are Copyright 2019 Adobe. All rights reserved. 8 | # Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 9 | # 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit 10 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. 11 | # 12 | # ========================================================================================== 13 | # 14 | # BSD-3 License 15 | # 16 | # Redistribution and use in source and binary forms, with or without 17 | # modification, are permitted provided that the following conditions are met: 18 | # 19 | # * Redistributions of source code must retain the above copyright notice, this 20 | # list of conditions and the following disclaimer. 21 | # 22 | # * Redistributions in binary form must reproduce the above copyright notice, 23 | # this list of conditions and the following disclaimer in the documentation 24 | # and/or other materials provided with the distribution. 25 | # 26 | # * Neither the name of the copyright holder nor the names of its 27 | # contributors may be used to endorse or promote products derived from 28 | # this software without specific prior written permission. 29 | # 30 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 31 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 32 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 33 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 34 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 35 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 36 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 37 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 38 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 39 | 40 | import torch.nn as nn 41 | import torch.utils.model_zoo as model_zoo 42 | from models_lpf import * 43 | 44 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 45 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d'] 46 | 47 | 48 | # model_urls = { 49 | # 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 50 | # 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 51 | # 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 52 | # 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 53 | # 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 54 | # } 55 | 56 | 57 | def conv3x3(in_planes, out_planes, stride=1, groups=1): 58 | """3x3 convolution with padding""" 59 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 60 | padding=1, groups=groups, bias=False) 61 | 62 | def conv1x1(in_planes, out_planes, stride=1): 63 | """1x1 convolution""" 64 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 65 | 66 | class BasicBlock(nn.Module): 67 | expansion = 1 68 | 69 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1): 70 | super(BasicBlock, self).__init__() 71 | if norm_layer is None: 72 | norm_layer = nn.BatchNorm2d 73 | if groups != 1: 74 | raise ValueError('BasicBlock only supports groups=1') 75 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 76 | self.conv1 = conv3x3(inplanes, planes) 77 | self.bn1 = norm_layer(planes) 78 | self.relu = nn.ReLU(inplace=True) 79 | if(stride==1): 80 | self.conv2 = conv3x3(planes,planes) 81 | else: 82 | self.conv2 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes), 83 | conv3x3(planes, planes),) 84 | self.bn2 = norm_layer(planes) 85 | self.downsample = downsample 86 | self.stride = stride 87 | 88 | def forward(self, x): 89 | identity = x 90 | 91 | out = self.conv1(x) 92 | out = self.bn1(out) 93 | out = self.relu(out) 94 | 95 | out = self.conv2(out) 96 | out = self.bn2(out) 97 | 98 | if self.downsample is not None: 99 | identity = self.downsample(x) 100 | 101 | out += identity 102 | out = self.relu(out) 103 | 104 | return out 105 | 106 | 107 | class Bottleneck(nn.Module): 108 | expansion = 4 109 | 110 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1): 111 | super(Bottleneck, self).__init__() 112 | if norm_layer is None: 113 | norm_layer = nn.BatchNorm2d 114 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 115 | self.conv1 = conv1x1(inplanes, planes) 116 | self.bn1 = norm_layer(planes) 117 | self.conv2 = conv3x3(planes, planes, groups) # stride moved 118 | self.bn2 = norm_layer(planes) 119 | if(stride==1): 120 | self.conv3 = conv1x1(planes, planes * self.expansion) 121 | else: 122 | self.conv3 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes), 123 | conv1x1(planes, planes * self.expansion)) 124 | self.bn3 = norm_layer(planes * self.expansion) 125 | self.relu = nn.ReLU(inplace=True) 126 | self.downsample = downsample 127 | self.stride = stride 128 | 129 | def forward(self, x): 130 | identity = x 131 | 132 | out = self.conv1(x) 133 | out = self.bn1(out) 134 | out = self.relu(out) 135 | 136 | out = self.conv2(out) 137 | out = self.bn2(out) 138 | out = self.relu(out) 139 | 140 | out = self.conv3(out) 141 | out = self.bn3(out) 142 | 143 | if self.downsample is not None: 144 | identity = self.downsample(x) 145 | 146 | out += identity 147 | out = self.relu(out) 148 | 149 | return out 150 | 151 | 152 | class ResNet(nn.Module): 153 | 154 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 155 | groups=1, width_per_group=64, norm_layer=None, filter_size=1, pool_only=True): 156 | super(ResNet, self).__init__() 157 | if norm_layer is None: 158 | norm_layer = nn.BatchNorm2d 159 | planes = [int(width_per_group * groups * 2 ** i) for i in range(4)] 160 | self.inplanes = planes[0] 161 | 162 | if(pool_only): 163 | self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=2, padding=3, bias=False) 164 | else: 165 | self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=1, padding=3, bias=False) 166 | self.bn1 = norm_layer(planes[0]) 167 | self.relu = nn.ReLU(inplace=True) 168 | 169 | if(pool_only): 170 | self.maxpool = nn.Sequential(*[nn.MaxPool2d(kernel_size=2, stride=1), 171 | Downsample(filt_size=filter_size, stride=2, channels=planes[0])]) 172 | else: 173 | self.maxpool = nn.Sequential(*[Downsample(filt_size=filter_size, stride=2, channels=planes[0]), 174 | nn.MaxPool2d(kernel_size=2, stride=1), 175 | Downsample(filt_size=filter_size, stride=2, channels=planes[0])]) 176 | 177 | self.layer1 = self._make_layer(block, planes[0], layers[0], groups=groups, norm_layer=norm_layer) 178 | self.layer2 = self._make_layer(block, planes[1], layers[1], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) 179 | self.layer3 = self._make_layer(block, planes[2], layers[2], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) 180 | self.layer4 = self._make_layer(block, planes[3], layers[3], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size) 181 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 182 | self.fc = nn.Linear(planes[3] * block.expansion, num_classes) 183 | 184 | for m in self.modules(): 185 | if isinstance(m, nn.Conv2d): 186 | if(m.in_channels!=m.out_channels or m.out_channels!=m.groups or m.bias is not None): 187 | # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics 188 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 189 | else: 190 | print('Not initializing') 191 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 192 | nn.init.constant_(m.weight, 1) 193 | nn.init.constant_(m.bias, 0) 194 | 195 | # Zero-initialize the last BN in each residual branch, 196 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 197 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 198 | if zero_init_residual: 199 | for m in self.modules(): 200 | if isinstance(m, Bottleneck): 201 | nn.init.constant_(m.bn3.weight, 0) 202 | elif isinstance(m, BasicBlock): 203 | nn.init.constant_(m.bn2.weight, 0) 204 | 205 | def _make_layer(self, block, planes, blocks, stride=1, groups=1, norm_layer=None, filter_size=1): 206 | if norm_layer is None: 207 | norm_layer = nn.BatchNorm2d 208 | downsample = None 209 | if stride != 1 or self.inplanes != planes * block.expansion: 210 | # downsample = nn.Sequential( 211 | # conv1x1(self.inplanes, planes * block.expansion, stride, filter_size=filter_size), 212 | # norm_layer(planes * block.expansion), 213 | # ) 214 | 215 | downsample = [Downsample(filt_size=filter_size, stride=stride, channels=self.inplanes),] if(stride !=1) else [] 216 | downsample += [conv1x1(self.inplanes, planes * block.expansion, 1), 217 | norm_layer(planes * block.expansion)] 218 | # print(downsample) 219 | downsample = nn.Sequential(*downsample) 220 | 221 | layers = [] 222 | layers.append(block(self.inplanes, planes, stride, downsample, groups, norm_layer, filter_size=filter_size)) 223 | self.inplanes = planes * block.expansion 224 | for _ in range(1, blocks): 225 | layers.append(block(self.inplanes, planes, groups=groups, norm_layer=norm_layer, filter_size=filter_size)) 226 | 227 | return nn.Sequential(*layers) 228 | 229 | def forward(self, x): 230 | x = self.conv1(x) 231 | x = self.bn1(x) 232 | x = self.relu(x) 233 | x = self.maxpool(x) 234 | 235 | x = self.layer1(x) 236 | x = self.layer2(x) 237 | x = self.layer3(x) 238 | x = self.layer4(x) 239 | 240 | x = self.avgpool(x) 241 | x = x.view(x.size(0), -1) 242 | x = self.fc(x) 243 | 244 | return x 245 | 246 | 247 | def resnet18(pretrained=False, filter_size=1, pool_only=True, **kwargs): 248 | """Constructs a ResNet-18 model. 249 | Args: 250 | pretrained (bool): If True, returns a model pre-trained on ImageNet 251 | """ 252 | model = ResNet(BasicBlock, [2, 2, 2, 2], filter_size=filter_size, pool_only=pool_only, **kwargs) 253 | if pretrained: 254 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 255 | return model 256 | 257 | 258 | def resnet34(pretrained=False, filter_size=1, pool_only=True, **kwargs): 259 | """Constructs a ResNet-34 model. 260 | Args: 261 | pretrained (bool): If True, returns a model pre-trained on ImageNet 262 | """ 263 | model = ResNet(BasicBlock, [3, 4, 6, 3], filter_size=filter_size, pool_only=pool_only, **kwargs) 264 | if pretrained: 265 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 266 | return model 267 | 268 | 269 | def resnet50(pretrained=False, filter_size=1, pool_only=True, **kwargs): 270 | """Constructs a ResNet-50 model. 271 | Args: 272 | pretrained (bool): If True, returns a model pre-trained on ImageNet 273 | """ 274 | model = ResNet(Bottleneck, [3, 4, 6, 3], filter_size=filter_size, pool_only=pool_only, **kwargs) 275 | if pretrained: 276 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 277 | return model 278 | 279 | 280 | def resnet101(pretrained=False, filter_size=1, pool_only=True, **kwargs): 281 | """Constructs a ResNet-101 model. 282 | Args: 283 | pretrained (bool): If True, returns a model pre-trained on ImageNet 284 | """ 285 | model = ResNet(Bottleneck, [3, 4, 23, 3], filter_size=filter_size, pool_only=pool_only, **kwargs) 286 | if pretrained: 287 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 288 | return model 289 | 290 | 291 | def resnet152(pretrained=False, filter_size=1, pool_only=True, **kwargs): 292 | """Constructs a ResNet-152 model. 293 | Args: 294 | pretrained (bool): If True, returns a model pre-trained on ImageNet 295 | """ 296 | model = ResNet(Bottleneck, [3, 8, 36, 3], filter_size=filter_size, pool_only=pool_only, **kwargs) 297 | if pretrained: 298 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 299 | return model 300 | 301 | 302 | def resnext50_32x4d(pretrained=False, filter_size=1, pool_only=True, **kwargs): 303 | model = ResNet(Bottleneck, [3, 4, 6, 3], groups=4, width_per_group=32, filter_size=filter_size, pool_only=pool_only, **kwargs) 304 | # if pretrained: 305 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 306 | return model 307 | 308 | 309 | def resnext101_32x8d(pretrained=False, filter_size=1, pool_only=True, **kwargs): 310 | model = ResNet(Bottleneck, [3, 4, 23, 3], groups=8, width_per_group=32, filter_size=filter_size, pool_only=pool_only, **kwargs) 311 | # if pretrained: 312 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 313 | return model -------------------------------------------------------------------------------- /code/models_lpf/vgg.py: -------------------------------------------------------------------------------- 1 | # This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models. 2 | # Copyright (c) 2017 Torch Contributors. 3 | # The Pytorch examples are available under the BSD 3-Clause License. 4 | # 5 | # ========================================================================================== 6 | # 7 | # Adobe’s modifications are Copyright 2019 Adobe. All rights reserved. 8 | # Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 9 | # 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit 10 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode. 11 | # 12 | # ========================================================================================== 13 | # 14 | # BSD-3 License 15 | # 16 | # Redistribution and use in source and binary forms, with or without 17 | # modification, are permitted provided that the following conditions are met: 18 | # 19 | # * Redistributions of source code must retain the above copyright notice, this 20 | # list of conditions and the following disclaimer. 21 | # 22 | # * Redistributions in binary form must reproduce the above copyright notice, 23 | # this list of conditions and the following disclaimer in the documentation 24 | # and/or other materials provided with the distribution. 25 | # 26 | # * Neither the name of the copyright holder nor the names of its 27 | # contributors may be used to endorse or promote products derived from 28 | # this software without specific prior written permission. 29 | # 30 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 31 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 32 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 33 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 34 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 35 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 36 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 37 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 38 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 39 | 40 | import torch.nn as nn 41 | import torch.utils.model_zoo as model_zoo 42 | from models_lpf import * 43 | 44 | __all__ = [ 45 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 46 | 'vgg19_bn', 'vgg19', 47 | ] 48 | 49 | 50 | # model_urls = { 51 | # 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 52 | # 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 53 | # 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 54 | # 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 55 | # 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 56 | # 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 57 | # 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 58 | # 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 59 | # } 60 | 61 | 62 | class VGG(nn.Module): 63 | 64 | def __init__(self, features, num_classes=1000, init_weights=True): 65 | super(VGG, self).__init__() 66 | self.features = features 67 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 68 | self.classifier = nn.Sequential( 69 | nn.Linear(512 * 7 * 7, 4096), 70 | nn.ReLU(True), 71 | nn.Dropout(), 72 | nn.Linear(4096, 4096), 73 | nn.ReLU(True), 74 | nn.Dropout(), 75 | nn.Linear(4096, num_classes), 76 | ) 77 | if init_weights: 78 | self._initialize_weights() 79 | 80 | def forward(self, x): 81 | x = self.features(x) 82 | # print(x.shape) 83 | x = self.avgpool(x) 84 | x = x.view(x.size(0), -1) 85 | x = self.classifier(x) 86 | return x 87 | 88 | def _initialize_weights(self): 89 | for m in self.modules(): 90 | if isinstance(m, nn.Conv2d): 91 | if(m.in_channels!=m.out_channels or m.out_channels!=m.groups or m.bias is not None): 92 | # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics 93 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 94 | if m.bias is not None: 95 | nn.init.constant_(m.bias, 0) 96 | else: 97 | print('Not initializing') 98 | elif isinstance(m, nn.BatchNorm2d): 99 | nn.init.constant_(m.weight, 1) 100 | nn.init.constant_(m.bias, 0) 101 | elif isinstance(m, nn.Linear): 102 | nn.init.normal_(m.weight, 0, 0.01) 103 | nn.init.constant_(m.bias, 0) 104 | 105 | 106 | def make_layers(cfg, batch_norm=False, filter_size=1): 107 | layers = [] 108 | in_channels = 3 109 | for v in cfg: 110 | if v == 'M': 111 | # layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 112 | layers += [nn.MaxPool2d(kernel_size=2, stride=1), Downsample(filt_size=filter_size, stride=2, channels=in_channels)] 113 | else: 114 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 115 | if batch_norm: 116 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 117 | else: 118 | layers += [conv2d, nn.ReLU(inplace=True)] 119 | in_channels = v 120 | return nn.Sequential(*layers) 121 | 122 | 123 | cfg = { 124 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 125 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 126 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 127 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 128 | } 129 | 130 | def vgg11(pretrained=False, filter_size=1, **kwargs): 131 | """VGG 11-layer model (configuration "A") 132 | 133 | Args: 134 | pretrained (bool): If True, returns a model pre-trained on ImageNet 135 | """ 136 | if pretrained: 137 | kwargs['init_weights'] = False 138 | model = VGG(make_layers(cfg['A'], filter_size=filter_size), **kwargs) 139 | if pretrained: 140 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 141 | return model 142 | 143 | 144 | def vgg11_bn(pretrained=False, filter_size=1, **kwargs): 145 | """VGG 11-layer model (configuration "A") with batch normalization 146 | 147 | Args: 148 | pretrained (bool): If True, returns a model pre-trained on ImageNet 149 | """ 150 | if pretrained: 151 | kwargs['init_weights'] = False 152 | model = VGG(make_layers(cfg['A'], filter_size=filter_size, batch_norm=True), **kwargs) 153 | if pretrained: 154 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 155 | return model 156 | 157 | 158 | def vgg13(pretrained=False, filter_size=1, **kwargs): 159 | """VGG 13-layer model (configuration "B") 160 | 161 | Args: 162 | pretrained (bool): If True, returns a model pre-trained on ImageNet 163 | """ 164 | if pretrained: 165 | kwargs['init_weights'] = False 166 | model = VGG(make_layers(cfg['B'], filter_size=filter_size), **kwargs) 167 | if pretrained: 168 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13'])) 169 | return model 170 | 171 | 172 | def vgg13_bn(pretrained=False, filter_size=1, **kwargs): 173 | """VGG 13-layer model (configuration "B") with batch normalization 174 | 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | """ 178 | if pretrained: 179 | kwargs['init_weights'] = False 180 | model = VGG(make_layers(cfg['B'], filter_size=filter_size, batch_norm=True), **kwargs) 181 | if pretrained: 182 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn'])) 183 | return model 184 | 185 | 186 | def vgg16(pretrained=False, filter_size=1, **kwargs): 187 | """VGG 16-layer model (configuration "D") 188 | 189 | Args: 190 | pretrained (bool): If True, returns a model pre-trained on ImageNet 191 | """ 192 | if pretrained: 193 | kwargs['init_weights'] = False 194 | model = VGG(make_layers(cfg['D'], filter_size=filter_size), **kwargs) 195 | if pretrained: 196 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16'])) 197 | return model 198 | 199 | 200 | def vgg16_bn(pretrained=False, filter_size=1, **kwargs): 201 | """VGG 16-layer model (configuration "D") with batch normalization 202 | 203 | Args: 204 | pretrained (bool): If True, returns a model pre-trained on ImageNet 205 | """ 206 | if pretrained: 207 | kwargs['init_weights'] = False 208 | model = VGG(make_layers(cfg['D'], filter_size=filter_size, batch_norm=True), **kwargs) 209 | if pretrained: 210 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn'])) 211 | return model 212 | 213 | 214 | def vgg19(pretrained=False, filter_size=1, **kwargs): 215 | """VGG 19-layer model (configuration "E") 216 | 217 | Args: 218 | pretrained (bool): If True, returns a model pre-trained on ImageNet 219 | """ 220 | if pretrained: 221 | kwargs['init_weights'] = False 222 | model = VGG(make_layers(cfg['E'], filter_size=filter_size), **kwargs) 223 | if pretrained: 224 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) 225 | return model 226 | 227 | 228 | def vgg19_bn(pretrained=False, filter_size=1, **kwargs): 229 | """VGG 19-layer model (configuration 'E') with batch normalization 230 | 231 | Args: 232 | pretrained (bool): If True, returns a model pre-trained on ImageNet 233 | """ 234 | if pretrained: 235 | kwargs['init_weights'] = False 236 | model = VGG(make_layers(cfg['E'], filter_size=filter_size, batch_norm=True), **kwargs) 237 | if pretrained: 238 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn'])) 239 | return model 240 | 241 | -------------------------------------------------------------------------------- /code/render.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import argparse 5 | from config import cfg 6 | from utils import * 7 | from modeling import build_model 8 | import cv2 9 | from tqdm import tqdm 10 | 11 | from imageio_ffmpeg import write_frames 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument( 16 | "--config", default=None, type=str, help="config file path." 17 | ) 18 | parser.add_argument( 19 | "--dataset", default=None, type=str, help="dataset path." 20 | ) 21 | parser.add_argument( 22 | "--model", default=None, type=str, help="checkpoint model path." 23 | ) 24 | parser.add_argument( 25 | "--output_path", default='./out', type=str, help="image / videos output path." 26 | ) 27 | parser.add_argument( 28 | "--render_video", action='store_true', default=False, help="render around view videos." 29 | ) 30 | parser.add_argument( 31 | "--camera_path", default=None, type=str, help="path to cameras sequence for rendering around-view videos." 32 | ) 33 | 34 | 35 | 36 | args = parser.parse_args() 37 | 38 | os.makedirs(args.output_path, exist_ok=True) 39 | 40 | config_path = args.config 41 | dataset_path = args.dataset 42 | model_path = args.model 43 | 44 | # load config 45 | cfg.merge_from_file(config_path) 46 | tree_depth = cfg.MODEL.TREE_DEPTH 47 | img_size = cfg.INPUT.SIZE_TEST 48 | 49 | # load canonical volume 50 | coords = torch.load(os.path.join(dataset_path, 'volumes/coords_init.pth'), map_location='cpu').float().cuda() 51 | skeleton_init = torch.from_numpy( 52 | campose_to_extrinsic(np.loadtxt(os.path.join(dataset_path, 'bones/Bones_%04d.inf' % 0)))).float() 53 | bone_parents = torch.from_numpy(np.load(os.path.join(dataset_path, 'bones/bone_parents.npy'))).long() 54 | pose_init = get_local_eular(skeleton_init, bone_parents) 55 | 56 | skinning_weights = torch.load(os.path.join(dataset_path, 'volumes/volume_weights.pth'), 57 | map_location='cpu').float().cuda() 58 | joint_index = torch.load(os.path.join(dataset_path, 'volumes/volume_indices.pth'), 59 | map_location='cpu').int().cuda() 60 | volume_radius = torch.from_numpy(np.loadtxt(os.path.join(dataset_path, 'volumes/radius.txt'))).float() 61 | 62 | # load cameras 63 | camposes = np.loadtxt(os.path.join(dataset_path, 'CamPose.inf')) 64 | Ts = torch.Tensor(campose_to_extrinsic(camposes)) 65 | center = torch.mean(coords, dim=0).cpu() 66 | Ks = torch.Tensor(read_intrinsics(os.path.join(dataset_path, 'Intrinsic.inf'))) 67 | K = Ks[0] 68 | K[:2, :3] /= 2 69 | 70 | # load motions 71 | sequences = [] 72 | seq_nums = [] 73 | with open(os.path.join(dataset_path, 'sequences'), 'r') as f: 74 | lines = f.readlines() 75 | for line in lines: 76 | seq, num = line.strip().split(' ') 77 | sequences.append(seq) 78 | seq_nums.append(int(num)) 79 | 80 | pose_features = [] 81 | bone_matrices = [] 82 | skeletons = [] 83 | seq_trees = [] 84 | 85 | pose_feature_pad = torch.cat([skeleton_init[:, :3, 3], pose_init], dim=1) 86 | 87 | for seq_id, seq in enumerate(sequences): 88 | seq_num = seq_nums[seq_id] 89 | skeleton, pose_feature_s, bone_matrice = [], [], [] 90 | seq_trees.append([None] * seq_num) 91 | for i in range(seq_num): 92 | pose = torch.from_numpy( 93 | campose_to_extrinsic(np.loadtxt( 94 | os.path.join(dataset_path, 'bones/%s/Bones_%04d.inf' % (seq, i))))).float() 95 | matrices = torch.matmul(skeleton_init, torch.inverse(pose)) 96 | skeleton.append(pose.unsqueeze(0)) 97 | bone_matrice.append(matrices.unsqueeze(0)) 98 | 99 | pose = get_local_eular(pose, bone_parents) 100 | delta_pose = pose - pose_init 101 | pose_feature = torch.cat([pose_feature_pad, delta_pose], dim=1) 102 | if torch.any(torch.isnan(pose_feature)): 103 | pose_feature = torch.where(torch.isnan(pose_feature), torch.zeros_like(pose_feature).float(), 104 | pose_feature) 105 | 106 | pose_feature_s.append(pose_feature.unsqueeze(0)) 107 | 108 | skeleton = torch.cat(skeleton, dim=0).float() 109 | pose_feature_s = torch.cat(pose_feature_s, dim=0).float() 110 | bone_matrice = torch.cat(bone_matrice, dim=0).float() 111 | 112 | skeletons.append(skeleton.cuda()) 113 | pose_features.append(pose_feature_s.cuda()) 114 | bone_matrices.append(bone_matrice.cuda()) 115 | 116 | # load models 117 | model = build_model(cfg, flut_size=coords.shape[0], bone_feature_dim=pose_features[0].shape[2]).cuda() 118 | checkpoint = torch.load(model_path, map_location='cpu') 119 | model.load_state_dict(checkpoint['model']) 120 | model.eval() 121 | 122 | # render 123 | if not args.render_video: 124 | seq_id = 0 125 | frame_id = 0 126 | cam_id = 0 127 | cur_T = Ts[cam_id] 128 | if seq_trees[seq_id][frame_id] is None: 129 | transformation_matrices = torch.inverse(bone_matrices[seq_id][frame_id]) 130 | skeleton = skeletons[seq_id][frame_id][:, :3, 3] 131 | t, _, _ = warp_build_octree(coords, transformation_matrices=transformation_matrices, 132 | skeleton=skeleton, vb_weights=skinning_weights, 133 | vb_indices=joint_index, radius=volume_radius, 134 | max_depth=tree_depth) 135 | seq_trees[seq_id][frame_id] = t 136 | 137 | rgb, mask, feature = render_image(cfg, model, K, cur_T, (img_size[1], img_size[0]), 138 | tree=seq_trees[seq_id][frame_id], 139 | matrices=bone_matrices[seq_id][frame_id], 140 | joint_features=pose_features[seq_id][frame_id], 141 | bg=None, skinning_weights=skinning_weights, 142 | joint_index=joint_index) 143 | 144 | cv2.imwrite('rgb.jpg', cv2.cvtColor(rgb * 255, cv2.COLOR_BGR2RGB)) 145 | cv2.imwrite('feature.jpg', feature * 255) 146 | else: 147 | assert args.camera_path is not None, 'Please provide cameras trajectory.' 148 | camposes = np.loadtxt(os.path.join(args.camera_path, 'CamPose_spiral.inf')) 149 | Ts = torch.Tensor(campose_to_extrinsic(camposes)) 150 | center = torch.mean(coords, dim=0).cpu() 151 | Ks = torch.Tensor(read_intrinsics(os.path.join(args.camera_path, 'Intrinsic_spiral.inf'))) 152 | 153 | for seq_id in range(len(sequences)): 154 | 155 | writer_raw_rgb = write_frames(os.path.join(args.output_path, '%s_rgb.mp4' % sequences[seq_id]), img_size, fps=30, macro_block_size=8, quality=6) # size is (width, height) 156 | writer_raw_alpha = write_frames(os.path.join(args.output_path, '%s_alpha.mp4' % sequences[seq_id]), img_size, fps=30, macro_block_size=8, quality=6) # size is (width, height) 157 | writer_raw_feature = write_frames(os.path.join(args.output_path, '%s_feature.mp4' % sequences[seq_id]), img_size, fps=30, macro_block_size=8, quality=6) # size is (width, height) 158 | writer_raw_rgb.send(None) 159 | writer_raw_alpha.send(None) 160 | writer_raw_feature.send(None) 161 | 162 | for cam_id in tqdm(range(Ts.shape[0]), unit=" frame", desc=f"Rendering video"): 163 | frame_id = cam_id % seq_nums[seq_id] 164 | 165 | if seq_trees[seq_id][frame_id] is None: 166 | transformation_matrices = torch.inverse(bone_matrices[seq_id][frame_id]) 167 | skeleton = skeletons[seq_id][frame_id][:, :3, 3] 168 | t, wtime, btime = warp_build_octree(coords, transformation_matrices=transformation_matrices, 169 | skeleton=skeleton, vb_weights=skinning_weights, 170 | vb_indices=joint_index, radius=volume_radius, 171 | max_depth=tree_depth) 172 | # cache the tree for faster rendering if gpu memory is enough 173 | # seq_trees[seq_id][frame_id] = t 174 | else: 175 | t = seq_trees[seq_id][frame_id] 176 | 177 | rgb, mask, feature = render_image(cfg, model, Ks[cam_id], Ts[cam_id], (img_size[1], img_size[0]), 178 | tree=seq_trees[seq_id][frame_id], 179 | matrices=bone_matrices[seq_id][frame_id], 180 | joint_features=pose_features[seq_id][frame_id], 181 | bg=None, skinning_weights=skinning_weights, 182 | joint_index=joint_index) 183 | 184 | img = rgb * 255 185 | feature = feature * 255 186 | alpha = torch.from_numpy(mask).unsqueeze(-1).repeat(1, 1, 3).numpy() * 255 187 | img = img.copy(order='C') 188 | writer_raw_rgb.send(img.astype(np.uint8)) 189 | writer_raw_alpha.send(alpha.astype(np.uint8)) 190 | writer_raw_feature.send(feature.astype(np.uint8)) 191 | 192 | cv2.imwrite(os.path.join(args.output_path, '%s_rgb_%04d.jpg' % (sequences[seq_id], cam_id)), 193 | cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) 194 | cv2.imwrite(os.path.join(args.output_path, '%s_alpha_%04d.jpg' % (sequences[seq_id], cam_id)), alpha) 195 | cv2.imwrite(os.path.join(args.output_path, '%s_feature_%04d.jpg' % (sequences[seq_id], cam_id)), feature) 196 | 197 | writer_raw_rgb.close() 198 | writer_raw_alpha.close() 199 | writer_raw_feature.close() 200 | -------------------------------------------------------------------------------- /code/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | from .ray_sampling import ray_sampling, patch_sampling 8 | from .spherical_harmonics import computeRGB 9 | from .llinear_transform import compute_skinning_weights, compute_transformation, transform_coords 10 | from .build_octree import build_octree, load_octree, generate_transformation_matrices 11 | from .bone_parsing import get_local_transformation, get_local_eular, rotation_matrix_to_eular, eular_to_rotation_matrix 12 | from .utils import * 13 | from .rendering import * -------------------------------------------------------------------------------- /code/utils/bone_parsing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from pytorch3d.transforms import so3_log_map, so3_rotation_angle, so3_relative_angle, matrix_to_euler_angles, euler_angles_to_matrix 5 | 6 | 7 | def get_local_transformation(transformations, parents): 8 | parents_t = torch.cat([torch.eye(4).unsqueeze(0), transformations[parents[1:]]]) 9 | local_t = torch.matmul(torch.inverse(parents_t), transformations) 10 | 11 | return local_t 12 | 13 | 14 | def get_local_eular(transformations, parents, axis='XYZ'): 15 | local_t = get_local_transformation(transformations, parents) 16 | local_t = torch.clamp(local_t[:, :3, :3], max=1.) 17 | local_eular = matrix_to_euler_angles(local_t, axis) 18 | 19 | return local_eular 20 | 21 | 22 | def rotation_matrix_to_eular(rotations, axis='XYZ'): 23 | return matrix_to_euler_angles(rotations[:, :3, :3], axis) 24 | 25 | 26 | def eular_to_rotation_matrix(eular, axis='XYZ'): 27 | return euler_angles_to_matrix(eular, axis) 28 | 29 | -------------------------------------------------------------------------------- /code/utils/build_octree.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | 5 | import svox_t 6 | 7 | def build_octree(coordinates, radius=None, center=None, skeleton=None, data_dim=91, max_depth=10, data_format='SH9'): 8 | coordinates = coordinates.contiguous().cuda().detach() 9 | 10 | maxs_ = torch.max(coordinates, dim=0).values 11 | mins_ = torch.min(coordinates, dim=0).values 12 | radius_ = (maxs_ - mins_) / 2 13 | center_ = mins_ + radius_ 14 | 15 | if radius is None: 16 | radius = radius_ 17 | 18 | if center is None: 19 | center = center_ 20 | # print(center, radius) 21 | t = svox_t.N3Tree(N=2, 22 | data_dim=data_dim, 23 | depth_limit=max_depth, 24 | init_reserve=5000, 25 | init_refine=0, 26 | # geom_resize_fact=1.5, 27 | radius=list(radius), 28 | center=list(center), 29 | data_format=data_format, 30 | extra_data=skeleton, ).cuda() 31 | 32 | for i in range(max_depth): 33 | t[coordinates].refine() 34 | 35 | t.shrink_to_fit() 36 | t.construct_tree(coordinates) 37 | 38 | return t.cpu() 39 | 40 | def load_octree(tree_path): 41 | return svox_t.N3Tree.load(tree_path) 42 | 43 | def generate_transformation_matrices(matrices, skinning_weights, joint_index): 44 | return svox_t.blend_transformation_matrix(matrices, skinning_weights, joint_index) -------------------------------------------------------------------------------- /code/utils/llinear_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | 4 | def compute_skinning_weights(pos, tmp_vertices, tmp_weights, n_influence=3, n_binding=10, chunks=10000, debug=False): 5 | if pos.shape[0] > chunks: 6 | coords = pos.split(chunks, dim=0) 7 | else: 8 | coords = [pos] 9 | 10 | weights = [] 11 | 12 | for coord in coords: 13 | dis = coord[:, None, :] - tmp_vertices 14 | dis = dis.pow(2).sum(-1).sqrt() 15 | dis_min, indice = torch.topk(dis, n_influence, largest=False, dim=1) 16 | w = dis_min.max(-1)[0][..., None] - dis_min 17 | w = torch.softmax(w, dim=-1) 18 | weight = torch.sum(tmp_weights[indice] * w[..., None], dim=1) 19 | weights.append(weight) 20 | weights = torch.cat(weights, dim=0) 21 | 22 | weights, indices = torch.topk(weights, n_binding, largest=True, dim=1) 23 | weights = weights / weights.sum(-1).unsqueeze(-1) 24 | 25 | return weights, indices 26 | 27 | 28 | 29 | def compute_transformation(src_pose, dst_pose, weights, indices, chunks=10000): 30 | if weights is None or indices is None: 31 | print('Please provide skinning weights.') 32 | return None 33 | 34 | transformation_matrices = [] 35 | if weights.shape[0] > chunks: 36 | weights, indices = weights.split(chunks, dim=0), indices.split(chunks, dim=0) 37 | else: 38 | weights, indices = [weights], [indices] 39 | 40 | for i in range(len(weights)): 41 | weight, indice = weights[i], indices[i] 42 | src_matrices, dst_matricrs = src_pose[indice], dst_pose[indice] 43 | transformation_matrix = torch.matmul(dst_matricrs, src_matrices) 44 | transformation_matrix = torch.sum(transformation_matrix * weight[..., None, None], dim=1) 45 | 46 | transformation_matrices.append(transformation_matrix) 47 | 48 | transformation_matrices = torch.cat(transformation_matrices, dim=0) 49 | 50 | return transformation_matrices.float() 51 | 52 | 53 | def transform_coords(coords, matrices, chunks=10000): 54 | if coords.shape[0] > chunks: 55 | coords, matrices = coords.split(chunks, dim=0), matrices.split(chunks, dim=0) 56 | else: 57 | coords, matrices = [coords], [matrices] 58 | 59 | transformed_coords = [] 60 | 61 | for i in range(len(coords)): 62 | coord, ms = coords[i], matrices[i] 63 | coord = coord.unsqueeze(-1) 64 | coord = torch.cat([coord, torch.ones((coord.size(0), 1, coord.size(2)), device=coord.device)], dim=1) 65 | coord = torch.matmul(ms, coord) 66 | 67 | transformed_coords.append(coord) 68 | 69 | transformed_coords = torch.cat(transformed_coords, dim=0).squeeze()[:, :3] 70 | 71 | return transformed_coords 72 | 73 | 74 | -------------------------------------------------------------------------------- /code/utils/logger.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | """ 3 | @author: sherlock 4 | @contact: sherlockliao01@gmail.com 5 | """ 6 | 7 | import logging 8 | import os 9 | import sys 10 | 11 | 12 | def setup_logger(name, save_dir, distributed_rank): 13 | logger = logging.getLogger(name) 14 | logger.setLevel(logging.DEBUG) 15 | # don't log results for the non-master process 16 | if distributed_rank > 0: 17 | return logger 18 | ch = logging.StreamHandler(stream=sys.stdout) 19 | ch.setLevel(logging.DEBUG) 20 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 21 | ch.setFormatter(formatter) 22 | logger.addHandler(ch) 23 | 24 | if save_dir: 25 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w') 26 | fh.setLevel(logging.DEBUG) 27 | fh.setFormatter(formatter) 28 | logger.addHandler(fh) 29 | 30 | return logger 31 | -------------------------------------------------------------------------------- /code/utils/ray_sampling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | 4 | ''' 5 | Sample rays from views (and images) with/without masks 6 | 7 | -------------------------- 8 | INPUT Tensors 9 | Ks: intrinsics of cameras (M,3,3) 10 | Ts: extrinsic of cameras (M,4,4) 11 | image_size: the size of image [H,W] 12 | images: (M,C,H,W) 13 | mask_threshold: a float threshold to mask rays 14 | masks:(M,H,W) 15 | depth:(M,H,W) 16 | ------------------- 17 | OUPUT: 18 | list of rays: (N,6) dirs(3) + pos(3) 19 | RGB: (N,C) 20 | ''' 21 | 22 | 23 | def ray_sampling(Ks, Ts, image_size, masks=None, mask_threshold=0.5, images=None, depth=None, 24 | far_depth=None, fine_depth=None): 25 | h = image_size[0] 26 | w = image_size[1] 27 | M = Ks.size(0) 28 | 29 | x = torch.linspace(0, h - 1, steps=h, device=Ks.device) 30 | y = torch.linspace(0, w - 1, steps=w, device=Ks.device) 31 | 32 | grid_x, grid_y = torch.meshgrid(x, y) 33 | coordinates = torch.stack([grid_y, grid_x]).unsqueeze(0).repeat(M, 1, 1, 1) # (M,2,H,W) 34 | coordinates = torch.cat([coordinates, torch.ones(coordinates.size(0), 1, coordinates.size(2), 35 | coordinates.size(3), device=Ks.device)], dim=1).permute(0, 2, 3, 36 | 1).unsqueeze( 37 | -1) 38 | 39 | inv_Ks = torch.inverse(Ks) 40 | 41 | dirs = torch.matmul(inv_Ks, coordinates) # (M,H,W,3,1) 42 | dirs = dirs / torch.norm(dirs, dim=3, keepdim=True) 43 | 44 | # zs to depth 45 | td = dirs.view(M, h, w, -1) 46 | if depth is not None: 47 | depth = depth.to(Ks.device) 48 | depth = depth / td[:, :, :, 2] 49 | if far_depth is not None: 50 | far_depth = far_depth.to(Ks.device) 51 | far_depth = far_depth / td[:, :, :, 2] 52 | else: 53 | far_depth = depth 54 | if fine_depth is not None: 55 | fine_depth = fine_depth.to(Ks.device) 56 | fine_depth = fine_depth / td[:, :, :, 2] 57 | else: 58 | fine_depth = depth 59 | 60 | dirs = torch.cat([dirs, torch.zeros(dirs.size(0), coordinates.size(1), 61 | coordinates.size(2), 1, 1, device=Ks.device)], dim=3) # (M,H,W,4,1) 62 | 63 | dirs = torch.matmul(Ts, dirs) # (M,H,W,4,1) 64 | dirs = dirs[:, :, :, 0:3, 0] # (M,H,W,3) 65 | 66 | pos = Ts[:, 0:3, 3] # (M,3) 67 | pos = pos.unsqueeze(1).unsqueeze(1).repeat(1, h, w, 1) 68 | 69 | rays = torch.cat([dirs, pos], dim=3) # (M,H,W,6) 70 | if depth is not None: 71 | depth = torch.stack([depth, far_depth, fine_depth], dim=3) # (M,H,W,2) 72 | ds = None 73 | 74 | if images is not None: 75 | images = images.to(Ks.device) 76 | rgbs = images.permute(0, 2, 3, 1) # (M,H,W,C) 77 | else: 78 | rgbs = None 79 | 80 | if masks is not None: 81 | rays = rays[masks > mask_threshold, :] 82 | if rgbs is not None: 83 | rgbs = rgbs[masks > mask_threshold, :] 84 | if depth is not None: 85 | ds = depth[masks > mask_threshold, :] 86 | else: 87 | rays = rays.reshape((-1, rays.size(3))) 88 | if rgbs is not None: 89 | rgbs = rgbs.reshape((-1, rgbs.size(3))) 90 | if depth is not None: 91 | ds = depth.reshape((-1, depth.size(3))) 92 | 93 | return rays, rgbs, ds 94 | 95 | 96 | def patch_sampling(Ks, Ts, image_size, patch_size=32, images=None, depth=None, far_depth=None, fine_depth=None, 97 | alpha=None, keep_bg=False): 98 | h = image_size[0] 99 | w = image_size[1] 100 | M = Ks.size(0) 101 | 102 | h_num, w_num = math.ceil(h / patch_size), math.ceil(w / patch_size) 103 | ray_patches, rgb_patches, depth_patches, alpha_patches = [], [], [], [] 104 | 105 | x = torch.linspace(0, h - 1, steps=h, device=Ks.device) 106 | y = torch.linspace(0, w - 1, steps=w, device=Ks.device) 107 | 108 | grid_x, grid_y = torch.meshgrid(x, y) 109 | coordinates = torch.stack([grid_y, grid_x]).unsqueeze(0).repeat(M, 1, 1, 1) # (M,2,H,W) 110 | coordinates = torch.cat([coordinates, torch.ones(coordinates.size(0), 1, coordinates.size(2), 111 | coordinates.size(3), device=Ks.device)], dim=1).permute(0, 2, 3, 112 | 1).unsqueeze( 113 | -1) 114 | 115 | inv_Ks = torch.inverse(Ks) 116 | 117 | dirs = torch.matmul(inv_Ks, coordinates) # (M,H,W,3,1) 118 | dirs = dirs / torch.norm(dirs, dim=3, keepdim=True) 119 | 120 | # zs to depth 121 | td = dirs.view(M, h, w, -1) 122 | if depth is not None: 123 | depth = depth / td[:, :, :, 2] 124 | if far_depth is not None: 125 | far_depth = far_depth / td[:, :, :, 2] 126 | else: 127 | far_depth = depth 128 | if fine_depth is not None: 129 | fine_depth = fine_depth / td[:, :, :, 2] 130 | else: 131 | fine_depth = depth 132 | 133 | dirs = torch.cat([dirs, torch.zeros(dirs.size(0), coordinates.size(1), 134 | coordinates.size(2), 1, 1, device=Ks.device)], dim=3) # (M,H,W,4,1) 135 | 136 | dirs = torch.matmul(Ts, dirs) # (M,H,W,4,1) 137 | dirs = dirs[:, :, :, 0:3, 0] # (M,H,W,3) 138 | 139 | pos = Ts[:, 0:3, 3] # (M,3) 140 | pos = pos.unsqueeze(1).unsqueeze(1).repeat(1, h, w, 1) 141 | 142 | rays = torch.cat([dirs, pos], dim=3) # (M,H,W,6) 143 | 144 | if images is not None: 145 | rgbs = images.permute(0, 2, 3, 1) # (M,H,W,C) 146 | else: 147 | rgbs = None 148 | 149 | depth = torch.stack([depth, far_depth, fine_depth], dim=3) # (M,H,W,2) 150 | 151 | padded_h, padded_w = h_num * patch_size, w_num * patch_size 152 | ray_padded, rgb_padded, depth_padded, alpha_padded = torch.zeros([rays.size(0), padded_h, padded_w, rays.size(3)], 153 | device=rays.device), \ 154 | torch.zeros([rgbs.size(0), padded_h, padded_w, rgbs.size(3)], 155 | device=rgbs.device), \ 156 | torch.zeros([depth.size(0), padded_h, padded_w, depth.size(3)], 157 | device=depth.device), \ 158 | torch.zeros([alpha.size(0), padded_h, padded_w], 159 | device=alpha.device) 160 | ray_padded[:, :h, :w, :] = rays 161 | rgb_padded[:, :h, :w, :] = rgbs 162 | depth_padded[:, :h, :w, :] = depth 163 | alpha_padded[:, :h, :w] = alpha 164 | 165 | rays = ray_padded.view(int(padded_h / patch_size), patch_size, int(padded_w / patch_size), patch_size, -1).permute(0, 2, 1, 3, 4).reshape(-1, patch_size, patch_size, ray_padded.shape[-1]) 166 | rgbs = rgb_padded.view(int(padded_h / patch_size), patch_size, int(padded_w / patch_size), patch_size, -1).permute(0, 2, 1, 3, 4).reshape(-1, patch_size, patch_size, rgb_padded.shape[-1]) 167 | ds = depth_padded.view(int(padded_h / patch_size), patch_size, int(padded_w / patch_size), patch_size, -1).permute(0, 2, 1, 3, 4).reshape(-1, patch_size, patch_size, depth_padded.shape[-1]) 168 | als = alpha_padded.view(int(padded_h / patch_size), patch_size, int(padded_w / patch_size), patch_size).permute(0, 2, 1, 3).reshape(-1, patch_size, patch_size) 169 | 170 | mask = ds.reshape(ds.shape[0], -1).sum(1) > 0 171 | 172 | return rays[mask], rgbs[mask], ds[mask], als[mask] 173 | -------------------------------------------------------------------------------- /code/utils/rendering.py: -------------------------------------------------------------------------------- 1 | from utils import ray_sampling 2 | import collections 3 | import time 4 | import torch 5 | import svox_t 6 | from utils import generate_transformation_matrices 7 | import numpy as np 8 | import math 9 | 10 | Sh_Rays = collections.namedtuple('Sh_Rays', 'origins dirs viewdirs') 11 | 12 | 13 | def render_image(cfg, model, K, T, img_size=(450, 800), tree=None, matrices=None, joint_features=None, 14 | bg=None, skinning_weights=None, joint_index=None): 15 | torch.cuda.synchronize() 16 | s = time.time() 17 | h, w = img_size[0], img_size[1] 18 | rays, _, _ = ray_sampling(K.unsqueeze(0).cuda(), T.unsqueeze(0).cuda(), img_size) 19 | 20 | with torch.no_grad(): 21 | joint_features = None if not cfg.MODEL.USE_MOTION else joint_features 22 | matrices = generate_transformation_matrices(matrices=matrices, skinning_weights=skinning_weights, 23 | joint_index=joint_index) 24 | 25 | with torch.cuda.amp.autocast(enabled=False): 26 | features = model.tree_renderer(tree, rays, matrices).reshape(1, h, w, -1).permute(0, 3, 1, 2) 27 | 28 | if cfg.MODEL.USE_MOTION: 29 | motion_feature = model.tree_renderer.motion_feature_render(tree, joint_features, skinning_weights, 30 | joint_index, 31 | rays) 32 | 33 | 34 | motion_feature = motion_feature.reshape(1, h, w, -1).permute(0, 3, 1, 2) 35 | else: 36 | motion_feature = features[:, :9, ...] 37 | 38 | with torch.cuda.amp.autocast(enabled=True): 39 | features_in = features[:, :-1, ...] 40 | if cfg.MODEL.USE_MOTION: 41 | features_in = torch.cat([features[:, :-1, ...], motion_feature], dim=1) 42 | rgba_out = model.render_net(features_in, features[:, -1:, ...]) 43 | 44 | rgba_volume = torch.cat([features[:, :3, ...], features[:, -1:, ...]], dim=1) 45 | 46 | rgb = rgba_out[0, :-1, ...] 47 | alpha = rgba_out[0, -1:, ...] 48 | img_volume = rgba_volume[0, :3, ...].permute(1, 2, 0) 49 | 50 | if model.use_render_net: 51 | rgb = torch.nn.Hardtanh()(rgb) 52 | rgb = (rgb + 1) / 2 53 | 54 | alpha = torch.nn.Hardtanh()(alpha) 55 | alpha = (alpha + 1) / 2 56 | alpha = torch.clamp(alpha, min=0, max=1.) 57 | 58 | if bg is not None: 59 | if bg.max() > 1: 60 | bg = bg / 255 61 | comp_img = rgb * alpha + (1 - alpha) * bg 62 | else: 63 | comp_img = rgb * alpha + (1 - alpha) 64 | 65 | img_unet = comp_img.permute(1, 2, 0).float().cpu().numpy() 66 | 67 | return img_unet, alpha.squeeze().float().detach().cpu().numpy(), img_volume.float().detach().cpu().numpy() 68 | 69 | 70 | def build_octree(coordinates, radius=None, center=None, skeleton=None, data_dim=91, 71 | max_depth=10, data_format='SH9'): 72 | coordinates = coordinates.contiguous().cuda().detach() 73 | 74 | maxs_ = torch.max(coordinates, dim=0).values 75 | mins_ = torch.min(coordinates, dim=0).values 76 | radius_ = (maxs_ - mins_) / 2 77 | center_ = mins_ + radius_ 78 | 79 | if radius is None: 80 | radius = radius_ 81 | 82 | if center is None: 83 | center = center_ 84 | 85 | t = svox_t.N3Tree(N=2, 86 | data_dim=data_dim, 87 | depth_limit=max_depth, 88 | init_reserve=30000, 89 | init_refine=0, 90 | # geom_resize_fact=1.5, 91 | radius=list(radius), 92 | center=list(center), 93 | data_format=data_format, 94 | extra_data=skeleton.contiguous(), 95 | map_location=coordinates.device) 96 | 97 | for i in range(max_depth): 98 | t[coordinates].refine() 99 | 100 | t.shrink_to_fit() 101 | t.construct_tree(coordinates) 102 | 103 | return t 104 | 105 | 106 | def warp_build_octree(coords, transformation_matrices, skeleton=None, vb_weights=None, vb_indices=None, radius=None, 107 | max_depth=8): 108 | torch.cuda.synchronize() 109 | s = time.time() 110 | vs, ms = svox_t.warp_vertices(transformation_matrices, coords, vb_weights, vb_indices) 111 | torch.cuda.synchronize() 112 | wtime = time.time() - s 113 | 114 | torch.cuda.synchronize() 115 | s = time.time() 116 | t = build_octree(vs, radius=radius, max_depth=max_depth, skeleton=skeleton) 117 | torch.cuda.synchronize() 118 | btime = time.time() - s 119 | 120 | return t, wtime, btime 121 | 122 | def rodrigues_rotation_matrix(axis, theta): 123 | axis = np.asarray(axis) 124 | theta = np.asarray(theta) 125 | axis = axis/math.sqrt(np.dot(axis, axis)) 126 | a = math.cos(theta/2.0) 127 | b, c, d = -axis*math.sin(theta/2.0) 128 | aa, bb, cc, dd = a*a, b*b, c*c, d*d 129 | bc, ad, ac, ab, bd, cd = b*c, a*d, a*c, a*b, b*d, c*d 130 | return np.array([[aa+bb-cc-dd, 2*(bc+ad), 2*(bd-ac)], 131 | [2*(bc-ad), aa+cc-bb-dd, 2*(cd+ab)], 132 | [2*(bd+ac), 2*(cd-ab), aa+dd-bb-cc]]) 133 | -------------------------------------------------------------------------------- /code/utils/spherical_harmonics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def ComputeSH(dirs): 5 | ''' 6 | dirs: b*3 7 | ''' 8 | x = dirs[..., 0] 9 | y = dirs[..., 1] 10 | z = dirs[..., 2] 11 | xx = dirs[..., 0] ** 2 12 | yy = dirs[..., 1] ** 2 13 | zz = dirs[..., 2] ** 2 14 | 15 | xy = dirs[..., 0] * dirs[..., 1] 16 | yz = dirs[..., 1] * dirs[..., 2] 17 | xz = dirs[..., 0] * dirs[..., 2] 18 | 19 | sh = torch.zeros((dirs.shape[0], 25)).to(dirs.device) 20 | 21 | sh[:, 0] = 0.282095 22 | 23 | sh[:, 1] = -0.4886025119029199 * y 24 | sh[:, 2] = 0.4886025119029199 * z 25 | sh[:, 3] = -0.4886025119029199 * x 26 | 27 | sh[:, 4] = 1.0925484305920792 * xy 28 | sh[:, 5] = -1.0925484305920792 * yz 29 | sh[:, 6] = 0.31539156525252005 * (2.0 * zz - xx - yy) 30 | sh[:, 7] = -1.0925484305920792 * xz 31 | # sh2p2 32 | sh[:, 8] = 0.5462742152960396 * (xx - yy) 33 | 34 | sh[:, 9] = -0.5900435899266435 * y * (3 * xx - yy) 35 | sh[:, 10] = 2.890611442640554 * xy * z 36 | sh[:, 11] = -0.4570457994644658 * y * (4 * zz - xx - yy) 37 | sh[:, 12] = 0.3731763325901154 * z * (2 * zz - 3 * xx - 3 * yy) 38 | sh[:, 13] = -0.4570457994644658 * x * (4 * zz - xx - yy) 39 | sh[:, 14] = 1.445305721320277 * z * (xx - yy) 40 | sh[:, 15] = -0.5900435899266435 * x * (xx - 3 * yy) 41 | 42 | sh[:, 16] = 2.5033429417967046 * xy * (xx - yy) 43 | sh[:, 17] = -1.7701307697799304 * yz * (3 * xx - yy) 44 | sh[:, 18] = 0.9461746957575601 * xy * (7 * zz - 1.0) 45 | sh[:, 19] = -0.6690465435572892 * yz * (7 * zz - 3.0) 46 | sh[:, 20] = 0.10578554691520431 * (zz * (35 * zz - 30) + 3) 47 | sh[:, 21] = -0.6690465435572892 * xz * (7 * zz - 3) 48 | sh[:, 22] = 0.47308734787878004 * (xx - yy) * (7 * zz - 1.0) 49 | sh[:, 23] = -1.7701307697799304 * xz * (xx - 3 * yy) 50 | sh[:, 24] = 0.6258357354491761 * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) 51 | 52 | return sh 53 | 54 | 55 | def computeRGB(dirs, coeff): 56 | ''' 57 | dirs: n * 3 58 | coeff: n * n_components * f 59 | ''' 60 | n_components = coeff.shape[1] 61 | 62 | return (ComputeSH(dirs)[..., :n_components].unsqueeze(-1) * coeff).sum(1) -------------------------------------------------------------------------------- /code/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from pytorch3d.transforms import so3_log_map, so3_rotation_angle, so3_relative_angle, matrix_to_euler_angles, euler_angles_to_matrix 5 | 6 | 7 | def get_local_transformation(transformations, parents): 8 | parents_t = torch.cat([torch.eye(4).unsqueeze(0), transformations[parents[1:]]]) 9 | local_t = torch.matmul(torch.inverse(parents_t), transformations) 10 | 11 | return local_t 12 | 13 | 14 | def get_local_eular(transformations, parents, axis='XYZ'): 15 | local_t = get_local_transformation(transformations, parents) 16 | local_t = torch.clamp(local_t[:, :3, :3], max=1.) 17 | local_eular = matrix_to_euler_angles(local_t, axis) 18 | 19 | return local_eular 20 | 21 | 22 | def rotation_matrix_to_eular(rotations, axis='XYZ'): 23 | return matrix_to_euler_angles(rotations[:, :3, :3], axis) 24 | 25 | 26 | def eular_to_rotation_matrix(eular, axis='XYZ'): 27 | return euler_angles_to_matrix(eular, axis) 28 | 29 | 30 | def campose_to_extrinsic(camposes): 31 | if camposes.shape[1] != 12: 32 | raise Exception(" wrong campose data structure!") 33 | return 34 | 35 | res = np.zeros((camposes.shape[0], 4, 4)) 36 | 37 | res[:, 0:3, 2] = camposes[:, 0:3] 38 | res[:, 0:3, 0] = camposes[:, 3:6] 39 | res[:, 0:3, 1] = camposes[:, 6:9] 40 | res[:, 0:3, 3] = camposes[:, 9:12] 41 | res[:, 3, 3] = 1.0 42 | 43 | return res 44 | 45 | 46 | def read_intrinsics(fn_instrinsic): 47 | fo = open(fn_instrinsic) 48 | data = fo.readlines() 49 | i = 0 50 | Ks = [] 51 | while i < len(data): 52 | if len(data[i]) > 6: 53 | tmp = data[i].split() 54 | tmp = [float(i) for i in tmp] 55 | a = np.array(tmp) 56 | i = i + 1 57 | tmp = data[i].split() 58 | tmp = [float(i) for i in tmp] 59 | b = np.array(tmp) 60 | i = i + 1 61 | tmp = data[i].split() 62 | tmp = [float(i) for i in tmp] 63 | c = np.array(tmp) 64 | res = np.vstack([a, b, c]) 65 | Ks.append(res) 66 | 67 | i = i + 1 68 | Ks = np.stack(Ks) 69 | fo.close() 70 | 71 | return Ks 72 | -------------------------------------------------------------------------------- /medias/featured1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaiminLuo/Artemis/9a0aee4af849c384b49a1a2af354f47549d05c85/medias/featured1.png -------------------------------------------------------------------------------- /medias/overview_v3-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HaiminLuo/Artemis/9a0aee4af849c384b49a1a2af354f47549d05c85/medias/overview_v3-1.png -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | yacs 2 | numpy 3 | torch 4 | torchvision 5 | pillow 6 | argparse 7 | opencv-python 8 | imageio 9 | imageio-ffmpeg 10 | IPython 11 | tqdm --------------------------------------------------------------------------------