├── README.md
├── code
├── config
│ ├── __init__.py
│ └── defaults.py
├── configs
│ └── configs.yml
├── modeling
│ ├── UNet
│ │ ├── __init__.py
│ │ ├── unet_model.py
│ │ └── unet_parts.py
│ ├── __init__.py
│ ├── model.py
│ └── tree_render.py
├── models_lpf
│ ├── __init__.py
│ ├── alexnet.py
│ ├── densenet.py
│ ├── downsample.py
│ ├── mobilenet.py
│ ├── resnet.py
│ └── vgg.py
├── render.py
└── utils
│ ├── __init__.py
│ ├── bone_parsing.py
│ ├── build_octree.py
│ ├── llinear_transform.py
│ ├── logger.py
│ ├── ray_sampling.py
│ ├── rendering.py
│ ├── spherical_harmonics.py
│ └── utils.py
├── medias
├── featured1.png
└── overview_v3-1.png
└── requirement.txt
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
Artemis: Articulated Neural Pets with Appearance and Motion Synthesis
6 |
7 |
8 |
9 |
10 |
11 | SIGGRAPH, 2022
12 |
13 | Haimin Luo
14 | ·
15 | Teng Xu
16 | ·
17 | Yuheng Jiang
18 | ·
19 | Chenglin Zhou
20 | ·
21 | Qiwei Qiu
22 | ·
23 | Yingliang Zhang
24 | ·
25 | Wei Yang
26 | ·
27 | Lan Xu
28 | ·
29 | Jingyi Yu
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 | This repository contains a pytorch implementation and **D**ynamic **F**urry **A**nimal (DFA) dataset for the paper: [Artemis: Articulated Neural Pets with Appearance and Motion Synthesis](https://arxiv.org/abs/2202.05628). In this paper, we present **ARTEMIS**, a novel neural modeling and rendering pipeline for generating **ART**iculated neural pets with app**E**arance and **M**otion synthes**IS**.
48 |
49 | ## Overview
50 | 
51 |
52 | ## Installation
53 | Create a virtual environment and install requirements as follow
54 | ```
55 | conda create -n artemis python=3.7
56 | conda activate artemis
57 | pip install -r requirement.txt
58 | ```
59 | - Install pytorch3d following the official [installation steps](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md).
60 | - Install cuda based pytorch extension for animatable volumes: [svox_t](https://github.com/HaiminLuo/svox_t).
61 |
62 |
63 | The code is tested on Ubuntu 18.04 + Pytorch 1.12.1.
64 |
65 | ## Dynamic Furry Animals (DFA) Dataset
66 | The DFA dataset can be found at [DFA datasets](https://shanghaitecheducn-my.sharepoint.com/:f:/g/personal/luohm_shanghaitech_edu_cn/Et60lJpJdp5DoyQF7uzP6jgB_JEW4LIHixAyXEiVhHT3Vw?e=d09jtz), which contains multi-view renderings and skeletal motions of 9 high quality CGI furry animals. As described in Artemis, each dataset contains $1920\times1080$ RGBA images rendered under 36 cameras around the dynamic animal.
67 |
68 | The datasets (e.g., cat) are organized as follows:
69 |
70 | ```
71 | cat
72 | ├── img
73 | │ └── run - Motion name.
74 | │ └── %d - The frame number, start from 0.
75 | │ └──img_%04d.jpg - RGB images for each view. view number start from 0.
76 | │ └──img_%04d_alpha.png - Alpha mattes for corresponding RGB image.
77 | │ └── ...
78 | │
79 | ├── volumes
80 | │ └── coords_init.pth - voxel coordinates represent an animal in rest pose.
81 | │ └── volume_indices.pth - The indices of bones to which the voxels are bound.
82 | | └── volume_weights.pth - The skinning weights of the voxels.
83 | | └── radius.txt - The radius of the volume.
84 | |
85 | ├── bones
86 | │ └── run - Motion name.
87 | │ └── Bones_%04d.inf - The skeletal pose for each frame, start from 0. In each row, the 3x4 [R T] matrix is displayed in columns, with the third column followed by columns 1, 2, and 4.
88 | │ └── ...
89 | | └── bones_parents.npy - The parent bone index of each bone.
90 | | └── Bones_0000.inf - The rest pose.
91 | |
92 | ├── CamPose.inf - Camera extrinsics. In each row, the 3x4 [R T] matrix is displayed in columns, with the third column followed by columns 1, 2, and 4, where R*X^{camera}+T=X^{world}.
93 | │
94 | └── Intrinsic.inf - Camera intrinsics. The format of each intrinsics is: "idx \n fx 0 cx \n 0 fy cy \n 0 0 1 \n \n" (idx starts from 0)
95 | │
96 | └── sequences - Motion sequences. The format of each motion is: "motion_name frames\n"
97 | ```
98 |
99 | ## Pre-trained model
100 | Our pre-trained NGI models (e.g., `cat`, `wolf`, `panda`, ...) described in our paper can be found at [pre-trained models](https://shanghaitecheducn-my.sharepoint.com/:f:/g/personal/luohm_shanghaitech_edu_cn/EgJC3IrKA1pPjYTFiO2rVswBW4QVHkV54TqGxcxMIr-Bvw?e=H3xhTG).
101 |
102 |
103 | ## Rendering
104 | We provide a rendering script to illustrate how to load motions, cameras, models and how to animate the NGI animals and render them.
105 |
106 | To render a RGBA image with specified camera view and motion pose:
107 | ```
108 | cd code
109 |
110 | python render.py --config ../model/cat/configs.yml --model ../model/cat/model.pt --dataset ../dataset/cat --output_path ./output
111 | ```
112 |
113 | To render an around-view video with dynamic motion and spiral camera trajectory:
114 | ```
115 | cd code
116 |
117 | python render.py --config ../model/cat/configs.yml --model ../model/cat/model.pt --dataset ../dataset/cat --output_path ./output --render_video --camera_path ../model/cat/
118 | ```
119 |
120 |
121 | ## Citation
122 | If you find our code or paper helps, please consider citing:
123 | ```
124 | @article{10.1145/3528223.3530086,
125 | author = {Luo, Haimin and Xu, Teng and Jiang, Yuheng and Zhou, Chenglin and Qiu, Qiwei and Zhang, Yingliang and Yang, Wei and Xu, Lan and Yu, Jingyi},
126 | title = {Artemis: Articulated Neural Pets with Appearance and Motion Synthesis},
127 | year = {2022},
128 | issue_date = {July 2022},
129 | publisher = {Association for Computing Machinery},
130 | address = {New York, NY, USA},
131 | volume = {41},
132 | number = {4},
133 | issn = {0730-0301},
134 | url = {https://doi.org/10.1145/3528223.3530086},
135 | doi = {10.1145/3528223.3530086},
136 | journal = {ACM Trans. Graph.},
137 | month = {jul},
138 | articleno = {164},
139 | numpages = {19},
140 | keywords = {novel view syntheis, neural rendering, dynamic scene modeling, neural volumetric animal, motion synthesis, neural representation}
141 | }
142 | ```
143 | And also consider citing another related and intresting work for high-quality photo-realistic rendering of real fuzzy objects: [Convolutional Neural Opacity Radiance Fields](https://www.computer.org/csdl/proceedings-article/iccp/2021/09466273/1uSSXDOinlu):
144 | ```
145 | @INPROCEEDINGS {9466273,
146 | author = {H. Luo and A. Chen and Q. Zhang and B. Pang and M. Wu and L. Xu and J. Yu},
147 | booktitle = {2021 IEEE International Conference on Computational Photography (ICCP)},
148 | title = {Convolutional Neural Opacity Radiance Fields},
149 | year = {2021},
150 | volume = {},
151 | issn = {},
152 | pages = {1-12},
153 | keywords = {training;photography;telepresence;image color analysis;computational modeling;entertainment industry;image capture},
154 | doi = {10.1109/ICCP51581.2021.9466273},
155 | url = {https://doi.ieeecomputersociety.org/10.1109/ICCP51581.2021.9466273},
156 | publisher = {IEEE Computer Society},
157 | address = {Los Alamitos, CA, USA},
158 | month = {may}
159 | }
160 | ```
--------------------------------------------------------------------------------
/code/config/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .defaults import _C as cfg
8 |
--------------------------------------------------------------------------------
/code/config/defaults.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | # -----------------------------------------------------------------------------
4 | # Convention about Training / Test specific parameters
5 | # -----------------------------------------------------------------------------
6 | # Whenever an argument can be either used for training or for testing, the
7 | # corresponding name will be post-fixed by a _TRAIN for a training parameter,
8 | # or _TEST for a test-specific parameter.
9 | # For example, the number of images during training will be
10 | # IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
11 | # IMAGES_PER_BATCH_TEST
12 |
13 | # -----------------------------------------------------------------------------
14 | # Config definition
15 | # -----------------------------------------------------------------------------
16 |
17 | _C = CN()
18 |
19 | _C.MODEL = CN()
20 | _C.MODEL.DEVICE = "cuda"
21 | _C.MODEL.COARSE_RAY_SAMPLING = 64
22 | _C.MODEL.FINE_RAY_SAMPLING = 80
23 | _C.MODEL.INPORTANCE_RAY_SAMPLE = 0
24 | _C.MODEL.SAMPLE_METHOD = "NEAR_FAR"
25 | _C.MODEL.BOARDER_WEIGHT = 1e10
26 | _C.MODEL.SAME_SPACENET = False
27 | _C.MODEL.DEPTH_FIELD = 1e-1
28 | _C.MODEL.DEPTH_RATIO = 1e-1
29 | _C.MODEL.SAMPLE_INF = False
30 | _C.MODEL.USE_MOTION = False
31 |
32 | _C.MODEL.UNET_DIR = False
33 | _C.MODEL.USE_SH = True
34 | _C.MODEL.SH_DIM = 9
35 | _C.MODEL.SH_FEAT_DIM = 3
36 | _C.MODEL.BONE_FEAT_DIM = 0
37 | _C.MODEL.UNET_FEAT = False
38 | _C.MODEL.USE_RENDER_NET = True
39 | _C.MODEL.USE_LIGHT_RENDERER = False
40 | _C.MODEL.USE_BONE_NET = False
41 | _C.MODEL.USE_WARP_NET = False
42 |
43 | _C.MODEL.SAMPLE_NET = False
44 | _C.MODEL.OPACITY_NET = False
45 | _C.MODEL.NOISE_STD = 0.0
46 |
47 | _C.MODEL.TKERNEL_INC_RAW = True
48 | _C.MODEL.ENCODE_POS_DIM = 10
49 | _C.MODEL.ENCODE_DIR_DIM = 4
50 | _C.MODEL.GAUSSIAN_SIGMA = 10.
51 | _C.MODEL.KERNEL_TYPE = "POS"
52 | _C.MODEL.PAIR_TRAINING = False
53 | _C.MODEL.TREE_DEPTH = 8
54 | _C.MODEL.RANDOM_INI = True
55 |
56 | # -----------------------------------------------------------------------------
57 | # INPUT
58 | # -----------------------------------------------------------------------------
59 | _C.INPUT = CN()
60 | # Size of the image during training
61 | _C.INPUT.SIZE_TRAIN = [400, 250]
62 | # Size of the image during test
63 | _C.INPUT.SIZE_TEST = [400, 250]
64 | # Minimum scale for the image during training
65 | _C.INPUT.MIN_SCALE_TRAIN = 0.5
66 | # Maximum scale for the image during test
67 | _C.INPUT.MAX_SCALE_TRAIN = 1.2
68 | # Random probability for image horizontal flip
69 | _C.INPUT.PROB = 0.5
70 | # Values to be used for image normalization
71 | _C.INPUT.PIXEL_MEAN = [0.1307, ]
72 | # Values to be used for image normalization
73 | _C.INPUT.PIXEL_STD = [0.3081, ]
74 |
75 | # -----------------------------------------------------------------------------
76 | # Dataset
77 | # -----------------------------------------------------------------------------
78 | _C.DATASETS = CN()
79 | # List of the dataset names for training, as present in paths_catalog.py
80 | _C.DATASETS.TRAIN = ""
81 | # List of the dataset names for testing, as present in paths_catalog.py
82 | _C.DATASETS.TEST = ""
83 | _C.DATASETS.SHIFT = 0.0
84 | _C.DATASETS.MAXRATION = 0.0
85 | _C.DATASETS.ROTATION = 0.0
86 | _C.DATASETS.USE_MASK = False
87 | _C.DATASETS.USE_DEPTH = False
88 | _C.DATASETS.USE_ALPHA = False
89 | _C.DATASETS.USE_BG = False
90 | _C.DATASETS.NUM_FRAME = 1
91 | _C.DATASETS.NUM_CAMERA = 1000
92 | _C.DATASETS.TYPE = "NR"
93 | _C.DATASETS.SYNTHESIS = True
94 | _C.DATASETS.NO_BOUNDARY = False
95 | _C.DATASETS.BOUNDARY_WIDTH = 3
96 | _C.DATASETS.PATCH_SIZE = 16
97 | _C.DATASETS.KEEP_BG = False
98 | _C.DATASETS.PAIR_SAMPLE = False
99 |
100 | # -----------------------------------------------------------------------------
101 | # DataLoader
102 | # -----------------------------------------------------------------------------
103 | _C.DATALOADER = CN()
104 | # Number of data loading threads
105 | _C.DATALOADER.NUM_WORKERS = 8
106 |
107 | # ---------------------------------------------------------------------------- #
108 | # Solver
109 | # ---------------------------------------------------------------------------- #
110 | _C.SOLVER = CN()
111 | _C.SOLVER.OPTIMIZER_NAME = "SGD"
112 |
113 | _C.SOLVER.MAX_EPOCHS = 50
114 |
115 | _C.SOLVER.BASE_LR = 0.001
116 | _C.SOLVER.BIAS_LR_FACTOR = 2
117 |
118 | _C.SOLVER.MOMENTUM = 0.9
119 |
120 | _C.SOLVER.WEIGHT_DECAY = 0.0005
121 | _C.SOLVER.WEIGHT_DECAY_BIAS = 0
122 |
123 | _C.SOLVER.GAMMA = 0.1
124 | _C.SOLVER.STEPS = (30000,)
125 |
126 | _C.SOLVER.WARMUP_FACTOR = 1.0 / 3
127 | _C.SOLVER.WARMUP_ITERS = 500
128 | _C.SOLVER.WARMUP_METHOD = "linear"
129 |
130 | _C.SOLVER.LOSS_FN = "L1"
131 |
132 | _C.SOLVER.CHECKPOINT_PERIOD = 10
133 | _C.SOLVER.LOG_PERIOD = 100
134 | _C.SOLVER.BUNCH = 4096
135 | _C.SOLVER.BATCH_SIZE = 64
136 | _C.SOLVER.START_ITERS = 50
137 | _C.SOLVER.END_ITERS = 200
138 | _C.SOLVER.LR_SCALE = 0.1
139 | _C.SOLVER.COARSE_STAGE = 10
140 |
141 | # used for update 3d geometry
142 | _C.SOLVER.START_EPOCHS = 10
143 | _C.SOLVER.EPOCH_STEP = 1
144 | _C.SOLVER.CUBE_RESOLUTION = 512
145 | _C.SOLVER.CURVE_BUNCH = 10000
146 | _C.SOLVER.THRESHOLD = 5.
147 | _C.SOLVER.CURVE_SAMPLE_NUM = 64
148 | _C.SOLVER.CURVE_SAMPLE_SCALE = 1
149 | _C.SOLVER.UPDATE_GEOMETRY = False
150 | _C.SOLVER.UPDATE_RANGE = False
151 | _C.SOLVER.USE_BOARDER_COLOR = False
152 | _C.SOLVER.USE_AMP = False
153 | _C.SOLVER.SEED = 2021
154 |
155 |
156 | # Number of images per batch
157 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
158 | # see 2 images per batch
159 | _C.SOLVER.IMS_PER_BATCH = 16
160 |
161 | # This is global, so if we have 8 GPUs and IMS_PER_BATCH = 16, each GPU will
162 | # see 2 images per batch
163 | _C.TEST = CN()
164 | _C.TEST.IMS_PER_BATCH = 8
165 | _C.TEST.WEIGHT = ""
166 |
167 | # ---------------------------------------------------------------------------- #
168 | # Misc options
169 | # ---------------------------------------------------------------------------- #
170 | _C.OUTPUT_DIR = ""
171 |
--------------------------------------------------------------------------------
/code/configs/configs.yml:
--------------------------------------------------------------------------------
1 | SOLVER:
2 | OPTIMIZER_NAME: "Adam"
3 | BASE_LR: 0.0012 # 12
4 | WEIGHT_DECAY: 0. # 0.0000001
5 | IMS_PER_BATCH: 1
6 | START_ITERS: 2000
7 | END_ITERS: 50000
8 | LR_SCALE: 0.08
9 | WARMUP_ITERS: 1000
10 |
11 | LOSS_FN: "L1" # L1 or L2
12 |
13 | MAX_EPOCHS: 200
14 | CHECKPOINT_PERIOD: 5000
15 | LOG_PERIOD: 30
16 | BUNCH: 60000
17 | BATCH_SIZE: 8 # 12
18 | COARSE_STAGE: 0
19 |
20 | USE_AMP: True
21 | SEED: 2021
22 |
23 | INPUT:
24 | SIZE_TRAIN: [960,540]
25 | SIZE_TEST: [960,540]
26 |
27 | DATASETS:
28 | TYPE: "NR" # "NR" "LLFF"
29 | SYNTHESIS: True
30 |
31 | TRAIN: "wolf"
32 | TEST: "wolf"
33 |
34 | NUM_FRAME: 48
35 | NUM_CAMERA: 24 # 24
36 | SHIFT: 0.0
37 | MAXRATION: 0.0
38 | ROTATION: 0.0
39 |
40 | USE_MASK: False
41 | USE_DEPTH: True
42 | USE_ALPHA: True
43 |
44 | PAIR_SAMPLE: False
45 |
46 | DATALOADER:
47 | NUM_WORKERS: 20
48 |
49 | MODEL:
50 | USE_MOTION: True
51 | USE_RENDER_NET: True
52 | USE_LIGHT_RENDERER: False
53 |
54 | USE_SH: True
55 | SH_DIM: 9
56 | SH_FEAT_DIM: 10
57 | BONE_FEAT_DIM: 10
58 | TREE_DEPTH: 8
59 | RANDOM_INI: True
60 |
61 | PAIR_TRAINING: False
62 |
63 | TEST:
64 | IMS_PER_BATCH: 1
65 |
66 | OUTPUT_DIR: ""
67 |
68 |
69 |
70 |
71 |
--------------------------------------------------------------------------------
/code/modeling/UNet/__init__.py:
--------------------------------------------------------------------------------
1 | from .unet_model import UNet, LightUNet
2 |
--------------------------------------------------------------------------------
/code/modeling/UNet/unet_model.py:
--------------------------------------------------------------------------------
1 | # full assembly of the sub-parts to form the complete net
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from .unet_parts import *
6 |
7 | class LightUNet(nn.Module):
8 | def __init__(self, rgb_feat_channels, alpha_feat_channels, n_classes1, n_classes2):
9 | super(LightUNet, self).__init__()
10 | self.inc = inconv(rgb_feat_channels, 24)
11 | self.down1 = down(24, 48)
12 | self.down2 = down(48, 192)
13 | self.down3 = down(192, 192)
14 | self.up1 = up(192+192, 96)
15 | self.up2 = up(96+48, 48)
16 | self.up3 = up(48+24, 30)
17 | self.outc = outconv(30, n_classes1 + n_classes2)
18 |
19 | # self.inc2 = inconv(alpha_feat_channels + n_classes1, 16)
20 | # self.down5 = down(16, 32)
21 | # self.down6 = down(32, 64)
22 | # self.up5 = up(64 + 48 + 32, 32)
23 | # self.up6 = up(32 + 24 + 16, 8)
24 | # self.outc2 = outconv(8, n_classes2)
25 |
26 | def forward(self, rgb_feature, alpha_feature, motion_feature=None):
27 | # if motion_feature is not None:
28 | # rgb_feature = torch.cat([rgb_feature, motion_feature], dim=1)
29 |
30 | x1 = self.inc(rgb_feature) # 24
31 | x2 = self.down1(x1) # 48
32 | x3 = self.down2(x2) # 192
33 | x4 = self.down3(x3) # 192
34 |
35 | x5 = self.up1(x4, x3) # 96
36 | x5 = self.up2(x5, x2) # 48
37 | x5 = self.up3(x5, x1) # 30
38 | x_rgb = self.outc(x5) # n_classes1
39 |
40 | # x = torch.cat([alpha_feature, x_rgb], dim=1)
41 | # x1_2 = self.inc2(x)
42 | # x2_2 = self.down5(x1_2)
43 | # x3_2 = self.down6(x2_2)
44 | #
45 | # x6 = self.up5(x3_2, torch.cat([x2, x2_2], dim=1))
46 | # x6 = self.up6(x6, torch.cat([x1, x1_2], dim=1))
47 | # x_alpha = self.outc2(x6)
48 | #
49 | # x_rgba = torch.cat([x_rgb, x_alpha], dim=1)
50 |
51 | x_rgba = x_rgb
52 | return x_rgba
53 |
54 |
55 | class UNet(nn.Module):
56 | def __init__(self, rgb_feat_channels, alpha_feat_channels, n_classes1, n_classes2):
57 | super(UNet, self).__init__()
58 | self.inc = inconv(rgb_feat_channels, 24)
59 | self.down1 = down(24, 48)
60 | self.down2 = down(48, 96)
61 | self.down3 = down(96, 384)
62 | self.down4 = down(384, 384)
63 | self.up1 = up(768, 192)
64 | self.up2 = up(288, 48)
65 | self.up3 = up(96, 48)
66 | self.up4 = up(72, 30)
67 | self.outc = outconv(30, n_classes1)
68 |
69 | self.inc2 = inconv(alpha_feat_channels + n_classes1, 16)
70 | self.down5 = down(16, 32)
71 | self.down6 = down(32, 64)
72 | self.up5 = up(144, 32)
73 | self.up6 = up(72, 8)
74 | self.outc2 = outconv(8, n_classes2)
75 |
76 | def forward(self, rgb_feature, alpha_feature):
77 | x1 = self.inc(rgb_feature)
78 | x2 = self.down1(x1)
79 | x3 = self.down2(x2)
80 | x4 = self.down3(x3)
81 | x5 = self.down4(x4)
82 |
83 | x6 = self.up1(x5, x4)
84 | x6 = self.up2(x6, x3)
85 | x6 = self.up3(x6, x2)
86 | x6 = self.up4(x6, x1)
87 | x_rgb = self.outc(x6)
88 |
89 | x = torch.cat([alpha_feature, x_rgb], dim=1)
90 | x1_2 = self.inc2(x)
91 | x2_2 = self.down5(x1_2)
92 | x3_2 = self.down6(x2_2)
93 |
94 | x6 = self.up5(x3_2, torch.cat([x2, x2_2], dim=1))
95 | x6 = self.up6(x6, torch.cat([x1, x1_2], dim=1))
96 | x_alpha = self.outc2(x6)
97 |
98 | x = torch.cat([x_rgb, x_alpha], dim=1)
99 | return x
100 |
--------------------------------------------------------------------------------
/code/modeling/UNet/unet_parts.py:
--------------------------------------------------------------------------------
1 | # sub-parts of the U-Net model
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import models_lpf
7 |
8 |
9 | class gated_conv(nn.Module):
10 | def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0):
11 | super(gated_conv, self).__init__()
12 | self.conv2 = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding)
13 | self.conv2_gate = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding)
14 |
15 | def forward(self, x):
16 | feat = self.conv2(x)
17 | mask = self.conv2_gate(x)
18 |
19 | return torch.sigmoid(mask) * feat
20 |
21 |
22 | class double_conv(nn.Module):
23 | '''(conv => BN => ReLU) * 2'''
24 |
25 | def __init__(self, in_ch, out_ch):
26 | super(double_conv, self).__init__()
27 | self.conv = nn.Sequential(
28 | gated_conv(in_ch, out_ch, 3, padding=1),
29 | nn.BatchNorm2d(out_ch),
30 | nn.ReLU(inplace=True),
31 | gated_conv(out_ch, out_ch, 3, padding=1),
32 | nn.BatchNorm2d(out_ch),
33 | nn.ReLU(inplace=True)
34 | )
35 |
36 | def forward(self, x):
37 | x = self.conv(x)
38 | return x
39 |
40 |
41 | class inconv(nn.Module):
42 | def __init__(self, in_ch, out_ch):
43 | super(inconv, self).__init__()
44 | self.conv = double_conv(in_ch, out_ch)
45 |
46 | def forward(self, x):
47 | x = self.conv(x)
48 | return x
49 |
50 |
51 | class down(nn.Module):
52 | def __init__(self, in_ch, out_ch):
53 | super(down, self).__init__()
54 | self.mpconv = nn.Sequential(
55 | nn.MaxPool2d(2, stride=1),
56 | models_lpf.Downsample(channels=in_ch, filt_size=3, stride=2),
57 | double_conv(in_ch, out_ch)
58 | )
59 |
60 | def forward(self, x):
61 | x = self.mpconv(x)
62 | return x
63 |
64 |
65 | class up(nn.Module):
66 | def __init__(self, in_ch, out_ch, bilinear=True):
67 | super(up, self).__init__()
68 |
69 | # would be a nice idea if the upsampling could be learned too,
70 | # but my machine do not have enough memory to handle all those weights
71 | if bilinear:
72 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
73 | else:
74 | self.up = nn.ConvTranspose2d(in_ch // 2, in_ch // 2, 2, stride=2)
75 |
76 | self.conv = double_conv(in_ch, out_ch)
77 |
78 | def forward(self, x1, x2):
79 | x1 = self.up(x1)
80 |
81 | # input is CHW
82 | diffY = x2.size()[2] - x1.size()[2]
83 | diffX = x2.size()[3] - x1.size()[3]
84 |
85 | x1 = F.pad(x1, (diffX // 2, diffX - diffX // 2,
86 | diffY // 2, diffY - diffY // 2))
87 |
88 | # for padding issues, see
89 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
90 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
91 |
92 | x = torch.cat([x2, x1], dim=1)
93 | x = self.conv(x)
94 | return x
95 |
96 |
97 | class outconv(nn.Module):
98 | def __init__(self, in_ch, out_ch):
99 | super(outconv, self).__init__()
100 | self.conv = nn.Conv2d(in_ch, out_ch, 1)
101 | self.conv2 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
102 |
103 | def forward(self, x):
104 | x1 = self.conv(x)
105 | x2 = self.conv2(x)
106 | return x1 + x2
107 |
--------------------------------------------------------------------------------
/code/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import GeneralModel, NLayerDiscriminator
2 | from .tree_render import TreeRenderer
3 | import torch
4 | import os
5 |
6 | def build_model(cfg, flut_size=0, bone_feature_dim=0):
7 | if not cfg.MODEL.USE_MOTION:
8 | bone_feature_dim = 0
9 | return GeneralModel(cfg, flut_size=flut_size, use_render_net=cfg.MODEL.USE_RENDER_NET, bone_feature_dim=bone_feature_dim, texture_feature_dim=cfg.MODEL.SH_FEAT_DIM)
10 |
11 |
12 | def build_discriminator(cfg):
13 | return NLayerDiscriminator(input_nc=input_nc, ndf=32, n_layers=5, norm_layer='spectral')
--------------------------------------------------------------------------------
/code/modeling/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .UNet import UNet, LightUNet
4 | from .tree_render import TreeRenderer
5 |
6 | import functools
7 | import torch.nn.utils.spectral_norm as spectral_norm
8 | from torch.cuda.amp import autocast
9 | from utils import generate_transformation_matrices
10 |
11 | import time
12 |
13 | class GeneralModel(nn.Module):
14 | def __init__(self, cfg, flut_size=0, use_render_net=True, bone_feature_dim=0, texture_feature_dim=0, sh_dim=9):
15 | super(GeneralModel, self).__init__()
16 | self.use_render_net = use_render_net
17 | self.bone_feature_dim = bone_feature_dim
18 |
19 | if not cfg.MODEL.USE_MOTION:
20 | self.bone_feature_dim = 0
21 |
22 | if flut_size != 0:
23 | self.tree_renderer = TreeRenderer(flut_size, background_brightness=0., random_init=cfg.MODEL.RANDOM_INI, feature_dim=texture_feature_dim * sh_dim + 1)
24 |
25 |
26 | if self.use_render_net:
27 | if not cfg.MODEL.USE_LIGHT_RENDERER:
28 | self.render_net = UNet(rgb_feat_channels=texture_feature_dim + self.bone_feature_dim,
29 | alpha_feat_channels=1, n_classes1=3, n_classes2=1)
30 | else:
31 | print('Light')
32 | self.render_net = LightUNet(rgb_feat_channels=texture_feature_dim + self.bone_feature_dim,
33 | alpha_feat_channels=1, n_classes1=3, n_classes2=1)
34 |
35 | def forward(self, rays, tree=None, joint_features=None, skinning_weights=None, joint_index=None, coords=None, transformation_matrices=None):
36 | batch_size, h, w = rays.size(0), rays.size(1), rays.size(2)
37 | joint_num, joint_feature_dim = (joint_features.size(1), joint_features.size(2)) if joint_features is not None else (0, 0)
38 | if isinstance(tree, list):
39 | assert len(tree) == batch_size
40 | else:
41 | tree = [tree]
42 |
43 | if isinstance(transformation_matrices, list):
44 | assert len(transformation_matrices) == batch_size
45 | else:
46 | transformation_matrices = [transformation_matrices]
47 |
48 | # print(rays.shape)
49 |
50 | rgbas, results = [], []
51 | rgb_features, alpha_features, motion_features = [], [], []
52 | features_in = []
53 | vrts = []
54 |
55 | joint_features = joint_features.reshape(-1, joint_feature_dim).type_as(rays) if joint_features is not None else None
56 | # print(joint_features.max())
57 | if self.use_bone_net and joint_features is not None:
58 | joint_features = self.bone_net(joint_features)
59 |
60 | if joint_features is not None:
61 | joint_features = joint_features.reshape(batch_size, joint_num, -1)
62 |
63 | s = time.time()
64 | for i in range(batch_size):
65 | ray = rays[i:i+1, ...].reshape(-1, rays.size(3))
66 |
67 | torch.cuda.synchronize()
68 | s = time.time()
69 | matrices = None
70 | # print(batch_size, i, transformation_matrices[i].max())
71 | if transformation_matrices[i] is not None:
72 | matrices = generate_transformation_matrices(matrices=transformation_matrices[i], skinning_weights=skinning_weights, joint_index=joint_index)
73 | features = self.tree_renderer(tree[i], ray, matrices)
74 | features = features.reshape(1, h, w, -1).permute(0, 3, 1, 2)
75 |
76 | motion_feature = None
77 | if joint_features is not None:
78 | with autocast(enabled=False):
79 | motion_feature = self.tree_renderer.motion_feature_render(tree[i], joint_features[i].float(), skinning_weights,
80 | joint_index,
81 | ray)
82 | motion_feature = motion_feature.reshape(1, h, w, -1).permute(0, 3, 1, 2)
83 | torch.cuda.synchronize()
84 | features_in.append(features)
85 | motion_features.append(motion_feature)
86 |
87 | if coords is not None:
88 | vrts.append(self.tree_renderer.voxel_regularization(tree[i], coords[i]))
89 |
90 | features_in = torch.cat(features_in, dim=0)
91 | motion_features = None if motion_features[0] is None else torch.cat(motion_features, dim=0)
92 |
93 | rgbas = torch.cat([features_in[:, :3, ...], features_in[:, -1:, ...]], dim=1)
94 | torch.cuda.synchronize()
95 | s = time.time()
96 | if self.use_render_net:
97 | results = self.render_net(torch.cat([features[:, :-1, ...], motion_feature], dim=1), features_in[:, -1:, ...])
98 | else:
99 | results = rgbas
100 |
101 | if coords is None:
102 | return rgbas, results, motion_features
103 | else:
104 | return rgbas, results, sum(vrts) / float(len(vrts)), motion_features
105 |
106 | def render_volume_feature(self, rays, tree=None, joint_features=None, skinning_weights=None, joint_index=None):
107 | batch_size, h, w = rays.size(0), rays.size(1), rays.size(2)
108 | rays = rays.reshape(-1, rays.size(3))
109 |
110 | features = self.tree_renderer(tree, rays)
111 | features = features.reshape(batch_size, h, w, -1).permute(0, 3, 1, 2)
112 | rgb_feature_map = features[:, :-1, ...]
113 | alpha = features[:, -1:, ...]
114 | rgb = features[:, :3, ...]
115 | rgba = torch.cat([rgb, alpha], dim=1)
116 |
117 | motion_feature = None
118 | if joint_features is not None:
119 | motion_feature = self.tree_renderer.motion_feature_render(tree, joint_features, skinning_weights, joint_index, rays) # joint_features, skinning_weights, joint_index, rays
120 |
121 | return rgba, features, motion_feature
122 |
123 |
124 |
125 | class ConvBlock(nn.Module):
126 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1):
127 | super(ConvBlock, self).__init__()
128 |
129 | self.conv = nn.Sequential(
130 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
131 | kernel_size=kernel_size, padding=padding, stride=stride),
132 | # nn.BatchNorm2d(out_channels),
133 | nn.InstanceNorm2d(out_channels),
134 | nn.LeakyReLU(0.2, inplace=True)
135 | # nn.ReLU(inplace=True)
136 | )
137 |
138 | def forward(self, x):
139 | return self.conv(x)
140 |
141 |
142 | class NLayerDiscriminator(nn.Module):
143 | """Defines a PatchGAN discriminator"""
144 |
145 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
146 | """Construct a PatchGAN discriminator
147 | Parameters:
148 | input_nc (int) -- the number of channels in input images
149 | ndf (int) -- the number of filters in the last conv layer
150 | n_layers (int) -- the number of conv layers in the discriminator
151 | norm_layer -- normalization layer
152 | """
153 | super(NLayerDiscriminator, self).__init__()
154 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
155 | use_bias = norm_layer.func == nn.InstanceNorm2d
156 | else:
157 | use_bias = norm_layer == nn.InstanceNorm2d
158 |
159 | kw = 4
160 | padw = 1
161 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
162 | nf_mult = 1
163 | nf_mult_prev = 1
164 | for n in range(1, n_layers): # gradually increase the number of filters
165 | nf_mult_prev = nf_mult
166 | nf_mult = min(2 ** n, 8)
167 | sequence += [
168 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
169 | norm_layer(ndf * nf_mult),
170 | nn.LeakyReLU(0.2, True)
171 | ] if norm_layer is not 'spectral' else [
172 | spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias)),
173 | nn.LeakyReLU(0.2, True)
174 | ]
175 |
176 | nf_mult_prev = nf_mult
177 | nf_mult = min(2 ** n_layers, 8)
178 | sequence += [
179 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
180 | norm_layer(ndf * nf_mult),
181 | nn.LeakyReLU(0.2, True)
182 | ] if norm_layer is not 'spectral' else [
183 | spectral_norm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias)),
184 | nn.LeakyReLU(0.2, True)
185 | ]
186 |
187 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
188 | self.model = nn.Sequential(*sequence)
189 |
190 | def forward(self, input):
191 | """Standard forward."""
192 | return self.model(input)
193 |
194 |
195 | class PixelDiscriminator(nn.Module):
196 | """Defines a 1x1 PatchGAN discriminator (pixelGAN)"""
197 |
198 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d):
199 | """Construct a 1x1 PatchGAN discriminator
200 | Parameters:
201 | input_nc (int) -- the number of channels in input images
202 | ndf (int) -- the number of filters in the last conv layer
203 | norm_layer -- normalization layer
204 | """
205 | super(PixelDiscriminator, self).__init__()
206 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
207 | use_bias = norm_layer.func == nn.InstanceNorm2d
208 | else:
209 | use_bias = norm_layer == nn.InstanceNorm2d
210 |
211 | self.net = [
212 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
213 | nn.LeakyReLU(0.2, True),
214 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
215 | norm_layer(ndf * 2),
216 | nn.LeakyReLU(0.2, True),
217 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
218 |
219 | self.net = nn.Sequential(*self.net)
220 |
221 | def forward(self, input):
222 | """Standard forward."""
223 | return self.net(input)
--------------------------------------------------------------------------------
/code/modeling/tree_render.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 | from torch import nn
5 |
6 | import svox_t
7 | import collections
8 | Sh_Rays = collections.namedtuple('Sh_Rays', 'origins dirs viewdirs')
9 |
10 |
11 | class TreeRenderer(nn.Module):
12 |
13 | def __init__(self, flut_size, background_brightness=0., step_size=1e-3, random_init=False, feature_dim=0):
14 | super(TreeRenderer, self).__init__()
15 | self.background_brightness = background_brightness
16 | self.step_size = step_size
17 | print(flut_size)
18 |
19 | features = torch.empty((flut_size, feature_dim))
20 | nn.init.normal_(features)
21 | self.register_parameter("features", nn.Parameter(features))
22 |
23 | def forward(self, tree, rays, transformation_matrices=None, fast_rendering=False):
24 | t = tree.to(rays.device)
25 |
26 | r = svox_t.VolumeRenderer(t, background_brightness=self.background_brightness, step_size=self.step_size)
27 | dirs = rays[..., :3].contiguous()
28 | origins = rays[..., 3:].contiguous()
29 |
30 | sh_rays = Sh_Rays(origins, dirs, dirs)
31 |
32 | res = r(self.features, sh_rays, transformation_matrices=transformation_matrices, fast=fast_rendering)
33 |
34 | return res
35 |
36 | def motion_render(self, tree, rays):
37 | t = tree.to(rays.device)
38 | dirs = rays[..., :3].contiguous()
39 | origins = rays[..., 3:].contiguous()
40 | sh_rays = Sh_Rays(origins, dirs, dirs)
41 | r = svox_t.VolumeRenderer(t, background_brightness=self.background_brightness, step_size=self.step_size)
42 | motion_feature, depth, hit_point, data_idx = r.motion_render(self.features, sh_rays)
43 |
44 | return motion_feature, depth, hit_point, data_idx
45 |
46 | def motion_feature_render(self, tree, joint_features, skinning_weights, joint_index, rays):
47 | t = tree.to(rays.device)
48 | dirs = rays[..., :3].contiguous()
49 | origins = rays[..., 3:].contiguous()
50 | sh_rays = Sh_Rays(origins, dirs, dirs)
51 | r = svox_t.VolumeRenderer(t, background_brightness=self.background_brightness, step_size=self.step_size)
52 | motion_feature= r.motion_feature_render(self.features, joint_features, skinning_weights, joint_index, sh_rays)
53 |
54 | return motion_feature
55 |
56 | def voxel_regularization(self, tree, coordinate):
57 | tree = tree.to(coordinate.device)
58 | query_features = tree(self.features, coordinate, want_node_ids=False)
59 | vrt = nn.L1Loss()(query_features[..., :-1].detach(), self.features[..., :-1])
60 | # vrt = nn.MSELoss()(query_features.detach(), self.features)
61 | return vrt
62 |
63 | @staticmethod
64 | def MotionRender(tree, features, rays):
65 | t = tree.to(rays.device)
66 | dirs = rays[..., :3].contiguous()
67 | origins = rays[..., 3:].contiguous()
68 | sh_rays = Sh_Rays(origins, dirs, dirs)
69 | r = svox_t.VolumeRenderer(t, background_brightness=0., step_size=1e-3)
70 | motion_feature, depth, hit_point, data_idx = r.motion_render(features, sh_rays)
71 |
72 | return motion_feature, depth, hit_point, data_idx
73 |
74 | @staticmethod
75 | def MotionFeatureRender(tree, features, joint_features, skinning_weights, joint_index, rays):
76 | t = tree.to(rays.device)
77 | dirs = rays[..., :3].contiguous()
78 | origins = rays[..., 3:].contiguous()
79 | sh_rays = Sh_Rays(origins, dirs, dirs)
80 | r = svox_t.VolumeRenderer(t, background_brightness=0., step_size=1e-3)
81 | motion_feature = r.motion_feature_render(features, joint_features, skinning_weights, joint_index, sh_rays)
82 |
83 | return motion_feature
--------------------------------------------------------------------------------
/code/models_lpf/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, Adobe Inc. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
4 | # 4.0 International Public License. To view a copy of this license, visit
5 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
6 |
7 | from .downsample import *
8 | from .alexnet import *
9 | from .densenet import *
10 | from .mobilenet import *
11 | from .resnet import *
12 | from .vgg import *
13 |
--------------------------------------------------------------------------------
/code/models_lpf/alexnet.py:
--------------------------------------------------------------------------------
1 | # This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models.
2 | # Copyright (c) 2017 Torch Contributors.
3 | # The Pytorch examples are available under the BSD 3-Clause License.
4 | #
5 | # ==========================================================================================
6 | #
7 | # Adobe’s modifications are Copyright 2019 Adobe. All rights reserved.
8 | # Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
9 | # 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit
10 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
11 | #
12 | # ==========================================================================================
13 | #
14 | # BSD-3 License
15 | #
16 | # Redistribution and use in source and binary forms, with or without
17 | # modification, are permitted provided that the following conditions are met:
18 | #
19 | # * Redistributions of source code must retain the above copyright notice, this
20 | # list of conditions and the following disclaimer.
21 | #
22 | # * Redistributions in binary form must reproduce the above copyright notice,
23 | # this list of conditions and the following disclaimer in the documentation
24 | # and/or other materials provided with the distribution.
25 | #
26 | # * Neither the name of the copyright holder nor the names of its
27 | # contributors may be used to endorse or promote products derived from
28 | # this software without specific prior written permission.
29 | #
30 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
31 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
32 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
33 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
34 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
35 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
36 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
37 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
38 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
39 |
40 | import torch
41 | import torch.nn as nn
42 | import torch.utils.model_zoo as model_zoo
43 | import numpy as np
44 | from models_lpf import *
45 | from IPython import embed
46 |
47 | __all__ = ['AlexNet', 'alexnet']
48 |
49 |
50 | # model_urls = {
51 | # 'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
52 | # }
53 |
54 |
55 | class AlexNet(nn.Module):
56 |
57 | def __init__(self, num_classes=1000, filter_size=1, pool_only=False, relu_first=True):
58 | super(AlexNet, self).__init__()
59 |
60 | if(pool_only): # only apply LPF to pooling layers, so run conv1 at stride 4 as before
61 | first_ds = [nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),]
62 | else:
63 | if(relu_first): # this is the right order
64 | first_ds = [nn.Conv2d(3, 64, kernel_size=11, stride=2, padding=2),
65 | nn.ReLU(inplace=True),
66 | Downsample(filt_size=filter_size, stride=2, channels=64),]
67 | else: # this is the wrong order, since it's equivalent to downsampling the image first
68 | first_ds = [nn.Conv2d(3, 64, kernel_size=11, stride=2, padding=2),
69 | Downsample(filt_size=filter_size, stride=2, channels=64),
70 | nn.ReLU(inplace=True),]
71 |
72 | first_ds += [nn.MaxPool2d(kernel_size=3, stride=1),
73 | Downsample(filt_size=filter_size, stride=2, channels=64),
74 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
75 | nn.ReLU(inplace=True),
76 | nn.MaxPool2d(kernel_size=3, stride=1),
77 | Downsample(filt_size=filter_size, stride=2, channels=192),
78 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
79 | nn.ReLU(inplace=True),
80 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
81 | nn.ReLU(inplace=True),
82 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
83 | nn.ReLU(inplace=True),
84 | nn.MaxPool2d(kernel_size=3, stride=1),
85 | Downsample(filt_size=filter_size, stride=2, channels=256)]
86 | self.features = nn.Sequential(*first_ds)
87 |
88 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
89 | self.classifier = nn.Sequential(
90 | nn.Dropout(),
91 | nn.Linear(256 * 6 * 6, 4096),
92 | nn.ReLU(inplace=True),
93 | nn.Dropout(),
94 | nn.Linear(4096, 4096),
95 | nn.ReLU(inplace=True),
96 | nn.Linear(4096, num_classes),
97 | )
98 |
99 | def forward(self, x):
100 | x = self.features(x)
101 | x = self.avgpool(x)
102 | x = x.view(x.size(0), 256 * 6 * 6)
103 | x = self.classifier(x)
104 | return x
105 |
106 |
107 | def alexnet(pretrained=False, **kwargs):
108 | """AlexNet model architecture from the
109 | `"One weird trick..." `_ paper.
110 |
111 | Args:
112 | pretrained (bool): If True, returns a model pre-trained on ImageNet
113 | """
114 | model = AlexNet(**kwargs)
115 | if pretrained:
116 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))
117 | return model
118 |
119 |
120 | # replacing MaxPool with BlurPool layers
121 | class AlexNetNMP(nn.Module):
122 |
123 | def __init__(self, num_classes=1000, filter_size=1):
124 | super(AlexNetNMP, self).__init__()
125 | self.features = nn.Sequential(
126 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
127 | nn.ReLU(inplace=True),
128 | Downsample(filt_size=filter_size, stride=2, channels=64, pad_off=-1, hidden=True),
129 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
130 | nn.ReLU(inplace=True),
131 | Downsample(filt_size=filter_size, stride=2, channels=192, pad_off=-1, hidden=True),
132 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
133 | nn.ReLU(inplace=True),
134 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
135 | nn.ReLU(inplace=True),
136 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
137 | nn.ReLU(inplace=True),
138 | Downsample(filt_size=filter_size, stride=2, channels=256, pad_off=-1, hidden=True),
139 | )
140 | self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
141 | self.classifier = nn.Sequential(
142 | nn.Dropout(),
143 | nn.Linear(256 * 6 * 6, 4096),
144 | nn.ReLU(inplace=True),
145 | nn.Dropout(),
146 | nn.Linear(4096, 4096),
147 | nn.ReLU(inplace=True),
148 | nn.Linear(4096, num_classes),
149 | )
150 |
151 | def forward(self, x):
152 | # embed()
153 | x = self.features(x)
154 | x = self.avgpool(x)
155 | x = x.view(x.size(0), 256 * 6 * 6)
156 | x = self.classifier(x)
157 | return x
158 |
159 |
160 | def alexnetnmp(pretrained=False, **kwargs):
161 | """AlexNet model architecture from the
162 | `"One weird trick..." `_ paper.
163 |
164 | Args:
165 | pretrained (bool): If True, returns a model pre-trained on ImageNet
166 | """
167 | model = AlexNetNMP(**kwargs)
168 | if pretrained:
169 | model.load_state_dict(model_zoo.load_url(model_urls['alexnet']))
170 | return model
171 |
172 |
173 |
174 |
175 | # def __init__(self, num_classes=1000):
176 | # super(AlexNet, self).__init__()
177 | # self.features = nn.Sequential(
178 | # nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
179 | # nn.ReLU(inplace=True),
180 | # nn.MaxPool2d(kernel_size=3, stride=2),
181 | # nn.Conv2d(64, 192, kernel_size=5, padding=2),
182 | # nn.ReLU(inplace=True),
183 | # nn.MaxPool2d(kernel_size=3, stride=2),
184 | # nn.Conv2d(192, 384, kernel_size=3, padding=1),
185 | # nn.ReLU(inplace=True),
186 | # nn.Conv2d(384, 256, kernel_size=3, padding=1),
187 | # nn.ReLU(inplace=True),
188 | # nn.Conv2d(256, 256, kernel_size=3, padding=1),
189 | # nn.ReLU(inplace=True),
190 | # nn.MaxPool2d(kernel_size=3, stride=2),
191 | # )
192 | # self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
193 | # self.classifier = nn.Sequential(
194 | # nn.Dropout(),
195 | # nn.Linear(256 * 6 * 6, 4096),
196 | # nn.ReLU(inplace=True),
197 | # nn.Dropout(),
198 | # nn.Linear(4096, 4096),
199 | # nn.ReLU(inplace=True),
200 | # nn.Linear(4096, num_classes),
201 | # )
202 |
203 |
204 |
205 |
--------------------------------------------------------------------------------
/code/models_lpf/densenet.py:
--------------------------------------------------------------------------------
1 | # This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models.
2 | # Copyright (c) 2017 Torch Contributors.
3 | # The Pytorch examples are available under the BSD 3-Clause License.
4 | #
5 | # ==========================================================================================
6 | #
7 | # Adobe’s modifications are Copyright 2019 Adobe. All rights reserved.
8 | # Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
9 | # 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit
10 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
11 | #
12 | # ==========================================================================================
13 | #
14 | # BSD-3 License
15 | #
16 | # Redistribution and use in source and binary forms, with or without
17 | # modification, are permitted provided that the following conditions are met:
18 | #
19 | # * Redistributions of source code must retain the above copyright notice, this
20 | # list of conditions and the following disclaimer.
21 | #
22 | # * Redistributions in binary form must reproduce the above copyright notice,
23 | # this list of conditions and the following disclaimer in the documentation
24 | # and/or other materials provided with the distribution.
25 | #
26 | # * Neither the name of the copyright holder nor the names of its
27 | # contributors may be used to endorse or promote products derived from
28 | # this software without specific prior written permission.
29 | #
30 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
31 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
32 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
33 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
34 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
35 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
36 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
37 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
38 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
39 |
40 | import re
41 | import torch
42 | import torch.nn as nn
43 | import torch.nn.functional as F
44 | import torch.utils.model_zoo as model_zoo
45 | from collections import OrderedDict
46 | from models_lpf import *
47 |
48 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
49 |
50 |
51 | # model_urls = {
52 | # 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth',
53 | # 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth',
54 | # 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth',
55 | # 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth',
56 | # }
57 |
58 |
59 | class _DenseLayer(nn.Sequential):
60 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
61 | super(_DenseLayer, self).__init__()
62 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
63 | self.add_module('relu1', nn.ReLU(inplace=True)),
64 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size *
65 | growth_rate, kernel_size=1, stride=1, bias=False)),
66 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
67 | self.add_module('relu2', nn.ReLU(inplace=True)),
68 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
69 | kernel_size=3, stride=1, padding=1, bias=False)),
70 | self.drop_rate = drop_rate
71 |
72 | def forward(self, x):
73 | new_features = super(_DenseLayer, self).forward(x)
74 | if self.drop_rate > 0:
75 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
76 | return torch.cat([x, new_features], 1)
77 |
78 |
79 | class _DenseBlock(nn.Sequential):
80 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate):
81 | super(_DenseBlock, self).__init__()
82 | for i in range(num_layers):
83 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate)
84 | self.add_module('denselayer%d' % (i + 1), layer)
85 |
86 |
87 | class _Transition(nn.Sequential):
88 | def __init__(self, num_input_features, num_output_features, filter_size=1):
89 | super(_Transition, self).__init__()
90 | self.add_module('norm', nn.BatchNorm2d(num_input_features))
91 | self.add_module('relu', nn.ReLU(inplace=True))
92 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
93 | kernel_size=1, stride=1, bias=False))
94 | # self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))
95 | self.add_module('pool', Downsample(filt_size=filter_size, stride=2, channels=num_output_features))
96 |
97 |
98 | class DenseNet(nn.Module):
99 | r"""Densenet-BC model class, based on
100 | `"Densely Connected Convolutional Networks" `_
101 | Args:
102 | growth_rate (int) - how many filters to add each layer (`k` in paper)
103 | block_config (list of 4 ints) - how many layers in each pooling block
104 | num_init_features (int) - the number of filters to learn in the first convolution layer
105 | bn_size (int) - multiplicative factor for number of bottle neck layers
106 | (i.e. bn_size * k features in the bottleneck layer)
107 | drop_rate (float) - dropout rate after each dense layer
108 | num_classes (int) - number of classification classes
109 | """
110 |
111 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
112 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000,
113 | filter_size=1, pool_only=True):
114 |
115 | super(DenseNet, self).__init__()
116 |
117 | # First convolution
118 | if(pool_only):
119 | self.features = nn.Sequential(OrderedDict([
120 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
121 | ('norm0', nn.BatchNorm2d(num_init_features)),
122 | ('relu0', nn.ReLU(inplace=True)),
123 | ('max0', nn.MaxPool2d(kernel_size=3, stride=1, padding=1)),
124 | ('pool0', Downsample(filt_size=filter_size, stride=2, channels=num_init_features)),
125 | ]))
126 | else:
127 | self.features = nn.Sequential(OrderedDict([
128 | ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=1, padding=3, bias=False)),
129 | ('norm0', nn.BatchNorm2d(num_init_features)),
130 | ('relu0', nn.ReLU(inplace=True)),
131 | ('ds0', Downsample(filt_size=filter_size, stride=2, channels=num_init_features)),
132 | ('max0', nn.MaxPool2d(kernel_size=3, stride=1, padding=1)),
133 | ('pool0', Downsample(filt_size=filter_size, stride=2, channels=num_init_features)),
134 | ]))
135 |
136 | # Each denseblock
137 | num_features = num_init_features
138 | for i, num_layers in enumerate(block_config):
139 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features,
140 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate)
141 | self.features.add_module('denseblock%d' % (i + 1), block)
142 | num_features = num_features + num_layers * growth_rate
143 | if i != len(block_config) - 1:
144 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2, filter_size=filter_size)
145 | self.features.add_module('transition%d' % (i + 1), trans)
146 | num_features = num_features // 2
147 |
148 | # Final batch norm
149 | self.features.add_module('norm5', nn.BatchNorm2d(num_features))
150 |
151 | # Linear layer
152 | self.classifier = nn.Linear(num_features, num_classes)
153 |
154 | # Official init from torch repo.
155 | for m in self.modules():
156 | if isinstance(m, nn.Conv2d):
157 | if(m.in_channels!=m.out_channels or m.out_channels!=m.groups or m.bias is not None):
158 | # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics
159 | nn.init.kaiming_normal_(m.weight)
160 | else:
161 | print('Not initializing')
162 | elif isinstance(m, nn.BatchNorm2d):
163 | nn.init.constant_(m.weight, 1)
164 | nn.init.constant_(m.bias, 0)
165 | elif isinstance(m, nn.Linear):
166 | nn.init.constant_(m.bias, 0)
167 |
168 | def forward(self, x):
169 | features = self.features(x)
170 | out = F.relu(features, inplace=True)
171 | out = F.adaptive_avg_pool2d(out, (1, 1)).view(features.size(0), -1)
172 | out = self.classifier(out)
173 | return out
174 |
175 |
176 | def _load_state_dict(model, model_url):
177 | # '.'s are no longer allowed in module names, but pervious _DenseLayer
178 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
179 | # They are also in the checkpoints in model_urls. This pattern is used
180 | # to find such keys.
181 | pattern = re.compile(
182 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
183 | state_dict = model_zoo.load_url(model_url)
184 | for key in list(state_dict.keys()):
185 | res = pattern.match(key)
186 | if res:
187 | new_key = res.group(1) + res.group(2)
188 | state_dict[new_key] = state_dict[key]
189 | del state_dict[key]
190 | model.load_state_dict(state_dict)
191 |
192 |
193 | def densenet121(pretrained=False, filter_size=1, pool_only=True, **kwargs):
194 | r"""Densenet-121 model from
195 | `"Densely Connected Convolutional Networks" `_
196 | Args:
197 | pretrained (bool): If True, returns a model pre-trained on ImageNet
198 | """
199 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
200 | filter_size=filter_size, pool_only=pool_only, **kwargs)
201 | if pretrained:
202 | _load_state_dict(model, model_urls['densenet121'])
203 | return model
204 |
205 |
206 | def densenet169(pretrained=False, filter_size=1, pool_only=True, **kwargs):
207 | r"""Densenet-169 model from
208 | `"Densely Connected Convolutional Networks" `_
209 | Args:
210 | pretrained (bool): If True, returns a model pre-trained on ImageNet
211 | """
212 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
213 | filter_size=filter_size, pool_only=pool_only, **kwargs)
214 | if pretrained:
215 | _load_state_dict(model, model_urls['densenet169'])
216 | return model
217 |
218 |
219 | def densenet201(pretrained=False, filter_size=1, pool_only=True, **kwargs):
220 | r"""Densenet-201 model from
221 | `"Densely Connected Convolutional Networks" `_
222 | Args:
223 | pretrained (bool): If True, returns a model pre-trained on ImageNet
224 | """
225 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
226 | filter_size=filter_size, pool_only=pool_only, **kwargs)
227 | if pretrained:
228 | _load_state_dict(model, model_urls['densenet201'])
229 | return model
230 |
231 |
232 | def densenet161(pretrained=False, filter_size=1, pool_only=True, **kwargs):
233 | r"""Densenet-161 model from
234 | `"Densely Connected Convolutional Networks" `_
235 | Args:
236 | pretrained (bool): If True, returns a model pre-trained on ImageNet
237 | """
238 | model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
239 | filter_size=filter_size, pool_only=pool_only, **kwargs)
240 | if pretrained:
241 | _load_state_dict(model, model_urls['densenet161'])
242 | return model
--------------------------------------------------------------------------------
/code/models_lpf/downsample.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, Adobe Inc. All rights reserved.
2 | #
3 | # This work is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
4 | # 4.0 International Public License. To view a copy of this license, visit
5 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
6 |
7 | import torch
8 | import torch.nn.parallel
9 | import numpy as np
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from IPython import embed
13 |
14 | class Downsample(nn.Module):
15 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
16 | super(Downsample, self).__init__()
17 | self.filt_size = filt_size
18 | self.pad_off = pad_off
19 | self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))]
20 | self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes]
21 | self.stride = stride
22 | self.off = int((self.stride-1)/2.)
23 | self.channels = channels
24 |
25 | if(self.filt_size==1):
26 | a = np.array([1.,])
27 | elif(self.filt_size==2):
28 | a = np.array([1., 1.])
29 | elif(self.filt_size==3):
30 | a = np.array([1., 2., 1.])
31 | elif(self.filt_size==4):
32 | a = np.array([1., 3., 3., 1.])
33 | elif(self.filt_size==5):
34 | a = np.array([1., 4., 6., 4., 1.])
35 | elif(self.filt_size==6):
36 | a = np.array([1., 5., 10., 10., 5., 1.])
37 | elif(self.filt_size==7):
38 | a = np.array([1., 6., 15., 20., 15., 6., 1.])
39 |
40 | filt = torch.Tensor(a[:,None]*a[None,:])
41 | filt = filt/torch.sum(filt)
42 | self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1)))
43 |
44 | self.pad = get_pad_layer(pad_type)(self.pad_sizes)
45 |
46 | def forward(self, inp):
47 | if(self.filt_size==1):
48 | if(self.pad_off==0):
49 | return inp[:,:,::self.stride,::self.stride]
50 | else:
51 | return self.pad(inp)[:,:,::self.stride,::self.stride]
52 | else:
53 | return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
54 |
55 | def get_pad_layer(pad_type):
56 | if(pad_type in ['refl','reflect']):
57 | PadLayer = nn.ReflectionPad2d
58 | elif(pad_type in ['repl','replicate']):
59 | PadLayer = nn.ReplicationPad2d
60 | elif(pad_type=='zero'):
61 | PadLayer = nn.ZeroPad2d
62 | else:
63 | print('Pad type [%s] not recognized'%pad_type)
64 | return PadLayer
65 |
66 | class Downsample1D(nn.Module):
67 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
68 | super(Downsample1D, self).__init__()
69 | self.filt_size = filt_size
70 | self.pad_off = pad_off
71 | self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
72 | self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
73 | self.stride = stride
74 | self.off = int((self.stride - 1) / 2.)
75 | self.channels = channels
76 |
77 | # print('Filter size [%i]' % filt_size)
78 | if(self.filt_size == 1):
79 | a = np.array([1., ])
80 | elif(self.filt_size == 2):
81 | a = np.array([1., 1.])
82 | elif(self.filt_size == 3):
83 | a = np.array([1., 2., 1.])
84 | elif(self.filt_size == 4):
85 | a = np.array([1., 3., 3., 1.])
86 | elif(self.filt_size == 5):
87 | a = np.array([1., 4., 6., 4., 1.])
88 | elif(self.filt_size == 6):
89 | a = np.array([1., 5., 10., 10., 5., 1.])
90 | elif(self.filt_size == 7):
91 | a = np.array([1., 6., 15., 20., 15., 6., 1.])
92 |
93 | filt = torch.Tensor(a)
94 | filt = filt / torch.sum(filt)
95 | self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1)))
96 |
97 | self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes)
98 |
99 | def forward(self, inp):
100 | if(self.filt_size == 1):
101 | if(self.pad_off == 0):
102 | return inp[:, :, ::self.stride]
103 | else:
104 | return self.pad(inp)[:, :, ::self.stride]
105 | else:
106 | return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
107 |
108 | def get_pad_layer_1d(pad_type):
109 | if(pad_type in ['refl', 'reflect']):
110 | PadLayer = nn.ReflectionPad1d
111 | elif(pad_type in ['repl', 'replicate']):
112 | PadLayer = nn.ReplicationPad1d
113 | elif(pad_type == 'zero'):
114 | PadLayer = nn.ZeroPad1d
115 | else:
116 | print('Pad type [%s] not recognized' % pad_type)
117 | return PadLayer
118 |
--------------------------------------------------------------------------------
/code/models_lpf/mobilenet.py:
--------------------------------------------------------------------------------
1 | # This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models.
2 | # Copyright (c) 2017 Torch Contributors.
3 | # The Pytorch examples are available under the BSD 3-Clause License.
4 |
5 | # ==========================================================================================
6 |
7 | # Adobe’s modifications are Copyright 2019 Adobe. All rights reserved.
8 | # Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
9 | # 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit
10 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
11 |
12 | # ==========================================================================================
13 |
14 | # BSD-3 License
15 |
16 | # Redistribution and use in source and binary forms, with or without
17 | # modification, are permitted provided that the following conditions are met:
18 |
19 | # * Redistributions of source code must retain the above copyright notice, this
20 | # list of conditions and the following disclaimer.
21 |
22 | # * Redistributions in binary form must reproduce the above copyright notice,
23 | # this list of conditions and the following disclaimer in the documentation
24 | # and/or other materials provided with the distribution.
25 |
26 | # * Neither the name of the copyright holder nor the names of its
27 | # contributors may be used to endorse or promote products derived from
28 | # this software without specific prior written permission.
29 |
30 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
31 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
32 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
33 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
34 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
35 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
36 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
37 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
38 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
39 |
40 | from torch import nn
41 | from models_lpf import *
42 |
43 | __all__ = ['MobileNetV2', 'mobilenet_v2']
44 |
45 |
46 | # model_urls = {
47 | # 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
48 | # }
49 |
50 |
51 | class ConvBNReLU(nn.Sequential):
52 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
53 | padding = (kernel_size - 1) // 2
54 | super(ConvBNReLU, self).__init__(
55 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
56 | nn.BatchNorm2d(out_planes),
57 | nn.ReLU6(inplace=True)
58 | )
59 |
60 |
61 | class InvertedResidual(nn.Module):
62 | def __init__(self, inp, oup, stride, expand_ratio, filter_size=1):
63 | super(InvertedResidual, self).__init__()
64 | self.stride = stride
65 | assert stride in [1, 2]
66 |
67 | hidden_dim = int(round(inp * expand_ratio))
68 | self.use_res_connect = self.stride == 1 and inp == oup
69 |
70 | layers = []
71 | if expand_ratio != 1:
72 | # pw
73 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
74 | if(stride==1):
75 | layers.extend([
76 | # dw
77 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim),
78 | # pw-linear
79 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
80 | nn.BatchNorm2d(oup),
81 | ])
82 | else:
83 | layers.extend([
84 | # dw
85 | ConvBNReLU(hidden_dim, hidden_dim, stride=1, groups=hidden_dim),
86 | Downsample(filt_size=filter_size, stride=stride, channels=hidden_dim),
87 | # pw-linear
88 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
89 | nn.BatchNorm2d(oup),
90 | ])
91 | self.conv = nn.Sequential(*layers)
92 |
93 | def forward(self, x):
94 | if self.use_res_connect:
95 | return x + self.conv(x)
96 | else:
97 | return self.conv(x)
98 |
99 |
100 | class MobileNetV2(nn.Module):
101 | def __init__(self, num_classes=1000, width_mult=1.0, filter_size=1):
102 | super(MobileNetV2, self).__init__()
103 | block = InvertedResidual
104 | input_channel = 32
105 | last_channel = 1280
106 | inverted_residual_setting = [
107 | # t, c, n, s
108 | [1, 16, 1, 1],
109 | [6, 24, 2, 2],
110 | [6, 32, 3, 2],
111 | [6, 64, 4, 2],
112 | [6, 96, 3, 1],
113 | [6, 160, 3, 2],
114 | [6, 320, 1, 1],
115 | ]
116 |
117 | # building first layer
118 | input_channel = int(input_channel * width_mult)
119 | self.last_channel = int(last_channel * max(1.0, width_mult))
120 | features = [ConvBNReLU(3, input_channel, stride=2)]
121 | # building inverted residual blocks
122 | for t, c, n, s in inverted_residual_setting:
123 | output_channel = int(c * width_mult)
124 | for i in range(n):
125 | stride = s if i == 0 else 1
126 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, filter_size=filter_size))
127 | input_channel = output_channel
128 | # building last several layers
129 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
130 | # make it nn.Sequential
131 | self.features = nn.Sequential(*features)
132 |
133 | # building classifier
134 | self.classifier = nn.Sequential(
135 | # nn.Dropout(0.2),
136 | nn.Linear(self.last_channel, num_classes),
137 | )
138 |
139 | # weight initialization
140 | for m in self.modules():
141 | if isinstance(m, nn.Conv2d):
142 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
143 | if m.bias is not None:
144 | nn.init.zeros_(m.bias)
145 | elif isinstance(m, nn.BatchNorm2d):
146 | nn.init.ones_(m.weight)
147 | nn.init.zeros_(m.bias)
148 | elif isinstance(m, nn.Linear):
149 | nn.init.normal_(m.weight, 0, 0.01)
150 | nn.init.zeros_(m.bias)
151 |
152 | def forward(self, x):
153 | x = self.features(x)
154 | x = x.mean([2, 3])
155 | x = self.classifier(x)
156 | return x
157 |
158 |
159 | def mobilenet_v2(pretrained=False, progress=True, filter_size=1, **kwargs):
160 | """
161 | Constructs a MobileNetV2 architecture from
162 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
163 | Args:
164 | pretrained (bool): If True, returns a model pre-trained on ImageNet
165 | progress (bool): If True, displays a progress bar of the download to stderr
166 | """
167 | model = MobileNetV2(filter_size=filter_size, **kwargs)
168 | # if pretrained:
169 | # state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
170 | # progress=progress)
171 | # model.load_state_dict(state_dict)
172 | return model
173 |
--------------------------------------------------------------------------------
/code/models_lpf/resnet.py:
--------------------------------------------------------------------------------
1 | # This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models.
2 | # Copyright (c) 2017 Torch Contributors.
3 | # The Pytorch examples are available under the BSD 3-Clause License.
4 | #
5 | # ==========================================================================================
6 | #
7 | # Adobe’s modifications are Copyright 2019 Adobe. All rights reserved.
8 | # Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
9 | # 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit
10 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
11 | #
12 | # ==========================================================================================
13 | #
14 | # BSD-3 License
15 | #
16 | # Redistribution and use in source and binary forms, with or without
17 | # modification, are permitted provided that the following conditions are met:
18 | #
19 | # * Redistributions of source code must retain the above copyright notice, this
20 | # list of conditions and the following disclaimer.
21 | #
22 | # * Redistributions in binary form must reproduce the above copyright notice,
23 | # this list of conditions and the following disclaimer in the documentation
24 | # and/or other materials provided with the distribution.
25 | #
26 | # * Neither the name of the copyright holder nor the names of its
27 | # contributors may be used to endorse or promote products derived from
28 | # this software without specific prior written permission.
29 | #
30 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
31 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
32 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
33 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
34 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
35 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
36 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
37 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
38 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
39 |
40 | import torch.nn as nn
41 | import torch.utils.model_zoo as model_zoo
42 | from models_lpf import *
43 |
44 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
45 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']
46 |
47 |
48 | # model_urls = {
49 | # 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
50 | # 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
51 | # 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
52 | # 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
53 | # 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
54 | # }
55 |
56 |
57 | def conv3x3(in_planes, out_planes, stride=1, groups=1):
58 | """3x3 convolution with padding"""
59 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
60 | padding=1, groups=groups, bias=False)
61 |
62 | def conv1x1(in_planes, out_planes, stride=1):
63 | """1x1 convolution"""
64 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
65 |
66 | class BasicBlock(nn.Module):
67 | expansion = 1
68 |
69 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1):
70 | super(BasicBlock, self).__init__()
71 | if norm_layer is None:
72 | norm_layer = nn.BatchNorm2d
73 | if groups != 1:
74 | raise ValueError('BasicBlock only supports groups=1')
75 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
76 | self.conv1 = conv3x3(inplanes, planes)
77 | self.bn1 = norm_layer(planes)
78 | self.relu = nn.ReLU(inplace=True)
79 | if(stride==1):
80 | self.conv2 = conv3x3(planes,planes)
81 | else:
82 | self.conv2 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes),
83 | conv3x3(planes, planes),)
84 | self.bn2 = norm_layer(planes)
85 | self.downsample = downsample
86 | self.stride = stride
87 |
88 | def forward(self, x):
89 | identity = x
90 |
91 | out = self.conv1(x)
92 | out = self.bn1(out)
93 | out = self.relu(out)
94 |
95 | out = self.conv2(out)
96 | out = self.bn2(out)
97 |
98 | if self.downsample is not None:
99 | identity = self.downsample(x)
100 |
101 | out += identity
102 | out = self.relu(out)
103 |
104 | return out
105 |
106 |
107 | class Bottleneck(nn.Module):
108 | expansion = 4
109 |
110 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, norm_layer=None, filter_size=1):
111 | super(Bottleneck, self).__init__()
112 | if norm_layer is None:
113 | norm_layer = nn.BatchNorm2d
114 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
115 | self.conv1 = conv1x1(inplanes, planes)
116 | self.bn1 = norm_layer(planes)
117 | self.conv2 = conv3x3(planes, planes, groups) # stride moved
118 | self.bn2 = norm_layer(planes)
119 | if(stride==1):
120 | self.conv3 = conv1x1(planes, planes * self.expansion)
121 | else:
122 | self.conv3 = nn.Sequential(Downsample(filt_size=filter_size, stride=stride, channels=planes),
123 | conv1x1(planes, planes * self.expansion))
124 | self.bn3 = norm_layer(planes * self.expansion)
125 | self.relu = nn.ReLU(inplace=True)
126 | self.downsample = downsample
127 | self.stride = stride
128 |
129 | def forward(self, x):
130 | identity = x
131 |
132 | out = self.conv1(x)
133 | out = self.bn1(out)
134 | out = self.relu(out)
135 |
136 | out = self.conv2(out)
137 | out = self.bn2(out)
138 | out = self.relu(out)
139 |
140 | out = self.conv3(out)
141 | out = self.bn3(out)
142 |
143 | if self.downsample is not None:
144 | identity = self.downsample(x)
145 |
146 | out += identity
147 | out = self.relu(out)
148 |
149 | return out
150 |
151 |
152 | class ResNet(nn.Module):
153 |
154 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
155 | groups=1, width_per_group=64, norm_layer=None, filter_size=1, pool_only=True):
156 | super(ResNet, self).__init__()
157 | if norm_layer is None:
158 | norm_layer = nn.BatchNorm2d
159 | planes = [int(width_per_group * groups * 2 ** i) for i in range(4)]
160 | self.inplanes = planes[0]
161 |
162 | if(pool_only):
163 | self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=2, padding=3, bias=False)
164 | else:
165 | self.conv1 = nn.Conv2d(3, planes[0], kernel_size=7, stride=1, padding=3, bias=False)
166 | self.bn1 = norm_layer(planes[0])
167 | self.relu = nn.ReLU(inplace=True)
168 |
169 | if(pool_only):
170 | self.maxpool = nn.Sequential(*[nn.MaxPool2d(kernel_size=2, stride=1),
171 | Downsample(filt_size=filter_size, stride=2, channels=planes[0])])
172 | else:
173 | self.maxpool = nn.Sequential(*[Downsample(filt_size=filter_size, stride=2, channels=planes[0]),
174 | nn.MaxPool2d(kernel_size=2, stride=1),
175 | Downsample(filt_size=filter_size, stride=2, channels=planes[0])])
176 |
177 | self.layer1 = self._make_layer(block, planes[0], layers[0], groups=groups, norm_layer=norm_layer)
178 | self.layer2 = self._make_layer(block, planes[1], layers[1], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size)
179 | self.layer3 = self._make_layer(block, planes[2], layers[2], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size)
180 | self.layer4 = self._make_layer(block, planes[3], layers[3], stride=2, groups=groups, norm_layer=norm_layer, filter_size=filter_size)
181 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
182 | self.fc = nn.Linear(planes[3] * block.expansion, num_classes)
183 |
184 | for m in self.modules():
185 | if isinstance(m, nn.Conv2d):
186 | if(m.in_channels!=m.out_channels or m.out_channels!=m.groups or m.bias is not None):
187 | # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics
188 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
189 | else:
190 | print('Not initializing')
191 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
192 | nn.init.constant_(m.weight, 1)
193 | nn.init.constant_(m.bias, 0)
194 |
195 | # Zero-initialize the last BN in each residual branch,
196 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
197 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
198 | if zero_init_residual:
199 | for m in self.modules():
200 | if isinstance(m, Bottleneck):
201 | nn.init.constant_(m.bn3.weight, 0)
202 | elif isinstance(m, BasicBlock):
203 | nn.init.constant_(m.bn2.weight, 0)
204 |
205 | def _make_layer(self, block, planes, blocks, stride=1, groups=1, norm_layer=None, filter_size=1):
206 | if norm_layer is None:
207 | norm_layer = nn.BatchNorm2d
208 | downsample = None
209 | if stride != 1 or self.inplanes != planes * block.expansion:
210 | # downsample = nn.Sequential(
211 | # conv1x1(self.inplanes, planes * block.expansion, stride, filter_size=filter_size),
212 | # norm_layer(planes * block.expansion),
213 | # )
214 |
215 | downsample = [Downsample(filt_size=filter_size, stride=stride, channels=self.inplanes),] if(stride !=1) else []
216 | downsample += [conv1x1(self.inplanes, planes * block.expansion, 1),
217 | norm_layer(planes * block.expansion)]
218 | # print(downsample)
219 | downsample = nn.Sequential(*downsample)
220 |
221 | layers = []
222 | layers.append(block(self.inplanes, planes, stride, downsample, groups, norm_layer, filter_size=filter_size))
223 | self.inplanes = planes * block.expansion
224 | for _ in range(1, blocks):
225 | layers.append(block(self.inplanes, planes, groups=groups, norm_layer=norm_layer, filter_size=filter_size))
226 |
227 | return nn.Sequential(*layers)
228 |
229 | def forward(self, x):
230 | x = self.conv1(x)
231 | x = self.bn1(x)
232 | x = self.relu(x)
233 | x = self.maxpool(x)
234 |
235 | x = self.layer1(x)
236 | x = self.layer2(x)
237 | x = self.layer3(x)
238 | x = self.layer4(x)
239 |
240 | x = self.avgpool(x)
241 | x = x.view(x.size(0), -1)
242 | x = self.fc(x)
243 |
244 | return x
245 |
246 |
247 | def resnet18(pretrained=False, filter_size=1, pool_only=True, **kwargs):
248 | """Constructs a ResNet-18 model.
249 | Args:
250 | pretrained (bool): If True, returns a model pre-trained on ImageNet
251 | """
252 | model = ResNet(BasicBlock, [2, 2, 2, 2], filter_size=filter_size, pool_only=pool_only, **kwargs)
253 | if pretrained:
254 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
255 | return model
256 |
257 |
258 | def resnet34(pretrained=False, filter_size=1, pool_only=True, **kwargs):
259 | """Constructs a ResNet-34 model.
260 | Args:
261 | pretrained (bool): If True, returns a model pre-trained on ImageNet
262 | """
263 | model = ResNet(BasicBlock, [3, 4, 6, 3], filter_size=filter_size, pool_only=pool_only, **kwargs)
264 | if pretrained:
265 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
266 | return model
267 |
268 |
269 | def resnet50(pretrained=False, filter_size=1, pool_only=True, **kwargs):
270 | """Constructs a ResNet-50 model.
271 | Args:
272 | pretrained (bool): If True, returns a model pre-trained on ImageNet
273 | """
274 | model = ResNet(Bottleneck, [3, 4, 6, 3], filter_size=filter_size, pool_only=pool_only, **kwargs)
275 | if pretrained:
276 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
277 | return model
278 |
279 |
280 | def resnet101(pretrained=False, filter_size=1, pool_only=True, **kwargs):
281 | """Constructs a ResNet-101 model.
282 | Args:
283 | pretrained (bool): If True, returns a model pre-trained on ImageNet
284 | """
285 | model = ResNet(Bottleneck, [3, 4, 23, 3], filter_size=filter_size, pool_only=pool_only, **kwargs)
286 | if pretrained:
287 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
288 | return model
289 |
290 |
291 | def resnet152(pretrained=False, filter_size=1, pool_only=True, **kwargs):
292 | """Constructs a ResNet-152 model.
293 | Args:
294 | pretrained (bool): If True, returns a model pre-trained on ImageNet
295 | """
296 | model = ResNet(Bottleneck, [3, 8, 36, 3], filter_size=filter_size, pool_only=pool_only, **kwargs)
297 | if pretrained:
298 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
299 | return model
300 |
301 |
302 | def resnext50_32x4d(pretrained=False, filter_size=1, pool_only=True, **kwargs):
303 | model = ResNet(Bottleneck, [3, 4, 6, 3], groups=4, width_per_group=32, filter_size=filter_size, pool_only=pool_only, **kwargs)
304 | # if pretrained:
305 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
306 | return model
307 |
308 |
309 | def resnext101_32x8d(pretrained=False, filter_size=1, pool_only=True, **kwargs):
310 | model = ResNet(Bottleneck, [3, 4, 23, 3], groups=8, width_per_group=32, filter_size=filter_size, pool_only=pool_only, **kwargs)
311 | # if pretrained:
312 | # model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
313 | return model
--------------------------------------------------------------------------------
/code/models_lpf/vgg.py:
--------------------------------------------------------------------------------
1 | # This code is built from the PyTorch examples repository: https://github.com/pytorch/vision/tree/master/torchvision/models.
2 | # Copyright (c) 2017 Torch Contributors.
3 | # The Pytorch examples are available under the BSD 3-Clause License.
4 | #
5 | # ==========================================================================================
6 | #
7 | # Adobe’s modifications are Copyright 2019 Adobe. All rights reserved.
8 | # Adobe’s modifications are licensed under the Creative Commons Attribution-NonCommercial-ShareAlike
9 | # 4.0 International Public License (CC-NC-SA-4.0). To view a copy of the license, visit
10 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode.
11 | #
12 | # ==========================================================================================
13 | #
14 | # BSD-3 License
15 | #
16 | # Redistribution and use in source and binary forms, with or without
17 | # modification, are permitted provided that the following conditions are met:
18 | #
19 | # * Redistributions of source code must retain the above copyright notice, this
20 | # list of conditions and the following disclaimer.
21 | #
22 | # * Redistributions in binary form must reproduce the above copyright notice,
23 | # this list of conditions and the following disclaimer in the documentation
24 | # and/or other materials provided with the distribution.
25 | #
26 | # * Neither the name of the copyright holder nor the names of its
27 | # contributors may be used to endorse or promote products derived from
28 | # this software without specific prior written permission.
29 | #
30 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
31 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
32 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
33 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
34 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
35 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
36 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
37 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
38 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
39 |
40 | import torch.nn as nn
41 | import torch.utils.model_zoo as model_zoo
42 | from models_lpf import *
43 |
44 | __all__ = [
45 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
46 | 'vgg19_bn', 'vgg19',
47 | ]
48 |
49 |
50 | # model_urls = {
51 | # 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
52 | # 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
53 | # 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
54 | # 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
55 | # 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
56 | # 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
57 | # 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
58 | # 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
59 | # }
60 |
61 |
62 | class VGG(nn.Module):
63 |
64 | def __init__(self, features, num_classes=1000, init_weights=True):
65 | super(VGG, self).__init__()
66 | self.features = features
67 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
68 | self.classifier = nn.Sequential(
69 | nn.Linear(512 * 7 * 7, 4096),
70 | nn.ReLU(True),
71 | nn.Dropout(),
72 | nn.Linear(4096, 4096),
73 | nn.ReLU(True),
74 | nn.Dropout(),
75 | nn.Linear(4096, num_classes),
76 | )
77 | if init_weights:
78 | self._initialize_weights()
79 |
80 | def forward(self, x):
81 | x = self.features(x)
82 | # print(x.shape)
83 | x = self.avgpool(x)
84 | x = x.view(x.size(0), -1)
85 | x = self.classifier(x)
86 | return x
87 |
88 | def _initialize_weights(self):
89 | for m in self.modules():
90 | if isinstance(m, nn.Conv2d):
91 | if(m.in_channels!=m.out_channels or m.out_channels!=m.groups or m.bias is not None):
92 | # don't want to reinitialize downsample layers, code assuming normal conv layers will not have these characteristics
93 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
94 | if m.bias is not None:
95 | nn.init.constant_(m.bias, 0)
96 | else:
97 | print('Not initializing')
98 | elif isinstance(m, nn.BatchNorm2d):
99 | nn.init.constant_(m.weight, 1)
100 | nn.init.constant_(m.bias, 0)
101 | elif isinstance(m, nn.Linear):
102 | nn.init.normal_(m.weight, 0, 0.01)
103 | nn.init.constant_(m.bias, 0)
104 |
105 |
106 | def make_layers(cfg, batch_norm=False, filter_size=1):
107 | layers = []
108 | in_channels = 3
109 | for v in cfg:
110 | if v == 'M':
111 | # layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
112 | layers += [nn.MaxPool2d(kernel_size=2, stride=1), Downsample(filt_size=filter_size, stride=2, channels=in_channels)]
113 | else:
114 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
115 | if batch_norm:
116 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
117 | else:
118 | layers += [conv2d, nn.ReLU(inplace=True)]
119 | in_channels = v
120 | return nn.Sequential(*layers)
121 |
122 |
123 | cfg = {
124 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
125 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
126 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
127 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
128 | }
129 |
130 | def vgg11(pretrained=False, filter_size=1, **kwargs):
131 | """VGG 11-layer model (configuration "A")
132 |
133 | Args:
134 | pretrained (bool): If True, returns a model pre-trained on ImageNet
135 | """
136 | if pretrained:
137 | kwargs['init_weights'] = False
138 | model = VGG(make_layers(cfg['A'], filter_size=filter_size), **kwargs)
139 | if pretrained:
140 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
141 | return model
142 |
143 |
144 | def vgg11_bn(pretrained=False, filter_size=1, **kwargs):
145 | """VGG 11-layer model (configuration "A") with batch normalization
146 |
147 | Args:
148 | pretrained (bool): If True, returns a model pre-trained on ImageNet
149 | """
150 | if pretrained:
151 | kwargs['init_weights'] = False
152 | model = VGG(make_layers(cfg['A'], filter_size=filter_size, batch_norm=True), **kwargs)
153 | if pretrained:
154 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
155 | return model
156 |
157 |
158 | def vgg13(pretrained=False, filter_size=1, **kwargs):
159 | """VGG 13-layer model (configuration "B")
160 |
161 | Args:
162 | pretrained (bool): If True, returns a model pre-trained on ImageNet
163 | """
164 | if pretrained:
165 | kwargs['init_weights'] = False
166 | model = VGG(make_layers(cfg['B'], filter_size=filter_size), **kwargs)
167 | if pretrained:
168 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
169 | return model
170 |
171 |
172 | def vgg13_bn(pretrained=False, filter_size=1, **kwargs):
173 | """VGG 13-layer model (configuration "B") with batch normalization
174 |
175 | Args:
176 | pretrained (bool): If True, returns a model pre-trained on ImageNet
177 | """
178 | if pretrained:
179 | kwargs['init_weights'] = False
180 | model = VGG(make_layers(cfg['B'], filter_size=filter_size, batch_norm=True), **kwargs)
181 | if pretrained:
182 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
183 | return model
184 |
185 |
186 | def vgg16(pretrained=False, filter_size=1, **kwargs):
187 | """VGG 16-layer model (configuration "D")
188 |
189 | Args:
190 | pretrained (bool): If True, returns a model pre-trained on ImageNet
191 | """
192 | if pretrained:
193 | kwargs['init_weights'] = False
194 | model = VGG(make_layers(cfg['D'], filter_size=filter_size), **kwargs)
195 | if pretrained:
196 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
197 | return model
198 |
199 |
200 | def vgg16_bn(pretrained=False, filter_size=1, **kwargs):
201 | """VGG 16-layer model (configuration "D") with batch normalization
202 |
203 | Args:
204 | pretrained (bool): If True, returns a model pre-trained on ImageNet
205 | """
206 | if pretrained:
207 | kwargs['init_weights'] = False
208 | model = VGG(make_layers(cfg['D'], filter_size=filter_size, batch_norm=True), **kwargs)
209 | if pretrained:
210 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
211 | return model
212 |
213 |
214 | def vgg19(pretrained=False, filter_size=1, **kwargs):
215 | """VGG 19-layer model (configuration "E")
216 |
217 | Args:
218 | pretrained (bool): If True, returns a model pre-trained on ImageNet
219 | """
220 | if pretrained:
221 | kwargs['init_weights'] = False
222 | model = VGG(make_layers(cfg['E'], filter_size=filter_size), **kwargs)
223 | if pretrained:
224 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
225 | return model
226 |
227 |
228 | def vgg19_bn(pretrained=False, filter_size=1, **kwargs):
229 | """VGG 19-layer model (configuration 'E') with batch normalization
230 |
231 | Args:
232 | pretrained (bool): If True, returns a model pre-trained on ImageNet
233 | """
234 | if pretrained:
235 | kwargs['init_weights'] = False
236 | model = VGG(make_layers(cfg['E'], filter_size=filter_size, batch_norm=True), **kwargs)
237 | if pretrained:
238 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
239 | return model
240 |
241 |
--------------------------------------------------------------------------------
/code/render.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import argparse
5 | from config import cfg
6 | from utils import *
7 | from modeling import build_model
8 | import cv2
9 | from tqdm import tqdm
10 |
11 | from imageio_ffmpeg import write_frames
12 |
13 | parser = argparse.ArgumentParser()
14 |
15 | parser.add_argument(
16 | "--config", default=None, type=str, help="config file path."
17 | )
18 | parser.add_argument(
19 | "--dataset", default=None, type=str, help="dataset path."
20 | )
21 | parser.add_argument(
22 | "--model", default=None, type=str, help="checkpoint model path."
23 | )
24 | parser.add_argument(
25 | "--output_path", default='./out', type=str, help="image / videos output path."
26 | )
27 | parser.add_argument(
28 | "--render_video", action='store_true', default=False, help="render around view videos."
29 | )
30 | parser.add_argument(
31 | "--camera_path", default=None, type=str, help="path to cameras sequence for rendering around-view videos."
32 | )
33 |
34 |
35 |
36 | args = parser.parse_args()
37 |
38 | os.makedirs(args.output_path, exist_ok=True)
39 |
40 | config_path = args.config
41 | dataset_path = args.dataset
42 | model_path = args.model
43 |
44 | # load config
45 | cfg.merge_from_file(config_path)
46 | tree_depth = cfg.MODEL.TREE_DEPTH
47 | img_size = cfg.INPUT.SIZE_TEST
48 |
49 | # load canonical volume
50 | coords = torch.load(os.path.join(dataset_path, 'volumes/coords_init.pth'), map_location='cpu').float().cuda()
51 | skeleton_init = torch.from_numpy(
52 | campose_to_extrinsic(np.loadtxt(os.path.join(dataset_path, 'bones/Bones_%04d.inf' % 0)))).float()
53 | bone_parents = torch.from_numpy(np.load(os.path.join(dataset_path, 'bones/bone_parents.npy'))).long()
54 | pose_init = get_local_eular(skeleton_init, bone_parents)
55 |
56 | skinning_weights = torch.load(os.path.join(dataset_path, 'volumes/volume_weights.pth'),
57 | map_location='cpu').float().cuda()
58 | joint_index = torch.load(os.path.join(dataset_path, 'volumes/volume_indices.pth'),
59 | map_location='cpu').int().cuda()
60 | volume_radius = torch.from_numpy(np.loadtxt(os.path.join(dataset_path, 'volumes/radius.txt'))).float()
61 |
62 | # load cameras
63 | camposes = np.loadtxt(os.path.join(dataset_path, 'CamPose.inf'))
64 | Ts = torch.Tensor(campose_to_extrinsic(camposes))
65 | center = torch.mean(coords, dim=0).cpu()
66 | Ks = torch.Tensor(read_intrinsics(os.path.join(dataset_path, 'Intrinsic.inf')))
67 | K = Ks[0]
68 | K[:2, :3] /= 2
69 |
70 | # load motions
71 | sequences = []
72 | seq_nums = []
73 | with open(os.path.join(dataset_path, 'sequences'), 'r') as f:
74 | lines = f.readlines()
75 | for line in lines:
76 | seq, num = line.strip().split(' ')
77 | sequences.append(seq)
78 | seq_nums.append(int(num))
79 |
80 | pose_features = []
81 | bone_matrices = []
82 | skeletons = []
83 | seq_trees = []
84 |
85 | pose_feature_pad = torch.cat([skeleton_init[:, :3, 3], pose_init], dim=1)
86 |
87 | for seq_id, seq in enumerate(sequences):
88 | seq_num = seq_nums[seq_id]
89 | skeleton, pose_feature_s, bone_matrice = [], [], []
90 | seq_trees.append([None] * seq_num)
91 | for i in range(seq_num):
92 | pose = torch.from_numpy(
93 | campose_to_extrinsic(np.loadtxt(
94 | os.path.join(dataset_path, 'bones/%s/Bones_%04d.inf' % (seq, i))))).float()
95 | matrices = torch.matmul(skeleton_init, torch.inverse(pose))
96 | skeleton.append(pose.unsqueeze(0))
97 | bone_matrice.append(matrices.unsqueeze(0))
98 |
99 | pose = get_local_eular(pose, bone_parents)
100 | delta_pose = pose - pose_init
101 | pose_feature = torch.cat([pose_feature_pad, delta_pose], dim=1)
102 | if torch.any(torch.isnan(pose_feature)):
103 | pose_feature = torch.where(torch.isnan(pose_feature), torch.zeros_like(pose_feature).float(),
104 | pose_feature)
105 |
106 | pose_feature_s.append(pose_feature.unsqueeze(0))
107 |
108 | skeleton = torch.cat(skeleton, dim=0).float()
109 | pose_feature_s = torch.cat(pose_feature_s, dim=0).float()
110 | bone_matrice = torch.cat(bone_matrice, dim=0).float()
111 |
112 | skeletons.append(skeleton.cuda())
113 | pose_features.append(pose_feature_s.cuda())
114 | bone_matrices.append(bone_matrice.cuda())
115 |
116 | # load models
117 | model = build_model(cfg, flut_size=coords.shape[0], bone_feature_dim=pose_features[0].shape[2]).cuda()
118 | checkpoint = torch.load(model_path, map_location='cpu')
119 | model.load_state_dict(checkpoint['model'])
120 | model.eval()
121 |
122 | # render
123 | if not args.render_video:
124 | seq_id = 0
125 | frame_id = 0
126 | cam_id = 0
127 | cur_T = Ts[cam_id]
128 | if seq_trees[seq_id][frame_id] is None:
129 | transformation_matrices = torch.inverse(bone_matrices[seq_id][frame_id])
130 | skeleton = skeletons[seq_id][frame_id][:, :3, 3]
131 | t, _, _ = warp_build_octree(coords, transformation_matrices=transformation_matrices,
132 | skeleton=skeleton, vb_weights=skinning_weights,
133 | vb_indices=joint_index, radius=volume_radius,
134 | max_depth=tree_depth)
135 | seq_trees[seq_id][frame_id] = t
136 |
137 | rgb, mask, feature = render_image(cfg, model, K, cur_T, (img_size[1], img_size[0]),
138 | tree=seq_trees[seq_id][frame_id],
139 | matrices=bone_matrices[seq_id][frame_id],
140 | joint_features=pose_features[seq_id][frame_id],
141 | bg=None, skinning_weights=skinning_weights,
142 | joint_index=joint_index)
143 |
144 | cv2.imwrite('rgb.jpg', cv2.cvtColor(rgb * 255, cv2.COLOR_BGR2RGB))
145 | cv2.imwrite('feature.jpg', feature * 255)
146 | else:
147 | assert args.camera_path is not None, 'Please provide cameras trajectory.'
148 | camposes = np.loadtxt(os.path.join(args.camera_path, 'CamPose_spiral.inf'))
149 | Ts = torch.Tensor(campose_to_extrinsic(camposes))
150 | center = torch.mean(coords, dim=0).cpu()
151 | Ks = torch.Tensor(read_intrinsics(os.path.join(args.camera_path, 'Intrinsic_spiral.inf')))
152 |
153 | for seq_id in range(len(sequences)):
154 |
155 | writer_raw_rgb = write_frames(os.path.join(args.output_path, '%s_rgb.mp4' % sequences[seq_id]), img_size, fps=30, macro_block_size=8, quality=6) # size is (width, height)
156 | writer_raw_alpha = write_frames(os.path.join(args.output_path, '%s_alpha.mp4' % sequences[seq_id]), img_size, fps=30, macro_block_size=8, quality=6) # size is (width, height)
157 | writer_raw_feature = write_frames(os.path.join(args.output_path, '%s_feature.mp4' % sequences[seq_id]), img_size, fps=30, macro_block_size=8, quality=6) # size is (width, height)
158 | writer_raw_rgb.send(None)
159 | writer_raw_alpha.send(None)
160 | writer_raw_feature.send(None)
161 |
162 | for cam_id in tqdm(range(Ts.shape[0]), unit=" frame", desc=f"Rendering video"):
163 | frame_id = cam_id % seq_nums[seq_id]
164 |
165 | if seq_trees[seq_id][frame_id] is None:
166 | transformation_matrices = torch.inverse(bone_matrices[seq_id][frame_id])
167 | skeleton = skeletons[seq_id][frame_id][:, :3, 3]
168 | t, wtime, btime = warp_build_octree(coords, transformation_matrices=transformation_matrices,
169 | skeleton=skeleton, vb_weights=skinning_weights,
170 | vb_indices=joint_index, radius=volume_radius,
171 | max_depth=tree_depth)
172 | # cache the tree for faster rendering if gpu memory is enough
173 | # seq_trees[seq_id][frame_id] = t
174 | else:
175 | t = seq_trees[seq_id][frame_id]
176 |
177 | rgb, mask, feature = render_image(cfg, model, Ks[cam_id], Ts[cam_id], (img_size[1], img_size[0]),
178 | tree=seq_trees[seq_id][frame_id],
179 | matrices=bone_matrices[seq_id][frame_id],
180 | joint_features=pose_features[seq_id][frame_id],
181 | bg=None, skinning_weights=skinning_weights,
182 | joint_index=joint_index)
183 |
184 | img = rgb * 255
185 | feature = feature * 255
186 | alpha = torch.from_numpy(mask).unsqueeze(-1).repeat(1, 1, 3).numpy() * 255
187 | img = img.copy(order='C')
188 | writer_raw_rgb.send(img.astype(np.uint8))
189 | writer_raw_alpha.send(alpha.astype(np.uint8))
190 | writer_raw_feature.send(feature.astype(np.uint8))
191 |
192 | cv2.imwrite(os.path.join(args.output_path, '%s_rgb_%04d.jpg' % (sequences[seq_id], cam_id)),
193 | cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
194 | cv2.imwrite(os.path.join(args.output_path, '%s_alpha_%04d.jpg' % (sequences[seq_id], cam_id)), alpha)
195 | cv2.imwrite(os.path.join(args.output_path, '%s_feature_%04d.jpg' % (sequences[seq_id], cam_id)), feature)
196 |
197 | writer_raw_rgb.close()
198 | writer_raw_alpha.close()
199 | writer_raw_feature.close()
200 |
--------------------------------------------------------------------------------
/code/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | from .ray_sampling import ray_sampling, patch_sampling
8 | from .spherical_harmonics import computeRGB
9 | from .llinear_transform import compute_skinning_weights, compute_transformation, transform_coords
10 | from .build_octree import build_octree, load_octree, generate_transformation_matrices
11 | from .bone_parsing import get_local_transformation, get_local_eular, rotation_matrix_to_eular, eular_to_rotation_matrix
12 | from .utils import *
13 | from .rendering import *
--------------------------------------------------------------------------------
/code/utils/bone_parsing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from pytorch3d.transforms import so3_log_map, so3_rotation_angle, so3_relative_angle, matrix_to_euler_angles, euler_angles_to_matrix
5 |
6 |
7 | def get_local_transformation(transformations, parents):
8 | parents_t = torch.cat([torch.eye(4).unsqueeze(0), transformations[parents[1:]]])
9 | local_t = torch.matmul(torch.inverse(parents_t), transformations)
10 |
11 | return local_t
12 |
13 |
14 | def get_local_eular(transformations, parents, axis='XYZ'):
15 | local_t = get_local_transformation(transformations, parents)
16 | local_t = torch.clamp(local_t[:, :3, :3], max=1.)
17 | local_eular = matrix_to_euler_angles(local_t, axis)
18 |
19 | return local_eular
20 |
21 |
22 | def rotation_matrix_to_eular(rotations, axis='XYZ'):
23 | return matrix_to_euler_angles(rotations[:, :3, :3], axis)
24 |
25 |
26 | def eular_to_rotation_matrix(eular, axis='XYZ'):
27 | return euler_angles_to_matrix(eular, axis)
28 |
29 |
--------------------------------------------------------------------------------
/code/utils/build_octree.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 |
5 | import svox_t
6 |
7 | def build_octree(coordinates, radius=None, center=None, skeleton=None, data_dim=91, max_depth=10, data_format='SH9'):
8 | coordinates = coordinates.contiguous().cuda().detach()
9 |
10 | maxs_ = torch.max(coordinates, dim=0).values
11 | mins_ = torch.min(coordinates, dim=0).values
12 | radius_ = (maxs_ - mins_) / 2
13 | center_ = mins_ + radius_
14 |
15 | if radius is None:
16 | radius = radius_
17 |
18 | if center is None:
19 | center = center_
20 | # print(center, radius)
21 | t = svox_t.N3Tree(N=2,
22 | data_dim=data_dim,
23 | depth_limit=max_depth,
24 | init_reserve=5000,
25 | init_refine=0,
26 | # geom_resize_fact=1.5,
27 | radius=list(radius),
28 | center=list(center),
29 | data_format=data_format,
30 | extra_data=skeleton, ).cuda()
31 |
32 | for i in range(max_depth):
33 | t[coordinates].refine()
34 |
35 | t.shrink_to_fit()
36 | t.construct_tree(coordinates)
37 |
38 | return t.cpu()
39 |
40 | def load_octree(tree_path):
41 | return svox_t.N3Tree.load(tree_path)
42 |
43 | def generate_transformation_matrices(matrices, skinning_weights, joint_index):
44 | return svox_t.blend_transformation_matrix(matrices, skinning_weights, joint_index)
--------------------------------------------------------------------------------
/code/utils/llinear_transform.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 |
4 | def compute_skinning_weights(pos, tmp_vertices, tmp_weights, n_influence=3, n_binding=10, chunks=10000, debug=False):
5 | if pos.shape[0] > chunks:
6 | coords = pos.split(chunks, dim=0)
7 | else:
8 | coords = [pos]
9 |
10 | weights = []
11 |
12 | for coord in coords:
13 | dis = coord[:, None, :] - tmp_vertices
14 | dis = dis.pow(2).sum(-1).sqrt()
15 | dis_min, indice = torch.topk(dis, n_influence, largest=False, dim=1)
16 | w = dis_min.max(-1)[0][..., None] - dis_min
17 | w = torch.softmax(w, dim=-1)
18 | weight = torch.sum(tmp_weights[indice] * w[..., None], dim=1)
19 | weights.append(weight)
20 | weights = torch.cat(weights, dim=0)
21 |
22 | weights, indices = torch.topk(weights, n_binding, largest=True, dim=1)
23 | weights = weights / weights.sum(-1).unsqueeze(-1)
24 |
25 | return weights, indices
26 |
27 |
28 |
29 | def compute_transformation(src_pose, dst_pose, weights, indices, chunks=10000):
30 | if weights is None or indices is None:
31 | print('Please provide skinning weights.')
32 | return None
33 |
34 | transformation_matrices = []
35 | if weights.shape[0] > chunks:
36 | weights, indices = weights.split(chunks, dim=0), indices.split(chunks, dim=0)
37 | else:
38 | weights, indices = [weights], [indices]
39 |
40 | for i in range(len(weights)):
41 | weight, indice = weights[i], indices[i]
42 | src_matrices, dst_matricrs = src_pose[indice], dst_pose[indice]
43 | transformation_matrix = torch.matmul(dst_matricrs, src_matrices)
44 | transformation_matrix = torch.sum(transformation_matrix * weight[..., None, None], dim=1)
45 |
46 | transformation_matrices.append(transformation_matrix)
47 |
48 | transformation_matrices = torch.cat(transformation_matrices, dim=0)
49 |
50 | return transformation_matrices.float()
51 |
52 |
53 | def transform_coords(coords, matrices, chunks=10000):
54 | if coords.shape[0] > chunks:
55 | coords, matrices = coords.split(chunks, dim=0), matrices.split(chunks, dim=0)
56 | else:
57 | coords, matrices = [coords], [matrices]
58 |
59 | transformed_coords = []
60 |
61 | for i in range(len(coords)):
62 | coord, ms = coords[i], matrices[i]
63 | coord = coord.unsqueeze(-1)
64 | coord = torch.cat([coord, torch.ones((coord.size(0), 1, coord.size(2)), device=coord.device)], dim=1)
65 | coord = torch.matmul(ms, coord)
66 |
67 | transformed_coords.append(coord)
68 |
69 | transformed_coords = torch.cat(transformed_coords, dim=0).squeeze()[:, :3]
70 |
71 | return transformed_coords
72 |
73 |
74 |
--------------------------------------------------------------------------------
/code/utils/logger.py:
--------------------------------------------------------------------------------
1 | # encoding: utf-8
2 | """
3 | @author: sherlock
4 | @contact: sherlockliao01@gmail.com
5 | """
6 |
7 | import logging
8 | import os
9 | import sys
10 |
11 |
12 | def setup_logger(name, save_dir, distributed_rank):
13 | logger = logging.getLogger(name)
14 | logger.setLevel(logging.DEBUG)
15 | # don't log results for the non-master process
16 | if distributed_rank > 0:
17 | return logger
18 | ch = logging.StreamHandler(stream=sys.stdout)
19 | ch.setLevel(logging.DEBUG)
20 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
21 | ch.setFormatter(formatter)
22 | logger.addHandler(ch)
23 |
24 | if save_dir:
25 | fh = logging.FileHandler(os.path.join(save_dir, "log.txt"), mode='w')
26 | fh.setLevel(logging.DEBUG)
27 | fh.setFormatter(formatter)
28 | logger.addHandler(fh)
29 |
30 | return logger
31 |
--------------------------------------------------------------------------------
/code/utils/ray_sampling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 |
4 | '''
5 | Sample rays from views (and images) with/without masks
6 |
7 | --------------------------
8 | INPUT Tensors
9 | Ks: intrinsics of cameras (M,3,3)
10 | Ts: extrinsic of cameras (M,4,4)
11 | image_size: the size of image [H,W]
12 | images: (M,C,H,W)
13 | mask_threshold: a float threshold to mask rays
14 | masks:(M,H,W)
15 | depth:(M,H,W)
16 | -------------------
17 | OUPUT:
18 | list of rays: (N,6) dirs(3) + pos(3)
19 | RGB: (N,C)
20 | '''
21 |
22 |
23 | def ray_sampling(Ks, Ts, image_size, masks=None, mask_threshold=0.5, images=None, depth=None,
24 | far_depth=None, fine_depth=None):
25 | h = image_size[0]
26 | w = image_size[1]
27 | M = Ks.size(0)
28 |
29 | x = torch.linspace(0, h - 1, steps=h, device=Ks.device)
30 | y = torch.linspace(0, w - 1, steps=w, device=Ks.device)
31 |
32 | grid_x, grid_y = torch.meshgrid(x, y)
33 | coordinates = torch.stack([grid_y, grid_x]).unsqueeze(0).repeat(M, 1, 1, 1) # (M,2,H,W)
34 | coordinates = torch.cat([coordinates, torch.ones(coordinates.size(0), 1, coordinates.size(2),
35 | coordinates.size(3), device=Ks.device)], dim=1).permute(0, 2, 3,
36 | 1).unsqueeze(
37 | -1)
38 |
39 | inv_Ks = torch.inverse(Ks)
40 |
41 | dirs = torch.matmul(inv_Ks, coordinates) # (M,H,W,3,1)
42 | dirs = dirs / torch.norm(dirs, dim=3, keepdim=True)
43 |
44 | # zs to depth
45 | td = dirs.view(M, h, w, -1)
46 | if depth is not None:
47 | depth = depth.to(Ks.device)
48 | depth = depth / td[:, :, :, 2]
49 | if far_depth is not None:
50 | far_depth = far_depth.to(Ks.device)
51 | far_depth = far_depth / td[:, :, :, 2]
52 | else:
53 | far_depth = depth
54 | if fine_depth is not None:
55 | fine_depth = fine_depth.to(Ks.device)
56 | fine_depth = fine_depth / td[:, :, :, 2]
57 | else:
58 | fine_depth = depth
59 |
60 | dirs = torch.cat([dirs, torch.zeros(dirs.size(0), coordinates.size(1),
61 | coordinates.size(2), 1, 1, device=Ks.device)], dim=3) # (M,H,W,4,1)
62 |
63 | dirs = torch.matmul(Ts, dirs) # (M,H,W,4,1)
64 | dirs = dirs[:, :, :, 0:3, 0] # (M,H,W,3)
65 |
66 | pos = Ts[:, 0:3, 3] # (M,3)
67 | pos = pos.unsqueeze(1).unsqueeze(1).repeat(1, h, w, 1)
68 |
69 | rays = torch.cat([dirs, pos], dim=3) # (M,H,W,6)
70 | if depth is not None:
71 | depth = torch.stack([depth, far_depth, fine_depth], dim=3) # (M,H,W,2)
72 | ds = None
73 |
74 | if images is not None:
75 | images = images.to(Ks.device)
76 | rgbs = images.permute(0, 2, 3, 1) # (M,H,W,C)
77 | else:
78 | rgbs = None
79 |
80 | if masks is not None:
81 | rays = rays[masks > mask_threshold, :]
82 | if rgbs is not None:
83 | rgbs = rgbs[masks > mask_threshold, :]
84 | if depth is not None:
85 | ds = depth[masks > mask_threshold, :]
86 | else:
87 | rays = rays.reshape((-1, rays.size(3)))
88 | if rgbs is not None:
89 | rgbs = rgbs.reshape((-1, rgbs.size(3)))
90 | if depth is not None:
91 | ds = depth.reshape((-1, depth.size(3)))
92 |
93 | return rays, rgbs, ds
94 |
95 |
96 | def patch_sampling(Ks, Ts, image_size, patch_size=32, images=None, depth=None, far_depth=None, fine_depth=None,
97 | alpha=None, keep_bg=False):
98 | h = image_size[0]
99 | w = image_size[1]
100 | M = Ks.size(0)
101 |
102 | h_num, w_num = math.ceil(h / patch_size), math.ceil(w / patch_size)
103 | ray_patches, rgb_patches, depth_patches, alpha_patches = [], [], [], []
104 |
105 | x = torch.linspace(0, h - 1, steps=h, device=Ks.device)
106 | y = torch.linspace(0, w - 1, steps=w, device=Ks.device)
107 |
108 | grid_x, grid_y = torch.meshgrid(x, y)
109 | coordinates = torch.stack([grid_y, grid_x]).unsqueeze(0).repeat(M, 1, 1, 1) # (M,2,H,W)
110 | coordinates = torch.cat([coordinates, torch.ones(coordinates.size(0), 1, coordinates.size(2),
111 | coordinates.size(3), device=Ks.device)], dim=1).permute(0, 2, 3,
112 | 1).unsqueeze(
113 | -1)
114 |
115 | inv_Ks = torch.inverse(Ks)
116 |
117 | dirs = torch.matmul(inv_Ks, coordinates) # (M,H,W,3,1)
118 | dirs = dirs / torch.norm(dirs, dim=3, keepdim=True)
119 |
120 | # zs to depth
121 | td = dirs.view(M, h, w, -1)
122 | if depth is not None:
123 | depth = depth / td[:, :, :, 2]
124 | if far_depth is not None:
125 | far_depth = far_depth / td[:, :, :, 2]
126 | else:
127 | far_depth = depth
128 | if fine_depth is not None:
129 | fine_depth = fine_depth / td[:, :, :, 2]
130 | else:
131 | fine_depth = depth
132 |
133 | dirs = torch.cat([dirs, torch.zeros(dirs.size(0), coordinates.size(1),
134 | coordinates.size(2), 1, 1, device=Ks.device)], dim=3) # (M,H,W,4,1)
135 |
136 | dirs = torch.matmul(Ts, dirs) # (M,H,W,4,1)
137 | dirs = dirs[:, :, :, 0:3, 0] # (M,H,W,3)
138 |
139 | pos = Ts[:, 0:3, 3] # (M,3)
140 | pos = pos.unsqueeze(1).unsqueeze(1).repeat(1, h, w, 1)
141 |
142 | rays = torch.cat([dirs, pos], dim=3) # (M,H,W,6)
143 |
144 | if images is not None:
145 | rgbs = images.permute(0, 2, 3, 1) # (M,H,W,C)
146 | else:
147 | rgbs = None
148 |
149 | depth = torch.stack([depth, far_depth, fine_depth], dim=3) # (M,H,W,2)
150 |
151 | padded_h, padded_w = h_num * patch_size, w_num * patch_size
152 | ray_padded, rgb_padded, depth_padded, alpha_padded = torch.zeros([rays.size(0), padded_h, padded_w, rays.size(3)],
153 | device=rays.device), \
154 | torch.zeros([rgbs.size(0), padded_h, padded_w, rgbs.size(3)],
155 | device=rgbs.device), \
156 | torch.zeros([depth.size(0), padded_h, padded_w, depth.size(3)],
157 | device=depth.device), \
158 | torch.zeros([alpha.size(0), padded_h, padded_w],
159 | device=alpha.device)
160 | ray_padded[:, :h, :w, :] = rays
161 | rgb_padded[:, :h, :w, :] = rgbs
162 | depth_padded[:, :h, :w, :] = depth
163 | alpha_padded[:, :h, :w] = alpha
164 |
165 | rays = ray_padded.view(int(padded_h / patch_size), patch_size, int(padded_w / patch_size), patch_size, -1).permute(0, 2, 1, 3, 4).reshape(-1, patch_size, patch_size, ray_padded.shape[-1])
166 | rgbs = rgb_padded.view(int(padded_h / patch_size), patch_size, int(padded_w / patch_size), patch_size, -1).permute(0, 2, 1, 3, 4).reshape(-1, patch_size, patch_size, rgb_padded.shape[-1])
167 | ds = depth_padded.view(int(padded_h / patch_size), patch_size, int(padded_w / patch_size), patch_size, -1).permute(0, 2, 1, 3, 4).reshape(-1, patch_size, patch_size, depth_padded.shape[-1])
168 | als = alpha_padded.view(int(padded_h / patch_size), patch_size, int(padded_w / patch_size), patch_size).permute(0, 2, 1, 3).reshape(-1, patch_size, patch_size)
169 |
170 | mask = ds.reshape(ds.shape[0], -1).sum(1) > 0
171 |
172 | return rays[mask], rgbs[mask], ds[mask], als[mask]
173 |
--------------------------------------------------------------------------------
/code/utils/rendering.py:
--------------------------------------------------------------------------------
1 | from utils import ray_sampling
2 | import collections
3 | import time
4 | import torch
5 | import svox_t
6 | from utils import generate_transformation_matrices
7 | import numpy as np
8 | import math
9 |
10 | Sh_Rays = collections.namedtuple('Sh_Rays', 'origins dirs viewdirs')
11 |
12 |
13 | def render_image(cfg, model, K, T, img_size=(450, 800), tree=None, matrices=None, joint_features=None,
14 | bg=None, skinning_weights=None, joint_index=None):
15 | torch.cuda.synchronize()
16 | s = time.time()
17 | h, w = img_size[0], img_size[1]
18 | rays, _, _ = ray_sampling(K.unsqueeze(0).cuda(), T.unsqueeze(0).cuda(), img_size)
19 |
20 | with torch.no_grad():
21 | joint_features = None if not cfg.MODEL.USE_MOTION else joint_features
22 | matrices = generate_transformation_matrices(matrices=matrices, skinning_weights=skinning_weights,
23 | joint_index=joint_index)
24 |
25 | with torch.cuda.amp.autocast(enabled=False):
26 | features = model.tree_renderer(tree, rays, matrices).reshape(1, h, w, -1).permute(0, 3, 1, 2)
27 |
28 | if cfg.MODEL.USE_MOTION:
29 | motion_feature = model.tree_renderer.motion_feature_render(tree, joint_features, skinning_weights,
30 | joint_index,
31 | rays)
32 |
33 |
34 | motion_feature = motion_feature.reshape(1, h, w, -1).permute(0, 3, 1, 2)
35 | else:
36 | motion_feature = features[:, :9, ...]
37 |
38 | with torch.cuda.amp.autocast(enabled=True):
39 | features_in = features[:, :-1, ...]
40 | if cfg.MODEL.USE_MOTION:
41 | features_in = torch.cat([features[:, :-1, ...], motion_feature], dim=1)
42 | rgba_out = model.render_net(features_in, features[:, -1:, ...])
43 |
44 | rgba_volume = torch.cat([features[:, :3, ...], features[:, -1:, ...]], dim=1)
45 |
46 | rgb = rgba_out[0, :-1, ...]
47 | alpha = rgba_out[0, -1:, ...]
48 | img_volume = rgba_volume[0, :3, ...].permute(1, 2, 0)
49 |
50 | if model.use_render_net:
51 | rgb = torch.nn.Hardtanh()(rgb)
52 | rgb = (rgb + 1) / 2
53 |
54 | alpha = torch.nn.Hardtanh()(alpha)
55 | alpha = (alpha + 1) / 2
56 | alpha = torch.clamp(alpha, min=0, max=1.)
57 |
58 | if bg is not None:
59 | if bg.max() > 1:
60 | bg = bg / 255
61 | comp_img = rgb * alpha + (1 - alpha) * bg
62 | else:
63 | comp_img = rgb * alpha + (1 - alpha)
64 |
65 | img_unet = comp_img.permute(1, 2, 0).float().cpu().numpy()
66 |
67 | return img_unet, alpha.squeeze().float().detach().cpu().numpy(), img_volume.float().detach().cpu().numpy()
68 |
69 |
70 | def build_octree(coordinates, radius=None, center=None, skeleton=None, data_dim=91,
71 | max_depth=10, data_format='SH9'):
72 | coordinates = coordinates.contiguous().cuda().detach()
73 |
74 | maxs_ = torch.max(coordinates, dim=0).values
75 | mins_ = torch.min(coordinates, dim=0).values
76 | radius_ = (maxs_ - mins_) / 2
77 | center_ = mins_ + radius_
78 |
79 | if radius is None:
80 | radius = radius_
81 |
82 | if center is None:
83 | center = center_
84 |
85 | t = svox_t.N3Tree(N=2,
86 | data_dim=data_dim,
87 | depth_limit=max_depth,
88 | init_reserve=30000,
89 | init_refine=0,
90 | # geom_resize_fact=1.5,
91 | radius=list(radius),
92 | center=list(center),
93 | data_format=data_format,
94 | extra_data=skeleton.contiguous(),
95 | map_location=coordinates.device)
96 |
97 | for i in range(max_depth):
98 | t[coordinates].refine()
99 |
100 | t.shrink_to_fit()
101 | t.construct_tree(coordinates)
102 |
103 | return t
104 |
105 |
106 | def warp_build_octree(coords, transformation_matrices, skeleton=None, vb_weights=None, vb_indices=None, radius=None,
107 | max_depth=8):
108 | torch.cuda.synchronize()
109 | s = time.time()
110 | vs, ms = svox_t.warp_vertices(transformation_matrices, coords, vb_weights, vb_indices)
111 | torch.cuda.synchronize()
112 | wtime = time.time() - s
113 |
114 | torch.cuda.synchronize()
115 | s = time.time()
116 | t = build_octree(vs, radius=radius, max_depth=max_depth, skeleton=skeleton)
117 | torch.cuda.synchronize()
118 | btime = time.time() - s
119 |
120 | return t, wtime, btime
121 |
122 | def rodrigues_rotation_matrix(axis, theta):
123 | axis = np.asarray(axis)
124 | theta = np.asarray(theta)
125 | axis = axis/math.sqrt(np.dot(axis, axis))
126 | a = math.cos(theta/2.0)
127 | b, c, d = -axis*math.sin(theta/2.0)
128 | aa, bb, cc, dd = a*a, b*b, c*c, d*d
129 | bc, ad, ac, ab, bd, cd = b*c, a*d, a*c, a*b, b*d, c*d
130 | return np.array([[aa+bb-cc-dd, 2*(bc+ad), 2*(bd-ac)],
131 | [2*(bc-ad), aa+cc-bb-dd, 2*(cd+ab)],
132 | [2*(bd+ac), 2*(cd-ab), aa+dd-bb-cc]])
133 |
--------------------------------------------------------------------------------
/code/utils/spherical_harmonics.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def ComputeSH(dirs):
5 | '''
6 | dirs: b*3
7 | '''
8 | x = dirs[..., 0]
9 | y = dirs[..., 1]
10 | z = dirs[..., 2]
11 | xx = dirs[..., 0] ** 2
12 | yy = dirs[..., 1] ** 2
13 | zz = dirs[..., 2] ** 2
14 |
15 | xy = dirs[..., 0] * dirs[..., 1]
16 | yz = dirs[..., 1] * dirs[..., 2]
17 | xz = dirs[..., 0] * dirs[..., 2]
18 |
19 | sh = torch.zeros((dirs.shape[0], 25)).to(dirs.device)
20 |
21 | sh[:, 0] = 0.282095
22 |
23 | sh[:, 1] = -0.4886025119029199 * y
24 | sh[:, 2] = 0.4886025119029199 * z
25 | sh[:, 3] = -0.4886025119029199 * x
26 |
27 | sh[:, 4] = 1.0925484305920792 * xy
28 | sh[:, 5] = -1.0925484305920792 * yz
29 | sh[:, 6] = 0.31539156525252005 * (2.0 * zz - xx - yy)
30 | sh[:, 7] = -1.0925484305920792 * xz
31 | # sh2p2
32 | sh[:, 8] = 0.5462742152960396 * (xx - yy)
33 |
34 | sh[:, 9] = -0.5900435899266435 * y * (3 * xx - yy)
35 | sh[:, 10] = 2.890611442640554 * xy * z
36 | sh[:, 11] = -0.4570457994644658 * y * (4 * zz - xx - yy)
37 | sh[:, 12] = 0.3731763325901154 * z * (2 * zz - 3 * xx - 3 * yy)
38 | sh[:, 13] = -0.4570457994644658 * x * (4 * zz - xx - yy)
39 | sh[:, 14] = 1.445305721320277 * z * (xx - yy)
40 | sh[:, 15] = -0.5900435899266435 * x * (xx - 3 * yy)
41 |
42 | sh[:, 16] = 2.5033429417967046 * xy * (xx - yy)
43 | sh[:, 17] = -1.7701307697799304 * yz * (3 * xx - yy)
44 | sh[:, 18] = 0.9461746957575601 * xy * (7 * zz - 1.0)
45 | sh[:, 19] = -0.6690465435572892 * yz * (7 * zz - 3.0)
46 | sh[:, 20] = 0.10578554691520431 * (zz * (35 * zz - 30) + 3)
47 | sh[:, 21] = -0.6690465435572892 * xz * (7 * zz - 3)
48 | sh[:, 22] = 0.47308734787878004 * (xx - yy) * (7 * zz - 1.0)
49 | sh[:, 23] = -1.7701307697799304 * xz * (xx - 3 * yy)
50 | sh[:, 24] = 0.6258357354491761 * (xx * (xx - 3 * yy) - yy * (3 * xx - yy))
51 |
52 | return sh
53 |
54 |
55 | def computeRGB(dirs, coeff):
56 | '''
57 | dirs: n * 3
58 | coeff: n * n_components * f
59 | '''
60 | n_components = coeff.shape[1]
61 |
62 | return (ComputeSH(dirs)[..., :n_components].unsqueeze(-1) * coeff).sum(1)
--------------------------------------------------------------------------------
/code/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from pytorch3d.transforms import so3_log_map, so3_rotation_angle, so3_relative_angle, matrix_to_euler_angles, euler_angles_to_matrix
5 |
6 |
7 | def get_local_transformation(transformations, parents):
8 | parents_t = torch.cat([torch.eye(4).unsqueeze(0), transformations[parents[1:]]])
9 | local_t = torch.matmul(torch.inverse(parents_t), transformations)
10 |
11 | return local_t
12 |
13 |
14 | def get_local_eular(transformations, parents, axis='XYZ'):
15 | local_t = get_local_transformation(transformations, parents)
16 | local_t = torch.clamp(local_t[:, :3, :3], max=1.)
17 | local_eular = matrix_to_euler_angles(local_t, axis)
18 |
19 | return local_eular
20 |
21 |
22 | def rotation_matrix_to_eular(rotations, axis='XYZ'):
23 | return matrix_to_euler_angles(rotations[:, :3, :3], axis)
24 |
25 |
26 | def eular_to_rotation_matrix(eular, axis='XYZ'):
27 | return euler_angles_to_matrix(eular, axis)
28 |
29 |
30 | def campose_to_extrinsic(camposes):
31 | if camposes.shape[1] != 12:
32 | raise Exception(" wrong campose data structure!")
33 | return
34 |
35 | res = np.zeros((camposes.shape[0], 4, 4))
36 |
37 | res[:, 0:3, 2] = camposes[:, 0:3]
38 | res[:, 0:3, 0] = camposes[:, 3:6]
39 | res[:, 0:3, 1] = camposes[:, 6:9]
40 | res[:, 0:3, 3] = camposes[:, 9:12]
41 | res[:, 3, 3] = 1.0
42 |
43 | return res
44 |
45 |
46 | def read_intrinsics(fn_instrinsic):
47 | fo = open(fn_instrinsic)
48 | data = fo.readlines()
49 | i = 0
50 | Ks = []
51 | while i < len(data):
52 | if len(data[i]) > 6:
53 | tmp = data[i].split()
54 | tmp = [float(i) for i in tmp]
55 | a = np.array(tmp)
56 | i = i + 1
57 | tmp = data[i].split()
58 | tmp = [float(i) for i in tmp]
59 | b = np.array(tmp)
60 | i = i + 1
61 | tmp = data[i].split()
62 | tmp = [float(i) for i in tmp]
63 | c = np.array(tmp)
64 | res = np.vstack([a, b, c])
65 | Ks.append(res)
66 |
67 | i = i + 1
68 | Ks = np.stack(Ks)
69 | fo.close()
70 |
71 | return Ks
72 |
--------------------------------------------------------------------------------
/medias/featured1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HaiminLuo/Artemis/9a0aee4af849c384b49a1a2af354f47549d05c85/medias/featured1.png
--------------------------------------------------------------------------------
/medias/overview_v3-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HaiminLuo/Artemis/9a0aee4af849c384b49a1a2af354f47549d05c85/medias/overview_v3-1.png
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | yacs
2 | numpy
3 | torch
4 | torchvision
5 | pillow
6 | argparse
7 | opencv-python
8 | imageio
9 | imageio-ffmpeg
10 | IPython
11 | tqdm
--------------------------------------------------------------------------------