├── LICENSE
├── README.md
├── criteria
├── __init__.py
└── lpips
│ ├── __init__.py
│ ├── lpips.py
│ ├── networks.py
│ └── utils.py
├── docs
├── anime.gif
├── car.gif
├── church.gif
├── ffhq.gif
├── teaser.jpg
└── thumb.gif
├── env.yaml
├── expansion
├── __init__.py
├── dataloader
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ └── seqlist.cpython-38.pyc
│ ├── chairslist.py
│ ├── chairssdlist.py
│ ├── depth_transforms.py
│ ├── depthloader.py
│ ├── flow_transforms.py
│ ├── hd1klist.py
│ ├── kitti12list.py
│ ├── kitti15list.py
│ ├── kitti15list_train.py
│ ├── kitti15list_train_lidar.py
│ ├── kitti15list_val.py
│ ├── kitti15list_val_lidar.py
│ ├── kitti15list_val_mr.py
│ ├── robloader.py
│ ├── sceneflowlist.py
│ ├── seqlist.py
│ ├── sintellist.py
│ ├── sintellist_clean.py
│ ├── sintellist_final.py
│ ├── sintellist_train.py
│ ├── sintellist_val.py
│ └── thingslist.py
├── models
│ ├── VCN_exp.py
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── VCN_exp.cpython-38.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── conv4d.cpython-38.pyc
│ │ └── submodule.cpython-38.pyc
│ ├── conv4d.py
│ └── submodule.py
├── submission.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── flowlib.cpython-38.pyc
│ ├── io.cpython-38.pyc
│ ├── pfm.cpython-38.pyc
│ └── util_flow.cpython-38.pyc
│ ├── flowlib.py
│ ├── io.py
│ ├── logger.py
│ ├── multiscaleloss.py
│ ├── pfm.py
│ ├── readpfm.py
│ ├── sintel_io.py
│ └── util_flow.py
├── interface
├── flask_app.py
├── inference.py
└── templates
│ └── index.html
├── licenses
├── LICENSE_ gengshan-y_expansion
├── LICENSE_HuangYG123
├── LICENSE_S-aiueo32
├── LICENSE_TreB1eN
├── LICENSE_lessw2020
├── LICENSE_pixel2style2pixel
└── LICENSE_rosinality
├── models
├── StyleGANControler.py
├── __init__.py
├── networks
│ ├── __init__.py
│ └── latent_transformer.py
└── stylegan2
│ ├── __init__.py
│ ├── model.py
│ └── op
│ ├── __init__.py
│ ├── fused_act.py
│ ├── fused_bias_act.cpp
│ ├── fused_bias_act_kernel.cu
│ ├── upfirdn2d.cpp
│ ├── upfirdn2d.py
│ └── upfirdn2d_kernel.cu
├── options
├── __init__.py
└── train_options.py
├── scripts
└── train.py
├── training
├── __init__.py
├── coach.py
└── ranger.py
└── utils
├── __init__.py
└── common.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Yuki Endo
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # User-Controllable Latent Transformer for StyleGAN Image Layout Editing
2 |
3 |
4 |
5 |
6 |
7 |
8 | This repository contains our implementation of the following paper:
9 |
10 | Yuki Endo: "User-Controllable Latent Transformer for StyleGAN Image Layout Editing," Computer Graphpics Forum (Pacific Graphics 2022) [[Project](http://www.cgg.cs.tsukuba.ac.jp/~endo/projects/UserControllableLT)] [[PDF (preprint)](http://arxiv.org/abs/2208.12408)]
11 |
12 | ## Prerequisites
13 | 1. Python 3.8
14 | 2. PyTorch 1.9.0
15 | 3. Flask
16 | 4. Others (see env.yml)
17 |
18 | ## Preparation
19 | Download and decompress our pre-trained models.
20 |
21 | ## Inference with our pre-trained models
22 | 




23 | We provide an interactive interface based on Flask. This interface can be locally launched with
24 | ```
25 | python interface/flask_app.py --checkpoint_path=pretrained_models/latent_transformer/cat.pt
26 | ```
27 | The interface can be accessed via http://localhost:8000/.
28 |
29 | ## Training
30 | The latent transformer can be trained with
31 | ```
32 | python scripts/train.py --exp_dir=results --stylegan_weights=pretrained_models/stylegan2-cat-config-f.pt
33 | ```
34 | To perform training with your dataset, you need first to train StyleGAN2 on your dataset using [rosinality's code](https://github.com/rosinality/stylegan2-pytorch) and then run the above script with specifying the trained weights.
35 |
36 | ## Link
37 | [Gradio demo](https://huggingface.co/spaces/radames/UserControllableLT-Latent-Transformer) by Radamés Ajna
38 |
39 | ## Citation
40 | Please cite our paper if you find the code useful:
41 | ```
42 | @Article{endoPG2022,
43 | Title = {User-Controllable Latent Transformer for StyleGAN Image Layout Editing},
44 | Author = {Yuki Endo},
45 | Journal = {Computer Graphics Forum},
46 | volume = {41},
47 | number = {7},
48 | pages = {395-406},
49 | doi = {10.1111/cgf.14686},
50 | Year = {2022}
51 | }
52 | ```
53 |
54 | ## Acknowledgements
55 | This code heavily borrows from the [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel) and [expansion](https://github.com/gengshan-y/expansion) repositories.
56 |
--------------------------------------------------------------------------------
/criteria/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/criteria/__init__.py
--------------------------------------------------------------------------------
/criteria/lpips/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/criteria/lpips/__init__.py
--------------------------------------------------------------------------------
/criteria/lpips/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from criteria.lpips.networks import get_network, LinLayers
5 | from criteria.lpips.utils import get_state_dict
6 |
7 |
8 | class LPIPS(nn.Module):
9 | r"""Creates a criterion that measures
10 | Learned Perceptual Image Patch Similarity (LPIPS).
11 | Arguments:
12 | net_type (str): the network type to compare the features:
13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
14 | version (str): the version of LPIPS. Default: 0.1.
15 | """
16 | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
17 |
18 | assert version in ['0.1'], 'v0.1 is only supported now'
19 |
20 | super(LPIPS, self).__init__()
21 |
22 | # pretrained network
23 | self.net = get_network(net_type).to("cuda")
24 |
25 | # linear layers
26 | self.lin = LinLayers(self.net.n_channels_list).to("cuda")
27 | self.lin.load_state_dict(get_state_dict(net_type, version))
28 |
29 | def forward(self, x: torch.Tensor, y: torch.Tensor):
30 | feat_x, feat_y = self.net(x), self.net(y)
31 |
32 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
33 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
34 |
35 | return torch.sum(torch.cat(res, 0)) / x.shape[0]
36 |
--------------------------------------------------------------------------------
/criteria/lpips/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from itertools import chain
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from criteria.lpips.utils import normalize_activation
10 |
11 |
12 | def get_network(net_type: str):
13 | if net_type == 'alex':
14 | return AlexNet()
15 | elif net_type == 'squeeze':
16 | return SqueezeNet()
17 | elif net_type == 'vgg':
18 | return VGG16()
19 | else:
20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21 |
22 |
23 | class LinLayers(nn.ModuleList):
24 | def __init__(self, n_channels_list: Sequence[int]):
25 | super(LinLayers, self).__init__([
26 | nn.Sequential(
27 | nn.Identity(),
28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29 | ) for nc in n_channels_list
30 | ])
31 |
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 |
36 | class BaseNet(nn.Module):
37 | def __init__(self):
38 | super(BaseNet, self).__init__()
39 |
40 | # register buffer
41 | self.register_buffer(
42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43 | self.register_buffer(
44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45 |
46 | def set_requires_grad(self, state: bool):
47 | for param in chain(self.parameters(), self.buffers()):
48 | param.requires_grad = state
49 |
50 | def z_score(self, x: torch.Tensor):
51 | return (x - self.mean) / self.std
52 |
53 | def forward(self, x: torch.Tensor):
54 | x = self.z_score(x)
55 |
56 | output = []
57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58 | x = layer(x)
59 | if i in self.target_layers:
60 | output.append(normalize_activation(x))
61 | if len(output) == len(self.target_layers):
62 | break
63 | return output
64 |
65 |
66 | class SqueezeNet(BaseNet):
67 | def __init__(self):
68 | super(SqueezeNet, self).__init__()
69 |
70 | self.layers = models.squeezenet1_1(True).features
71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73 |
74 | self.set_requires_grad(False)
75 |
76 |
77 | class AlexNet(BaseNet):
78 | def __init__(self):
79 | super(AlexNet, self).__init__()
80 |
81 | self.layers = models.alexnet(True).features
82 | self.target_layers = [2, 5, 8, 10, 12]
83 | self.n_channels_list = [64, 192, 384, 256, 256]
84 |
85 | self.set_requires_grad(False)
86 |
87 |
88 | class VGG16(BaseNet):
89 | def __init__(self):
90 | super(VGG16, self).__init__()
91 |
92 | self.layers = models.vgg16(True).features
93 | self.target_layers = [4, 9, 16, 23, 30]
94 | self.n_channels_list = [64, 128, 256, 512, 512]
95 |
96 | self.set_requires_grad(False)
--------------------------------------------------------------------------------
/criteria/lpips/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | def normalize_activation(x, eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8 | return x / (norm_factor + eps)
9 |
10 |
11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12 | # build url
13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14 | + f'master/lpips/weights/v{version}/{net_type}.pth'
15 |
16 | # download
17 | old_state_dict = torch.hub.load_state_dict_from_url(
18 | url, progress=True,
19 | map_location=None if torch.cuda.is_available() else torch.device('cpu')
20 | )
21 |
22 | # rename keys
23 | new_state_dict = OrderedDict()
24 | for key, val in old_state_dict.items():
25 | new_key = key
26 | new_key = new_key.replace('lin', '')
27 | new_key = new_key.replace('model.', '')
28 | new_state_dict[new_key] = val
29 |
30 | return new_state_dict
31 |
--------------------------------------------------------------------------------
/docs/anime.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/docs/anime.gif
--------------------------------------------------------------------------------
/docs/car.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/docs/car.gif
--------------------------------------------------------------------------------
/docs/church.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/docs/church.gif
--------------------------------------------------------------------------------
/docs/ffhq.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/docs/ffhq.gif
--------------------------------------------------------------------------------
/docs/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/docs/teaser.jpg
--------------------------------------------------------------------------------
/docs/thumb.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/docs/thumb.gif
--------------------------------------------------------------------------------
/env.yaml:
--------------------------------------------------------------------------------
1 | name: uclt
2 | channels:
3 | - pytorch
4 | - anaconda
5 | - nvidia
6 | - conda-forge
7 | - defaults
8 | dependencies:
9 | - _ipyw_jlab_nb_ext_conf=0.1.0=py38_0
10 | - _libgcc_mutex=0.1=conda_forge
11 | - _openmp_mutex=4.5=1_llvm
12 | - absl-py=0.13.0=pyhd8ed1ab_0
13 | - aiohttp=3.7.4.post0=py38h497a2fe_0
14 | - albumentations=1.0.3=pyhd8ed1ab_0
15 | - alsa-lib=1.2.3=h516909a_0
16 | - anaconda-client=1.8.0=py38h06a4308_0
17 | - anaconda-navigator=2.0.4=py38_0
18 | - anyio=2.2.0=py38h06a4308_1
19 | - appdirs=1.4.4=pyh9f0ad1d_0
20 | - argon2-cffi=20.1.0=py38h27cfd23_1
21 | - async-timeout=3.0.1=py_1000
22 | - async_generator=1.10=pyhd3eb1b0_0
23 | - attrs=21.2.0=pyhd3eb1b0_0
24 | - babel=2.9.1=pyhd3eb1b0_0
25 | - backcall=0.2.0=pyhd3eb1b0_0
26 | - backports=1.0=pyhd3eb1b0_2
27 | - backports.functools_lru_cache=1.6.4=pyhd3eb1b0_0
28 | - backports.tempfile=1.0=pyhd3eb1b0_1
29 | - backports.weakref=1.0.post1=py_1
30 | - beautifulsoup4=4.9.3=pyha847dfd_0
31 | - blas=1.0=mkl
32 | - bleach=4.0.0=pyhd3eb1b0_0
33 | - blinker=1.4=py_1
34 | - brotli=1.0.9=h7f98852_5
35 | - brotli-bin=1.0.9=h7f98852_5
36 | - brotlipy=0.7.0=py38h27cfd23_1003
37 | - bzip2=1.0.8=h7b6447c_0
38 | - c-ares=1.17.1=h27cfd23_0
39 | - ca-certificates=2021.10.8=ha878542_0
40 | - cachetools=4.2.2=pyhd8ed1ab_0
41 | - cairo=1.16.0=hf32fb01_1
42 | - certifi=2021.10.8=py38h578d9bd_1
43 | - cffi=1.14.6=py38h400218f_0
44 | - chardet=4.0.0=py38h06a4308_1003
45 | - click=8.0.1=pyhd3eb1b0_0
46 | - cloudpickle=1.6.0=py_0
47 | - clyent=1.2.2=py38_1
48 | - conda=4.11.0=py38h578d9bd_0
49 | - conda-build=3.21.4=py38h06a4308_0
50 | - conda-content-trust=0.1.1=pyhd3eb1b0_0
51 | - conda-env=2.6.0=1
52 | - conda-package-handling=1.7.3=py38h27cfd23_1
53 | - conda-repo-cli=1.0.4=pyhd3eb1b0_0
54 | - conda-token=0.3.0=pyhd3eb1b0_0
55 | - conda-verify=3.4.2=py_1
56 | - cryptography=3.4.7=py38hd23ed53_0
57 | - cudatoolkit=11.1.74=h6bb024c_0
58 | - cycler=0.10.0=py_2
59 | - cytoolz=0.11.0=py38h497a2fe_3
60 | - dask-core=2021.8.1=pyhd8ed1ab_0
61 | - dbus=1.13.18=hb2f20db_0
62 | - decorator=5.0.9=pyhd3eb1b0_0
63 | - defusedxml=0.7.1=pyhd3eb1b0_0
64 | - dill=0.3.4=pyhd8ed1ab_0
65 | - dominate=2.6.0=pyhd8ed1ab_0
66 | - entrypoints=0.3=py38_0
67 | - enum34=1.1.10=py38h32f6830_2
68 | - expat=2.4.1=h2531618_2
69 | - ffmpeg=4.3.2=hca11adc_0
70 | - filelock=3.0.12=pyhd3eb1b0_1
71 | - flask=1.1.2=pyh9f0ad1d_0
72 | - flask-httpauth=4.4.0=pyhd8ed1ab_0
73 | - fontconfig=2.13.1=h6c09931_0
74 | - fonttools=4.25.0=pyhd3eb1b0_0
75 | - freetype=2.10.4=h5ab3b9f_0
76 | - fsspec=2021.7.0=pyhd8ed1ab_0
77 | - ftfy=6.0.3=pyhd8ed1ab_0
78 | - func_timeout=4.3.5=py_0
79 | - future=0.18.2=py38_1
80 | - gdown=4.2.0=pyhd8ed1ab_0
81 | - geos=3.10.0=h9c3ff4c_0
82 | - gettext=0.19.8.1=h0b5b191_1005
83 | - git=2.23.0=pl526hacde149_0
84 | - glib=2.68.4=h9c3ff4c_0
85 | - glib-tools=2.68.4=h9c3ff4c_0
86 | - glob2=0.7=pyhd3eb1b0_0
87 | - gmp=6.2.1=h58526e2_0
88 | - gnutls=3.6.13=h85f3911_1
89 | - google-auth=1.35.0=pyh6c4a22f_0
90 | - google-auth-oauthlib=0.4.5=pyhd8ed1ab_0
91 | - gputil=1.4.0=pyh9f0ad1d_0
92 | - graphite2=1.3.13=h58526e2_1001
93 | - gst-plugins-base=1.18.4=hf529b03_2
94 | - gstreamer=1.18.4=h76c114f_2
95 | - harfbuzz=2.9.0=h83ec7ef_0
96 | - hdf5=1.10.6=nompi_h6a2412b_1114
97 | - icu=68.1=h58526e2_0
98 | - idna=2.10=pyhd3eb1b0_0
99 | - imagecodecs-lite=2019.12.3=py38h5c078b8_3
100 | - imageio=2.9.0=py_0
101 | - imageio-ffmpeg=0.4.5=pyhd8ed1ab_0
102 | - imgaug=0.4.0=py_1
103 | - importlib-metadata=3.10.0=py38h06a4308_0
104 | - importlib_metadata=3.10.0=hd3eb1b0_0
105 | - intel-openmp=2021.3.0=h06a4308_3350
106 | - ipykernel=5.3.4=py38h5ca1d4c_0
107 | - ipympl=0.8.2=pyhd8ed1ab_0
108 | - ipython=7.26.0=py38hb070fc8_0
109 | - ipython_genutils=0.2.0=pyhd3eb1b0_1
110 | - ipywidgets=7.6.3=pyhd3eb1b0_1
111 | - itsdangerous=2.0.1=pyhd3eb1b0_0
112 | - jasper=1.900.1=h07fcdf6_1006
113 | - jedi=0.18.0=py38h06a4308_1
114 | - jinja2=2.11.3=pyhd3eb1b0_0
115 | - joblib=1.1.0=pyhd8ed1ab_0
116 | - jpeg=9d=h36c2ea0_0
117 | - json5=0.9.6=pyhd3eb1b0_0
118 | - jsonnet=0.17.0=py38hadf7658_0
119 | - jsonschema=3.2.0=py_2
120 | - jupyter_client=6.1.12=pyhd3eb1b0_0
121 | - jupyter_core=4.7.1=py38h06a4308_0
122 | - jupyter_server=1.4.1=py38h06a4308_0
123 | - jupyterlab=3.1.7=pyhd3eb1b0_0
124 | - jupyterlab_pygments=0.1.2=py_0
125 | - jupyterlab_server=2.7.1=pyhd3eb1b0_0
126 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
127 | - kiwisolver=1.3.1=py38h1fd1430_1
128 | - krb5=1.19.2=hcc1bbae_0
129 | - lame=3.100=h7f98852_1001
130 | - lcms2=2.12=h3be6417_0
131 | - ld_impl_linux-64=2.35.1=h7274673_9
132 | - libarchive=3.4.2=h62408e4_0
133 | - libblas=3.9.0=11_linux64_mkl
134 | - libbrotlicommon=1.0.9=h7f98852_5
135 | - libbrotlidec=1.0.9=h7f98852_5
136 | - libbrotlienc=1.0.9=h7f98852_5
137 | - libcblas=3.9.0=11_linux64_mkl
138 | - libcurl=7.78.0=h2574ce0_0
139 | - libedit=3.1.20191231=he28a2e2_2
140 | - libev=4.33=h516909a_1
141 | - libevent=2.1.10=hcdb4288_3
142 | - libffi=3.3=he6710b0_2
143 | - libgcc-ng=11.1.0=hc902ee8_8
144 | - libgfortran-ng=11.1.0=h69a702a_8
145 | - libgfortran5=11.1.0=h6c583b3_8
146 | - libglib=2.68.4=h3e27bee_0
147 | - libiconv=1.16=h516909a_0
148 | - liblapack=3.9.0=11_linux64_mkl
149 | - liblapacke=3.9.0=11_linux64_mkl
150 | - liblief=0.10.1=he6710b0_0
151 | - libllvm11=11.1.0=hf817b99_2
152 | - libnghttp2=1.43.0=h812cca2_0
153 | - libogg=1.3.4=h7f98852_1
154 | - libopencv=4.5.2=py38hcdf9bf1_0
155 | - libopus=1.3.1=h7f98852_1
156 | - libpng=1.6.37=hbc83047_0
157 | - libpq=13.3=hd57d9b9_0
158 | - libprotobuf=3.15.8=h780b84a_0
159 | - libsodium=1.0.18=h7b6447c_0
160 | - libssh2=1.9.0=ha56f1ee_6
161 | - libstdcxx-ng=11.1.0=h56837e0_8
162 | - libtiff=4.2.0=h85742a9_0
163 | - libuuid=1.0.3=h1bed415_2
164 | - libuv=1.40.0=h7b6447c_0
165 | - libvorbis=1.3.7=h9c3ff4c_0
166 | - libwebp-base=1.2.0=h27cfd23_0
167 | - libxcb=1.14=h7b6447c_0
168 | - libxkbcommon=1.0.3=he3ba5ed_0
169 | - libxml2=2.9.12=h72842e0_0
170 | - llvm-openmp=12.0.1=h4bd325d_1
171 | - locket=0.2.0=py_2
172 | - lz4-c=1.9.3=h295c915_1
173 | - markdown=3.3.4=pyhd8ed1ab_0
174 | - markupsafe=2.0.1=py38h27cfd23_0
175 | - matplotlib=3.4.2=py38h578d9bd_0
176 | - matplotlib-base=3.4.2=py38hab158f2_0
177 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2
178 | - mistune=0.8.4=py38h7b6447c_1000
179 | - mkl=2021.3.0=h06a4308_520
180 | - mkl-service=2.4.0=py38h7f8727e_0
181 | - mkl_fft=1.3.0=py38h42c9631_2
182 | - mkl_random=1.2.2=py38h51133e4_0
183 | - multidict=5.1.0=py38h497a2fe_1
184 | - munkres=1.1.4=pyh9f0ad1d_0
185 | - mysql-common=8.0.25=ha770c72_0
186 | - mysql-libs=8.0.25=h935591d_0
187 | - navigator-updater=0.2.1=py38_0
188 | - nbclassic=0.2.6=pyhd3eb1b0_0
189 | - nbclient=0.5.3=pyhd3eb1b0_0
190 | - nbconvert=6.1.0=py38h06a4308_0
191 | - nbformat=5.1.3=pyhd3eb1b0_0
192 | - ncurses=6.2=he6710b0_1
193 | - nest-asyncio=1.5.1=pyhd3eb1b0_0
194 | - nettle=3.6=he412f7d_0
195 | - networkx=2.3=py_0
196 | - ninja=1.10.2=hff7bd54_1
197 | - notebook=6.4.3=py38h06a4308_0
198 | - nspr=4.30=h9c3ff4c_0
199 | - nss=3.69=hb5efdd6_0
200 | - numpy=1.20.3=py38hf144106_0
201 | - numpy-base=1.20.3=py38h74d4b33_0
202 | - oauthlib=3.1.1=pyhd8ed1ab_0
203 | - olefile=0.46=py_0
204 | - opencv=4.5.2=py38h578d9bd_0
205 | - openh264=2.1.1=h780b84a_0
206 | - openjpeg=2.3.0=h05c96fa_1
207 | - openssl=1.1.1l=h7f98852_0
208 | - packaging=21.0=pyhd3eb1b0_0
209 | - pandas=1.3.2=py38h43a58ef_0
210 | - pandocfilters=1.4.3=py38h06a4308_1
211 | - parso=0.8.2=pyhd3eb1b0_0
212 | - partd=1.2.0=pyhd8ed1ab_0
213 | - patchelf=0.12=h2531618_1
214 | - pathlib=1.0.1=py38h578d9bd_4
215 | - patsy=0.5.1=py_0
216 | - pcre=8.45=h295c915_0
217 | - perl=5.26.2=h14c3975_0
218 | - pexpect=4.8.0=pyhd3eb1b0_3
219 | - pickleshare=0.7.5=pyhd3eb1b0_1003
220 | - pillow=8.3.1=py38h2c7a002_0
221 | - pip=21.2.2=py38h06a4308_0
222 | - pixman=0.40.0=h36c2ea0_0
223 | - pkginfo=1.7.1=py38h06a4308_0
224 | - pooch=1.5.1=pyhd8ed1ab_0
225 | - portalocker=1.7.0=py38h578d9bd_1
226 | - prometheus_client=0.11.0=pyhd3eb1b0_0
227 | - prompt-toolkit=3.0.17=pyh06a4308_0
228 | - protobuf=3.15.8=py38h709712a_0
229 | - psutil=5.8.0=py38h27cfd23_1
230 | - ptyprocess=0.7.0=pyhd3eb1b0_2
231 | - py-lief=0.10.1=py38h403a769_0
232 | - py-opencv=4.5.2=py38hd0cf306_0
233 | - pyasn1=0.4.8=py_0
234 | - pyasn1-modules=0.2.7=py_0
235 | - pycosat=0.6.3=py38h7b6447c_1
236 | - pycparser=2.20=py_2
237 | - pygments=2.10.0=pyhd3eb1b0_0
238 | - pyjwt=2.1.0=pyhd8ed1ab_0
239 | - pyopenssl=20.0.1=pyhd3eb1b0_1
240 | - pyparsing=2.4.7=pyhd3eb1b0_0
241 | - pypng=0.0.20=py_0
242 | - pyqt=5.12.3=py38h578d9bd_7
243 | - pyqt-impl=5.12.3=py38h7400c14_7
244 | - pyqt5-sip=4.19.18=py38h709712a_7
245 | - pyqtchart=5.12=py38h7400c14_7
246 | - pyqtwebengine=5.12.1=py38h7400c14_7
247 | - pyrsistent=0.17.3=py38h7b6447c_0
248 | - pysocks=1.7.1=py38h06a4308_0
249 | - python=3.8.10=h12debd9_8
250 | - python-dateutil=2.8.2=pyhd3eb1b0_0
251 | - python-libarchive-c=2.9=pyhd3eb1b0_1
252 | - python-lmdb=0.99=py38h709712a_0
253 | - python_abi=3.8=2_cp38
254 | - pytorch=1.9.0=py3.8_cuda11.1_cudnn8.0.5_0
255 | - pytz=2021.1=pyhd3eb1b0_0
256 | - pyu2f=0.1.5=pyhd8ed1ab_0
257 | - pywavelets=1.1.1=py38h5c078b8_3
258 | - pyyaml=5.4.1=py38h27cfd23_1
259 | - pyzmq=22.2.1=py38h295c915_1
260 | - qt=5.12.9=hda022c4_4
261 | - qtpy=1.9.0=py_0
262 | - readline=8.1=h27cfd23_0
263 | - regex=2021.8.28=py38h497a2fe_0
264 | - requests=2.25.1=pyhd3eb1b0_0
265 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0
266 | - ripgrep=12.1.1=0
267 | - rsa=4.7.2=pyh44b312d_0
268 | - ruamel_yaml=0.15.100=py38h27cfd23_0
269 | - scikit-image=0.18.3=py38h43a58ef_0
270 | - scikit-learn=1.0=py38hacb3eff_1
271 | - scipy=1.7.1=py38h56a6a73_0
272 | - seaborn=0.11.2=hd8ed1ab_0
273 | - seaborn-base=0.11.2=pyhd8ed1ab_0
274 | - send2trash=1.5.0=pyhd3eb1b0_1
275 | - setuptools=52.0.0=py38h06a4308_0
276 | - shapely=1.8.0=py38hf7953bd_1
277 | - sip=4.19.13=py38he6710b0_0
278 | - sniffio=1.2.0=py38h06a4308_1
279 | - soupsieve=2.2.1=pyhd3eb1b0_0
280 | - sqlite=3.36.0=hc218d9a_0
281 | - statsmodels=0.12.2=py38h5c078b8_0
282 | - tensorboard=2.6.0=pyhd8ed1ab_1
283 | - tensorboard-data-server=0.6.0=py38h2b97feb_0
284 | - tensorboard-plugin-wit=1.8.0=pyh44b312d_0
285 | - tensorboardx=2.4=pyhd8ed1ab_0
286 | - terminado=0.9.4=py38h06a4308_0
287 | - testpath=0.5.0=pyhd3eb1b0_0
288 | - threadpoolctl=3.0.0=pyh8a188c0_0
289 | - tifffile=2019.7.26.2=py38_0
290 | - tk=8.6.10=hbc83047_0
291 | - toolz=0.11.1=py_0
292 | - torchfile=0.1.0=py_0
293 | - tornado=6.1=py38h27cfd23_0
294 | - tqdm=4.62.1=pyhd3eb1b0_1
295 | - traitlets=5.0.5=pyhd3eb1b0_0
296 | - typing_extensions=3.10.0.0=pyh06a4308_0
297 | - urllib3=1.26.6=pyhd3eb1b0_1
298 | - wcwidth=0.2.5=py_0
299 | - webencodings=0.5.1=py38_1
300 | - werkzeug=1.0.1=pyhd3eb1b0_0
301 | - wheel=0.37.0=pyhd3eb1b0_0
302 | - widgetsnbextension=3.5.1=py38_0
303 | - x264=1!161.3030=h7f98852_1
304 | - xmltodict=0.12.0=py_0
305 | - xz=5.2.5=h7b6447c_0
306 | - yacs=0.1.6=py_0
307 | - yaml=0.2.5=h7b6447c_0
308 | - yarl=1.6.3=py38h497a2fe_2
309 | - zeromq=4.3.4=h2531618_0
310 | - zipp=3.5.0=pyhd3eb1b0_0
311 | - zlib=1.2.11=h7b6447c_3
312 | - zstd=1.4.9=haebb681_0
313 | - pip:
314 | - addict==2.4.0
315 | - altair==4.2.0
316 | - astor==0.8.1
317 | - astunparse==1.6.3
318 | - backports-zoneinfo==0.2.1
319 | - base58==2.1.1
320 | - basicsr==1.3.4.1
321 | - boto3==1.18.33
322 | - botocore==1.21.33
323 | - clang==5.0
324 | - clean-fid==0.1.22
325 | - clip==1.0
326 | - colorama==0.4.4
327 | - commonmark==0.9.1
328 | - cython==0.29.30
329 | - einops==0.3.2
330 | - enum-compat==0.0.3
331 | - facexlib==0.2.0.3
332 | - filterpy==1.4.5
333 | - flatbuffers==1.12
334 | - gast==0.4.0
335 | - google-pasta==0.2.0
336 | - grpcio==1.39.0
337 | - h5py==3.1.0
338 | - ipdb==0.13.9
339 | - jacinle==1.0.0
340 | - jmespath==0.10.0
341 | - jsonpickle==2.2.0
342 | - keras==2.7.0
343 | - keras-preprocessing==1.1.2
344 | - libclang==12.0.0
345 | - llvmlite==0.37.0
346 | - lpips==0.1.4
347 | - numba==0.54.0
348 | - opencv-python==4.5.3.56
349 | - opt-einsum==3.3.0
350 | - pkgconfig==1.5.5
351 | - pyarrow==8.0.0
352 | - pydantic==1.8.2
353 | - pydeck==0.7.1
354 | - pyhocon==0.3.58
355 | - pytz-deprecation-shim==0.1.0.post0
356 | - pyvis==0.2.1
357 | - realesrgan==0.2.2.3
358 | - rich==10.9.0
359 | - s3transfer==0.5.0
360 | - six==1.15.0
361 | - sklearn==0.0
362 | - streamlit==0.64.0
363 | - tabulate==0.8.9
364 | - tb-nightly==2.7.0a20210827
365 | - tensorflow-estimator==2.7.0
366 | - tensorflow-gpu==2.7.0
367 | - tensorflow-io-gcs-filesystem==0.21.0
368 | - tensorfn==0.1.19
369 | - termcolor==1.1.0
370 | - toml==0.10.2
371 | - torchsample==0.1.3
372 | - torchvision==0.10.0+cu111
373 | - typing-extensions==3.7.4.3
374 | - tzdata==2022.1
375 | - tzlocal==4.2
376 | - validators==0.19.0
377 | - vit-pytorch==0.24.3
378 | - watchdog==2.1.8
379 | - wrapt==1.12.1
380 | - yapf==0.31.0
--------------------------------------------------------------------------------
/expansion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/__init__.py
--------------------------------------------------------------------------------
/expansion/dataloader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/dataloader/__init__.py
--------------------------------------------------------------------------------
/expansion/dataloader/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/dataloader/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/expansion/dataloader/__pycache__/seqlist.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/dataloader/__pycache__/seqlist.cpython-38.pyc
--------------------------------------------------------------------------------
/expansion/dataloader/chairslist.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import glob
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 | l0_train = []
20 | l1_train = []
21 | flow_train = []
22 | for flow_map in sorted(glob.glob(os.path.join(filepath,'*_flow.flo'))):
23 | root_filename = flow_map[:-9]
24 | img1 = root_filename+'_img1.ppm'
25 | img2 = root_filename+'_img2.ppm'
26 | if not (os.path.isfile(os.path.join(filepath,img1)) and os.path.isfile(os.path.join(filepath,img2))):
27 | continue
28 |
29 | l0_train.append(img1)
30 | l1_train.append(img2)
31 | flow_train.append(flow_map)
32 |
33 | return l0_train, l1_train, flow_train
34 |
--------------------------------------------------------------------------------
/expansion/dataloader/chairssdlist.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import glob
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 | l0_train = []
20 | l1_train = []
21 | flow_train = []
22 | for flow_map in sorted(glob.glob('%s/flow/*.pfm'%filepath)):
23 | img1 = flow_map.replace('flow','t0').replace('.pfm','.png')
24 | img2 = flow_map.replace('flow','t1').replace('.pfm','.png')
25 |
26 | l0_train.append(img1)
27 | l1_train.append(img2)
28 | flow_train.append(flow_map)
29 |
30 | return l0_train, l1_train, flow_train
31 |
--------------------------------------------------------------------------------
/expansion/dataloader/depthloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numbers
3 | import torch
4 | import torch.utils.data as data
5 | import torch
6 | import torchvision.transforms as transforms
7 | import random
8 | from PIL import Image, ImageOps
9 | import numpy as np
10 | import torchvision
11 | from . import depth_transforms as flow_transforms
12 | import pdb
13 | import cv2
14 | from utils.flowlib import read_flow
15 | from utils.util_flow import readPFM, load_calib_cam_to_cam
16 |
17 | def default_loader(path):
18 | return Image.open(path).convert('RGB')
19 |
20 | def flow_loader(path):
21 | if '.pfm' in path:
22 | data = readPFM(path)[0]
23 | data[:,:,2] = 1
24 | return data
25 | else:
26 | return read_flow(path)
27 |
28 | def load_exts(cam_file):
29 | with open(cam_file, 'r') as f:
30 | lines = f.readlines()
31 |
32 | l_exts = []
33 | r_exts = []
34 | for l in lines:
35 | if 'L ' in l:
36 | l_exts.append(np.asarray([float(i) for i in l[2:].strip().split(' ')]).reshape(4,4))
37 | if 'R ' in l:
38 | r_exts.append(np.asarray([float(i) for i in l[2:].strip().split(' ')]).reshape(4,4))
39 | return l_exts,r_exts
40 |
41 | def disparity_loader(path):
42 | if '.png' in path:
43 | data = Image.open(path)
44 | data = np.ascontiguousarray(data,dtype=np.float32)/256
45 | return data
46 | else:
47 | return readPFM(path)[0]
48 |
49 | # triangulation
50 | def triangulation(disp, xcoord, ycoord, bl=1, fl = 450, cx = 479.5, cy = 269.5):
51 | depth = bl*fl / disp # 450px->15mm focal length
52 | X = (xcoord - cx) * depth / fl
53 | Y = (ycoord - cy) * depth / fl
54 | Z = depth
55 | P = np.concatenate((X[np.newaxis],Y[np.newaxis],Z[np.newaxis]),0).reshape(3,-1)
56 | P = np.concatenate((P,np.ones((1,P.shape[-1]))),0)
57 | return P
58 |
59 | class myImageFloder(data.Dataset):
60 | def __init__(self, iml0, iml1, flowl0, loader=default_loader, dploader= flow_loader, scale=1.,shape=[320,448], order=1, noise=0.06, pca_augmentor=True, prob = 1.,sc=False,disp0=None,disp1=None,calib=None ):
61 | self.iml0 = iml0
62 | self.iml1 = iml1
63 | self.flowl0 = flowl0
64 | self.loader = loader
65 | self.dploader = dploader
66 | self.scale=scale
67 | self.shape=shape
68 | self.order=order
69 | self.noise = noise
70 | self.pca_augmentor = pca_augmentor
71 | self.prob = prob
72 | self.sc = sc
73 | self.disp0 = disp0
74 | self.disp1 = disp1
75 | self.calib = calib
76 |
77 | def __getitem__(self, index):
78 | iml0 = self.iml0[index]
79 | iml1 = self.iml1[index]
80 | flowl0= self.flowl0[index]
81 | th, tw = self.shape
82 |
83 | iml0 = self.loader(iml0)
84 | iml1 = self.loader(iml1)
85 |
86 | # get disparity
87 | if self.sc:
88 | flowl0 = self.dploader(flowl0)
89 | flowl0 = np.ascontiguousarray(flowl0,dtype=np.float32)
90 | flowl0[np.isnan(flowl0)] = 1e6 # set to max
91 | if 'camera_data.txt' in self.calib[index]:
92 | bl=1
93 | if '15mm_' in self.calib[index]:
94 | fl=450 # 450
95 | else:
96 | fl=1050
97 | cx = 479.5
98 | cy = 269.5
99 | # negative disp
100 | d1 = np.abs(disparity_loader(self.disp0[index]))
101 | d2 = np.abs(disparity_loader(self.disp1[index]) + d1)
102 | elif 'Sintel' in self.calib[index]:
103 | fl = 1000
104 | bl = 1
105 | cx = 511.5
106 | cy = 217.5
107 | d1 = np.zeros(flowl0.shape[:2])
108 | d2 = np.zeros(flowl0.shape[:2])
109 | else:
110 | ints = load_calib_cam_to_cam(self.calib[index])
111 | fl = ints['K_cam2'][0,0]
112 | cx = ints['K_cam2'][0,2]
113 | cy = ints['K_cam2'][1,2]
114 | bl = ints['b20']-ints['b30']
115 | d1 = disparity_loader(self.disp0[index])
116 | d2 = disparity_loader(self.disp1[index])
117 | #flowl0[:,:,2] = (flowl0[:,:,2]==1).astype(float)
118 | flowl0[:,:,2] = np.logical_and(np.logical_and(flowl0[:,:,2]==1, d1!=0), d2!=0).astype(float)
119 |
120 | shape = d1.shape
121 | mesh = np.meshgrid(range(shape[1]),range(shape[0]))
122 | xcoord = mesh[0].astype(float)
123 | ycoord = mesh[1].astype(float)
124 |
125 | # triangulation in two frames
126 | P0 = triangulation(d1, xcoord, ycoord, bl=bl, fl = fl, cx = cx, cy = cy)
127 | P1 = triangulation(d2, xcoord + flowl0[:,:,0], ycoord + flowl0[:,:,1], bl=bl, fl = fl, cx = cx, cy = cy)
128 | dis0 = P0[2]
129 | dis1 = P1[2]
130 |
131 | change_size = dis0.reshape(shape).astype(np.float32)
132 | flow3d = (P1-P0)[:3].reshape((3,)+shape).transpose((1,2,0))
133 |
134 | gt_normal = np.concatenate((d1[:,:,np.newaxis],d2[:,:,np.newaxis],d2[:,:,np.newaxis]),-1)
135 | change_size = np.concatenate((change_size[:,:,np.newaxis],gt_normal,flow3d),2)
136 | else:
137 | shape = iml0.size
138 | shape=[shape[1],shape[0]]
139 | flowl0 = np.zeros((shape[0],shape[1],3))
140 | change_size = np.zeros((shape[0],shape[1],7))
141 | depth = disparity_loader(self.iml1[index].replace('camera','groundtruth'))
142 | change_size[:,:,0] = depth
143 |
144 | seqid = self.iml0[index].split('/')[-5].rsplit('_',3)[0]
145 | ints = load_calib_cam_to_cam('/data/gengshay/KITTI/%s/calib_cam_to_cam.txt'%seqid)
146 | fl = ints['K_cam2'][0,0]
147 | cx = ints['K_cam2'][0,2]
148 | cy = ints['K_cam2'][1,2]
149 | bl = ints['b20']-ints['b30']
150 |
151 |
152 | iml1 = np.asarray(iml1)/255.
153 | iml0 = np.asarray(iml0)/255.
154 | iml0 = iml0[:,:,::-1].copy()
155 | iml1 = iml1[:,:,::-1].copy()
156 |
157 | ## following data augmentation procedure in PWCNet
158 | ## https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
159 | import __main__ # a workaround for "discount_coeff"
160 | try:
161 | with open('/scratch/gengshay/iter_counts-%d.txt'%int(__main__.args.logname.split('-')[-1]), 'r') as f:
162 | iter_counts = int(f.readline())
163 | except:
164 | iter_counts = 0
165 | schedule = [0.5, 1., 50000.] # initial coeff, final_coeff, half life
166 | schedule_coeff = schedule[0] + (schedule[1] - schedule[0]) * \
167 | (2/(1+np.exp(-1.0986*iter_counts/schedule[2])) - 1)
168 |
169 | if self.pca_augmentor:
170 | pca_augmentor = flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff)
171 | else:
172 | pca_augmentor = flow_transforms.Scale(1., order=0)
173 |
174 | if np.random.binomial(1,self.prob):
175 | co_transform1 = flow_transforms.Compose([
176 | flow_transforms.SpatialAug([th,tw],
177 | scale=[0.2,0.,0.1],
178 | rot=[0.4,0.],
179 | trans=[0.4,0.],
180 | squeeze=[0.3,0.], schedule_coeff=schedule_coeff, order=self.order),
181 | ])
182 | else:
183 | co_transform1 = flow_transforms.Compose([
184 | flow_transforms.RandomCrop([th,tw]),
185 | ])
186 |
187 | co_transform2 = flow_transforms.Compose([
188 | flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff),
189 | #flow_transforms.PCAAug(schedule_coeff=schedule_coeff),
190 | flow_transforms.ChromaticAug( schedule_coeff=schedule_coeff, noise=self.noise),
191 | ])
192 |
193 | flowl0 = np.concatenate([flowl0,change_size],-1)
194 | augmented,flowl0,intr = co_transform1([iml0, iml1], flowl0, [fl,cx,cy,bl])
195 | imol0 = augmented[0]
196 | imol1 = augmented[1]
197 | augmented,flowl0,intr = co_transform2(augmented, flowl0, intr)
198 |
199 | iml0 = augmented[0]
200 | iml1 = augmented[1]
201 | flowl0 = flowl0.astype(np.float32)
202 | change_size = flowl0[:,:,3:]
203 | flowl0 = flowl0[:,:,:3]
204 |
205 | # randomly cover a region
206 | sx=0;sy=0;cx=0;cy=0
207 | if np.random.binomial(1,0.5):
208 | sx = int(np.random.uniform(25,100))
209 | sy = int(np.random.uniform(25,100))
210 | #sx = int(np.random.uniform(50,150))
211 | #sy = int(np.random.uniform(50,150))
212 | cx = int(np.random.uniform(sx,iml1.shape[0]-sx))
213 | cy = int(np.random.uniform(sy,iml1.shape[1]-sy))
214 | iml1[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(iml1,0),0)[np.newaxis,np.newaxis]
215 |
216 | iml0 = torch.Tensor(np.transpose(iml0,(2,0,1)))
217 | iml1 = torch.Tensor(np.transpose(iml1,(2,0,1)))
218 |
219 | return iml0, iml1, flowl0, change_size, intr, imol0, imol1, np.asarray([cx-sx,cx+sx,cy-sy,cy+sy])
220 |
221 | def __len__(self):
222 | return len(self.iml0)
223 |
--------------------------------------------------------------------------------
/expansion/dataloader/hd1klist.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import pdb
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 |
20 | left_fold = 'image_2/'
21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('HD1K2018') > -1]
22 | train = sorted(train)
23 |
24 | l0_train = [filepath+left_fold+img for img in train]
25 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%04d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
26 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%04d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
27 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
28 |
29 | return l0_train, l1_train, flow_train
30 |
--------------------------------------------------------------------------------
/expansion/dataloader/kitti12list.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'colored_0/'
20 | flow_noc = 'flow_occ/'
21 |
22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
23 |
24 | l0_train = [filepath+left_fold+img for img in train]
25 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
26 | flow_train = [filepath+flow_noc+img for img in train]
27 |
28 |
29 | return l0_train, l1_train, flow_train
30 |
--------------------------------------------------------------------------------
/expansion/dataloader/kitti15list.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'image_2/'
20 | flow_noc = 'flow_occ/'
21 |
22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
23 |
24 | l0_train = [filepath+left_fold+img for img in train]
25 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
26 | flow_train = [filepath+flow_noc+img for img in train]
27 |
28 |
29 | return sorted(l0_train), sorted(l1_train), sorted(flow_train)
30 |
--------------------------------------------------------------------------------
/expansion/dataloader/kitti15list_train.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'image_2/'
20 | flow_noc = 'flow_occ/'
21 |
22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
23 |
24 | train = [i for i in train if int(i.split('_')[0])%5!=0]
25 |
26 | l0_train = [filepath+left_fold+img for img in train]
27 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
28 | flow_train = [filepath+flow_noc+img for img in train]
29 |
30 |
31 | return sorted(l0_train), sorted(l1_train), sorted(flow_train)
32 |
--------------------------------------------------------------------------------
/expansion/dataloader/kitti15list_train_lidar.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'image_2/'
20 | flow_noc = 'flow_occ/'
21 |
22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
23 |
24 | # train = [i for i in train if int(i.split('_')[0])%5!=0]
25 | with open('/data/gengshay/kitti_scene/devkit/mapping/train_mapping.txt','r') as f:
26 | flags = [True if len(i)>1 else False for i in f.readlines()]
27 | train = [fn for (it,fn) in enumerate(sorted(train)) if flags[it] ][:100]
28 |
29 | l0_train = [filepath+left_fold+img for img in train]
30 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
31 | flow_train = [filepath+flow_noc+img for img in train]
32 |
33 |
34 | return sorted(l0_train), sorted(l1_train), sorted(flow_train)
35 |
--------------------------------------------------------------------------------
/expansion/dataloader/kitti15list_val.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'image_2/'
20 | flow_noc = 'flow_occ/'
21 |
22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
23 |
24 | train = [i for i in train if int(i.split('_')[0])%5==0]
25 |
26 | l0_train = [filepath+left_fold+img for img in train]
27 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
28 | flow_train = [filepath+flow_noc+img for img in train]
29 |
30 |
31 | return sorted(l0_train), sorted(l1_train), sorted(flow_train)
32 |
--------------------------------------------------------------------------------
/expansion/dataloader/kitti15list_val_lidar.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'image_2/'
20 | flow_noc = 'flow_occ/'
21 |
22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1]
23 |
24 | # train = [i for i in train if int(i.split('_')[0])%5!=0]
25 | with open('/data/gengshay/kitti_scene/devkit/mapping/train_mapping.txt','r') as f:
26 | flags = [True if len(i)>1 else False for i in f.readlines()]
27 | train = [fn for (it,fn) in enumerate(sorted(train)) if flags[it] ][100:]
28 |
29 | l0_train = [filepath+left_fold+img for img in train]
30 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
31 | flow_train = [filepath+flow_noc+img for img in train]
32 |
33 |
34 | return sorted(l0_train), sorted(l1_train), sorted(flow_train)
35 |
--------------------------------------------------------------------------------
/expansion/dataloader/kitti15list_val_mr.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 |
19 | left_fold = 'image_2/'
20 | flow_noc = 'flow_occ/'
21 |
22 | train = [img for img in os.listdir(filepath+left_fold) if 'Kitti' in img and img.find('_10') > -1]
23 |
24 | # train = [i for i in train if int(i.split('_')[1])%5==0]
25 | import pdb; pdb.set_trace()
26 | train = sorted([i for i in train if int(i.split('_')[1])%5==0])[0:1]
27 |
28 | l0_train = [filepath+left_fold+img for img in train]
29 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train]
30 | flow_train = [filepath+flow_noc+img for img in train]
31 |
32 | l0_train += [filepath+left_fold+img.replace('_10','_09') for img in train]
33 | l1_train += [filepath+left_fold+img for img in train]
34 | flow_train += flow_train
35 |
36 | tmp = l0_train
37 | l0_train = l0_train+ [i.replace('rob_flow', 'kitti_scene').replace('Kitti2015_','') for i in l1_train]
38 | l1_train = l1_train+tmp
39 | flow_train += flow_train
40 |
41 | return l0_train, l1_train, flow_train
42 |
--------------------------------------------------------------------------------
/expansion/dataloader/robloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numbers
3 | import torch
4 | import torch.utils.data as data
5 | import torch
6 | import torchvision.transforms as transforms
7 | import random
8 | from PIL import Image, ImageOps
9 | import numpy as np
10 | import torchvision
11 | from . import flow_transforms
12 | import pdb
13 | import cv2
14 | from utils.flowlib import read_flow
15 | from utils.util_flow import readPFM
16 |
17 |
18 | def default_loader(path):
19 | return Image.open(path).convert('RGB')
20 |
21 | def flow_loader(path):
22 | if '.pfm' in path:
23 | data = readPFM(path)[0]
24 | data[:,:,2] = 1
25 | return data
26 | else:
27 | return read_flow(path)
28 |
29 |
30 | def disparity_loader(path):
31 | if '.png' in path:
32 | data = Image.open(path)
33 | data = np.ascontiguousarray(data,dtype=np.float32)/256
34 | return data
35 | else:
36 | return readPFM(path)[0]
37 |
38 | class myImageFloder(data.Dataset):
39 | def __init__(self, iml0, iml1, flowl0, loader=default_loader, dploader= flow_loader, scale=1.,shape=[320,448], order=1, noise=0.06, pca_augmentor=True, prob = 1., cover=False, black=False, scale_aug=[0.4,0.2]):
40 | self.iml0 = iml0
41 | self.iml1 = iml1
42 | self.flowl0 = flowl0
43 | self.loader = loader
44 | self.dploader = dploader
45 | self.scale=scale
46 | self.shape=shape
47 | self.order=order
48 | self.noise = noise
49 | self.pca_augmentor = pca_augmentor
50 | self.prob = prob
51 | self.cover = cover
52 | self.black = black
53 | self.scale_aug = scale_aug
54 |
55 | def __getitem__(self, index):
56 | iml0 = self.iml0[index]
57 | iml1 = self.iml1[index]
58 | flowl0= self.flowl0[index]
59 | th, tw = self.shape
60 |
61 | iml0 = self.loader(iml0)
62 | iml1 = self.loader(iml1)
63 | iml1 = np.asarray(iml1)/255.
64 | iml0 = np.asarray(iml0)/255.
65 | iml0 = iml0[:,:,::-1].copy()
66 | iml1 = iml1[:,:,::-1].copy()
67 | flowl0 = self.dploader(flowl0)
68 | #flowl0[:,:,-1][flowl0[:,:,0]==np.inf]=0 # for gtav window pfm files
69 | #flowl0[:,:,0][~flowl0[:,:,2].astype(bool)]=0
70 | #flowl0[:,:,1][~flowl0[:,:,2].astype(bool)]=0 # avoid nan in grad
71 | flowl0 = np.ascontiguousarray(flowl0,dtype=np.float32)
72 | flowl0[np.isnan(flowl0)] = 1e6 # set to max
73 |
74 | ## following data augmentation procedure in PWCNet
75 | ## https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu
76 | import __main__ # a workaround for "discount_coeff"
77 | try:
78 | with open('iter_counts-%d.txt'%int(__main__.args.logname.split('-')[-1]), 'r') as f:
79 | iter_counts = int(f.readline())
80 | except:
81 | iter_counts = 0
82 | schedule = [0.5, 1., 50000.] # initial coeff, final_coeff, half life
83 | schedule_coeff = schedule[0] + (schedule[1] - schedule[0]) * \
84 | (2/(1+np.exp(-1.0986*iter_counts/schedule[2])) - 1)
85 |
86 | if self.pca_augmentor:
87 | pca_augmentor = flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff)
88 | else:
89 | pca_augmentor = flow_transforms.Scale(1., order=0)
90 |
91 | if np.random.binomial(1,self.prob):
92 | co_transform = flow_transforms.Compose([
93 | flow_transforms.Scale(self.scale, order=self.order),
94 | #flow_transforms.SpatialAug([th,tw], trans=[0.2,0.03], order=self.order, black=self.black),
95 | flow_transforms.SpatialAug([th,tw],scale=[self.scale_aug[0],0.03,self.scale_aug[1]],
96 | rot=[0.4,0.03],
97 | trans=[0.4,0.03],
98 | squeeze=[0.3,0.], schedule_coeff=schedule_coeff, order=self.order, black=self.black),
99 | #flow_transforms.pseudoPCAAug(schedule_coeff=schedule_coeff),
100 | flow_transforms.PCAAug(schedule_coeff=schedule_coeff),
101 | flow_transforms.ChromaticAug( schedule_coeff=schedule_coeff, noise=self.noise),
102 | ])
103 | else:
104 | co_transform = flow_transforms.Compose([
105 | flow_transforms.Scale(self.scale, order=self.order),
106 | flow_transforms.SpatialAug([th,tw], trans=[0.4,0.03], order=self.order, black=self.black)
107 | ])
108 |
109 | augmented,flowl0 = co_transform([iml0, iml1], flowl0)
110 | iml0 = augmented[0]
111 | iml1 = augmented[1]
112 |
113 | if self.cover:
114 | ## randomly cover a region
115 | # following sec. 3.2 of http://openaccess.thecvf.com/content_CVPR_2019/html/Yang_Hierarchical_Deep_Stereo_Matching_on_High-Resolution_Images_CVPR_2019_paper.html
116 | if np.random.binomial(1,0.5):
117 | #sx = int(np.random.uniform(25,100))
118 | #sy = int(np.random.uniform(25,100))
119 | sx = int(np.random.uniform(50,125))
120 | sy = int(np.random.uniform(50,125))
121 | #sx = int(np.random.uniform(50,150))
122 | #sy = int(np.random.uniform(50,150))
123 | cx = int(np.random.uniform(sx,iml1.shape[0]-sx))
124 | cy = int(np.random.uniform(sy,iml1.shape[1]-sy))
125 | iml1[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(iml1,0),0)[np.newaxis,np.newaxis]
126 |
127 | iml0 = torch.Tensor(np.transpose(iml0,(2,0,1)))
128 | iml1 = torch.Tensor(np.transpose(iml1,(2,0,1)))
129 |
130 | return iml0, iml1, flowl0
131 |
132 | def __len__(self):
133 | return len(self.iml0)
134 |
--------------------------------------------------------------------------------
/expansion/dataloader/sceneflowlist.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import glob
4 |
5 | def dataloader(filepath, level=6):
6 | iml0 = []
7 | iml1 = []
8 | flowl0 = []
9 | disp0 = []
10 | dispc = []
11 | calib = []
12 | level_stars = '/*'*level
13 | candidate_pool = glob.glob('%s/optical_flow%s'%(filepath,level_stars))
14 | for flow_path in sorted(candidate_pool):
15 | if 'TEST' in flow_path: continue
16 | if 'flower_storm_x2/into_future/right/OpticalFlowIntoFuture_0023_R.pfm' in flow_path:
17 | continue
18 | if 'flower_storm_x2/into_future/left/OpticalFlowIntoFuture_0023_L.pfm' in flow_path:
19 | continue
20 | if 'flower_storm_augmented0_x2/into_future/right/OpticalFlowIntoFuture_0023_R.pfm' in flow_path:
21 | continue
22 | if 'flower_storm_augmented0_x2/into_future/left/OpticalFlowIntoFuture_0023_L.pfm' in flow_path:
23 | continue
24 | if 'FlyingThings' in flow_path and '_0014_' in flow_path:
25 | continue
26 | if 'FlyingThings' in flow_path and '_0015_' in flow_path:
27 | continue
28 | idd = flow_path.split('/')[-1].split('_')[-2]
29 | if 'into_future' in flow_path:
30 | idd_p1 = '%04d'%(int(idd)+1)
31 | else:
32 | idd_p1 = '%04d'%(int(idd)-1)
33 | if os.path.exists(flow_path.replace(idd,idd_p1)):
34 | d0_path = flow_path.replace('/into_future/','/').replace('/into_past/','/').replace('optical_flow','disparity')
35 | d0_path = '%s/%s.pfm'%(d0_path.rsplit('/',1)[0],idd)
36 | dc_path = flow_path.replace('optical_flow','disparity_change')
37 | dc_path = '%s/%s.pfm'%(dc_path.rsplit('/',1)[0],idd)
38 | im_path = flow_path.replace('/into_future/','/').replace('/into_past/','/').replace('optical_flow','frames_cleanpass')
39 | im0_path = '%s/%s.png'%(im_path.rsplit('/',1)[0],idd)
40 | im1_path = '%s/%s.png'%(im_path.rsplit('/',1)[0],idd_p1)
41 | #with open('%s/camera_data.txt'%(im0_path.replace('frames_cleanpass','camera_data').rsplit('/',2)[0]),'r') as f:
42 | # if 'FlyingThings' in flow_path and len(f.readlines())!=40:
43 | # print(flow_path)
44 | # continue
45 | iml0.append(im0_path)
46 | iml1.append(im1_path)
47 | flowl0.append(flow_path)
48 | disp0.append(d0_path)
49 | dispc.append(dc_path)
50 | calib.append('%s/camera_data.txt'%(im0_path.replace('frames_cleanpass','camera_data').rsplit('/',2)[0]))
51 | return iml0, iml1, flowl0, disp0, dispc, calib
52 |
--------------------------------------------------------------------------------
/expansion/dataloader/seqlist.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import glob
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 |
20 | train = [img for img in sorted(glob.glob('%s/*'%filepath))]
21 |
22 | l0_train = train[:-1]
23 | l1_train = train[1:]
24 |
25 |
26 | return sorted(l0_train), sorted(l1_train), sorted(l0_train)
27 |
--------------------------------------------------------------------------------
/expansion/dataloader/sintellist.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import pdb
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 |
20 | left_fold = 'image_2/'
21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1]
22 |
23 | l0_train = [filepath+left_fold+img for img in train]
24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
25 |
26 | #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val
27 |
28 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
29 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
30 |
31 |
32 | return l0_train, l1_train, flow_train
33 |
--------------------------------------------------------------------------------
/expansion/dataloader/sintellist_clean.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import pdb
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 |
20 | left_fold = 'image_2/'
21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel_clean') > -1]
22 |
23 | l0_train = [filepath+left_fold+img for img in train]
24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
25 |
26 | #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val
27 |
28 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
29 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
30 |
31 | return l0_train, l1_train, flow_train
32 |
--------------------------------------------------------------------------------
/expansion/dataloader/sintellist_final.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import pdb
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 |
20 | left_fold = 'image_2/'
21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel_final') > -1]
22 |
23 | l0_train = [filepath+left_fold+img for img in train]
24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
25 |
26 | #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val
27 |
28 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
29 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
30 |
31 | pdb.set_trace()
32 | return l0_train, l1_train, flow_train
33 |
--------------------------------------------------------------------------------
/expansion/dataloader/sintellist_train.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import pdb
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 |
20 | left_fold = 'image_2/'
21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1]
22 |
23 | l0_train = [filepath+left_fold+img for img in train]
24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
25 |
26 | l0_train = [i for i in l0_train if not(('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i))] # remove 10 as val
27 |
28 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
29 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
30 |
31 |
32 | return l0_train, l1_train, flow_train
33 |
--------------------------------------------------------------------------------
/expansion/dataloader/sintellist_val.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 | import pdb
8 |
9 | IMG_EXTENSIONS = [
10 | '.jpg', '.JPG', '.jpeg', '.JPEG',
11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
12 | ]
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
17 |
18 | def dataloader(filepath):
19 |
20 | left_fold = 'image_2/'
21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1]
22 |
23 | l0_train = [filepath+left_fold+img for img in train]
24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ]
25 |
26 | l0_train = [i for i in l0_train if ('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i)] # remove 10 as val
27 | #l0_train = [i for i in l0_train if not(('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i))] # remove 10 as val
28 |
29 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train]
30 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train]
31 |
32 |
33 | return sorted(l0_train)[::3], sorted(l1_train)[::3], sorted(flow_train)[::3]
34 | # return sorted(l0_train)[::10], sorted(l1_train)[::10], sorted(flow_train)[::10]
35 |
--------------------------------------------------------------------------------
/expansion/dataloader/thingslist.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 |
3 | from PIL import Image
4 | import os
5 | import os.path
6 | import numpy as np
7 |
8 | IMG_EXTENSIONS = [
9 | '.jpg', '.JPG', '.jpeg', '.JPEG',
10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
16 |
17 | def dataloader(filepath):
18 | exc_list = [
19 | '0004117.flo',
20 | '0003149.flo',
21 | '0001203.flo',
22 | '0003147.flo',
23 | '0003666.flo',
24 | '0006337.flo',
25 | '0006336.flo',
26 | '0007126.flo',
27 | '0004118.flo',
28 | ]
29 |
30 | left_fold = 'image_clean/left/'
31 | flow_noc = 'flow/left/into_future/'
32 | train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
33 |
34 | l0_trainlf = [filepath+left_fold+img.replace('flo','png') for img in train]
35 | l1_trainlf = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainlf]
36 | flow_trainlf = [filepath+flow_noc+img for img in train]
37 |
38 |
39 | exc_list = [
40 | '0003148.flo',
41 | '0004117.flo',
42 | '0002890.flo',
43 | '0003149.flo',
44 | '0001203.flo',
45 | '0003666.flo',
46 | '0006337.flo',
47 | '0006336.flo',
48 | '0004118.flo',
49 | ]
50 |
51 | left_fold = 'image_clean/right/'
52 | flow_noc = 'flow/right/into_future/'
53 | train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
54 |
55 | l0_trainrf = [filepath+left_fold+img.replace('flo','png') for img in train]
56 | l1_trainrf = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainrf]
57 | flow_trainrf = [filepath+flow_noc+img for img in train]
58 |
59 |
60 | exc_list = [
61 | '0004237.flo',
62 | '0004705.flo',
63 | '0004045.flo',
64 | '0004346.flo',
65 | '0000161.flo',
66 | '0000931.flo',
67 | '0000121.flo',
68 | '0010822.flo',
69 | '0004117.flo',
70 | '0006023.flo',
71 | '0005034.flo',
72 | '0005054.flo',
73 | '0000162.flo',
74 | '0000053.flo',
75 | '0005055.flo',
76 | '0003147.flo',
77 | '0004876.flo',
78 | '0000163.flo',
79 | '0006878.flo',
80 | ]
81 |
82 | left_fold = 'image_clean/left/'
83 | flow_noc = 'flow/left/into_past/'
84 | train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
85 |
86 | l0_trainlp = [filepath+left_fold+img.replace('flo','png') for img in train]
87 | l1_trainlp = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(-1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainlp]
88 | flow_trainlp = [filepath+flow_noc+img for img in train]
89 |
90 | exc_list = [
91 | '0003148.flo',
92 | '0004705.flo',
93 | '0000161.flo',
94 | '0000121.flo',
95 | '0004117.flo',
96 | '0000160.flo',
97 | '0005034.flo',
98 | '0005054.flo',
99 | '0000162.flo',
100 | '0000053.flo',
101 | '0005055.flo',
102 | '0003147.flo',
103 | '0001549.flo',
104 | '0000163.flo',
105 | '0006336.flo',
106 | '0001648.flo',
107 | '0006878.flo',
108 | ]
109 |
110 | left_fold = 'image_clean/right/'
111 | flow_noc = 'flow/right/into_past/'
112 | train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0]
113 |
114 | l0_trainrp = [filepath+left_fold+img.replace('flo','png') for img in train]
115 | l1_trainrp = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(-1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainrp]
116 | flow_trainrp = [filepath+flow_noc+img for img in train]
117 |
118 |
119 | l0_train = l0_trainlf + l0_trainrf + l0_trainlp + l0_trainrp
120 | l1_train = l1_trainlf + l1_trainrf + l1_trainlp + l1_trainrp
121 | flow_train = flow_trainlf + flow_trainrf + flow_trainlp + flow_trainrp
122 | return l0_train, l1_train, flow_train
123 |
--------------------------------------------------------------------------------
/expansion/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/models/__init__.py
--------------------------------------------------------------------------------
/expansion/models/__pycache__/VCN_exp.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/models/__pycache__/VCN_exp.cpython-38.pyc
--------------------------------------------------------------------------------
/expansion/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/expansion/models/__pycache__/conv4d.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/models/__pycache__/conv4d.cpython-38.pyc
--------------------------------------------------------------------------------
/expansion/models/__pycache__/submodule.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/models/__pycache__/submodule.cpython-38.pyc
--------------------------------------------------------------------------------
/expansion/models/conv4d.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | import torch.nn as nn
3 | import math
4 | import torch
5 | from torch.nn.parameter import Parameter
6 | import torch.nn.functional as F
7 | from torch.nn import Module
8 | from torch.nn.modules.conv import _ConvNd
9 | from torch.nn.modules.utils import _quadruple
10 | from torch.autograd import Variable
11 | from torch.nn import Conv2d
12 |
13 | def conv4d(data,filters,bias=None,permute_filters=True,use_half=False):
14 | """
15 | This is done by stacking results of multiple 3D convolutions, and is very slow.
16 | Taken from https://github.com/ignacio-rocco/ncnet
17 | """
18 | b,c,h,w,d,t=data.size()
19 |
20 | data=data.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop
21 |
22 | # Same permutation is done with filters, unless already provided with permutation
23 | if permute_filters:
24 | filters=filters.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop
25 |
26 | c_out=filters.size(1)
27 | if use_half:
28 | output = Variable(torch.HalfTensor(h,b,c_out,w,d,t),requires_grad=data.requires_grad)
29 | else:
30 | output = Variable(torch.zeros(h,b,c_out,w,d,t),requires_grad=data.requires_grad)
31 |
32 | padding=filters.size(0)//2
33 | if use_half:
34 | Z=Variable(torch.zeros(padding,b,c,w,d,t).half())
35 | else:
36 | Z=Variable(torch.zeros(padding,b,c,w,d,t))
37 |
38 | if data.is_cuda:
39 | Z=Z.cuda(data.get_device())
40 | output=output.cuda(data.get_device())
41 |
42 | data_padded = torch.cat((Z,data,Z),0)
43 |
44 |
45 | for i in range(output.size(0)): # loop on first feature dimension
46 | # convolve with center channel of filter (at position=padding)
47 | output[i,:,:,:,:,:]=F.conv3d(data_padded[i+padding,:,:,:,:,:],
48 | filters[padding,:,:,:,:,:], bias=bias, stride=1, padding=padding)
49 | # convolve with upper/lower channels of filter (at postions [:padding] [padding+1:])
50 | for p in range(1,padding+1):
51 | output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding-p,:,:,:,:,:],
52 | filters[padding-p,:,:,:,:,:], bias=None, stride=1, padding=padding)
53 | output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding+p,:,:,:,:,:],
54 | filters[padding+p,:,:,:,:,:], bias=None, stride=1, padding=padding)
55 |
56 | output=output.permute(1,2,0,3,4,5).contiguous()
57 | return output
58 |
59 | class Conv4d(_ConvNd):
60 | """Applies a 4D convolution over an input signal composed of several input
61 | planes.
62 | """
63 |
64 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True):
65 | # stride, dilation and groups !=1 functionality not tested
66 | stride=1
67 | dilation=1
68 | groups=1
69 | # zero padding is added automatically in conv4d function to preserve tensor size
70 | padding = 0
71 | kernel_size = _quadruple(kernel_size)
72 | stride = _quadruple(stride)
73 | padding = _quadruple(padding)
74 | dilation = _quadruple(dilation)
75 | super(Conv4d, self).__init__(
76 | in_channels, out_channels, kernel_size, stride, padding, dilation,
77 | False, _quadruple(0), groups, bias)
78 | # weights will be sliced along one dimension during convolution loop
79 | # make the looping dimension to be the first one in the tensor,
80 | # so that we don't need to call contiguous() inside the loop
81 | self.pre_permuted_filters=pre_permuted_filters
82 | if self.pre_permuted_filters:
83 | self.weight.data=self.weight.data.permute(2,0,1,3,4,5).contiguous()
84 | self.use_half=False
85 | # self.isbias = bias
86 | # if not self.isbias:
87 | # self.bn = torch.nn.BatchNorm1d(out_channels)
88 |
89 |
90 | def forward(self, input):
91 | out = conv4d(input, self.weight, bias=self.bias,permute_filters=not self.pre_permuted_filters,use_half=self.use_half) # filters pre-permuted in constructor
92 | # if not self.isbias:
93 | # b,c,u,v,h,w = out.shape
94 | # out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w)
95 | return out
96 |
97 | class fullConv4d(torch.nn.Module):
98 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True):
99 | super(fullConv4d, self).__init__()
100 | self.conv = Conv4d(in_channels, out_channels, kernel_size, bias=bias, pre_permuted_filters=pre_permuted_filters)
101 | self.isbias = bias
102 | if not self.isbias:
103 | self.bn = torch.nn.BatchNorm1d(out_channels)
104 |
105 | def forward(self, input):
106 | out = self.conv(input)
107 | if not self.isbias:
108 | b,c,u,v,h,w = out.shape
109 | out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w)
110 | return out
111 |
112 | class butterfly4D(torch.nn.Module):
113 | '''
114 | butterfly 4d
115 | '''
116 | def __init__(self, fdima, fdimb, withbn=True, full=True,groups=1):
117 | super(butterfly4D, self).__init__()
118 | self.proj = nn.Sequential(projfeat4d(fdima, fdimb, 1, with_bn=withbn,groups=groups),
119 | nn.ReLU(inplace=True),)
120 | self.conva1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups)
121 | self.conva2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups)
122 | self.convb3 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
123 | self.convb2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
124 | self.convb1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups)
125 |
126 | #@profile
127 | def forward(self,x):
128 | out = self.proj(x)
129 | b,c,u,v,h,w = out.shape # 9x9
130 |
131 | out1 = self.conva1(out) # 5x5, 3
132 | _,c1,u1,v1,h1,w1 = out1.shape
133 |
134 | out2 = self.conva2(out1) # 3x3, 9
135 | _,c2,u2,v2,h2,w2 = out2.shape
136 |
137 | out2 = self.convb3(out2) # 3x3, 9
138 |
139 | tout1 = F.upsample(out2.view(b,c,u2,v2,-1),(u1,v1,h2*w2),mode='trilinear').view(b,c,u1,v1,h2,w2) # 5x5
140 | tout1 = F.upsample(tout1.view(b,c,-1,h2,w2),(u1*v1,h1,w1),mode='trilinear').view(b,c,u1,v1,h1,w1) # 5x5
141 | out1 = tout1 + out1
142 | out1 = self.convb2(out1)
143 |
144 | tout = F.upsample(out1.view(b,c,u1,v1,-1),(u,v,h1*w1),mode='trilinear').view(b,c,u,v,h1,w1)
145 | tout = F.upsample(tout.view(b,c,-1,h1,w1),(u*v,h,w),mode='trilinear').view(b,c,u,v,h,w)
146 | out = tout + out
147 | out = self.convb1(out)
148 |
149 | return out
150 |
151 |
152 |
153 | class projfeat4d(torch.nn.Module):
154 | '''
155 | Turn 3d projection into 2d projection
156 | '''
157 | def __init__(self, in_planes, out_planes, stride, with_bn=True,groups=1):
158 | super(projfeat4d, self).__init__()
159 | self.with_bn = with_bn
160 | self.stride = stride
161 | self.conv1 = nn.Conv3d(in_planes, out_planes, 1, (stride,stride,1), padding=0,bias=not with_bn,groups=groups)
162 | self.bn = nn.BatchNorm3d(out_planes)
163 |
164 | def forward(self,x):
165 | b,c,u,v,h,w = x.size()
166 | x = self.conv1(x.view(b,c,u,v,h*w))
167 | if self.with_bn:
168 | x = self.bn(x)
169 | _,c,u,v,_ = x.shape
170 | x = x.view(b,c,u,v,h,w)
171 | return x
172 |
173 | class sepConv4d(torch.nn.Module):
174 | '''
175 | Separable 4d convolution block as 2 3D convolutions
176 | '''
177 | def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, ksize=3, full=True,groups=1):
178 | super(sepConv4d, self).__init__()
179 | bias = not with_bn
180 | self.isproj = False
181 | self.stride = stride[0]
182 | expand = 1
183 |
184 | if with_bn:
185 | if in_planes != out_planes:
186 | self.isproj = True
187 | self.proj = nn.Sequential(nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups),
188 | nn.BatchNorm2d(out_planes))
189 | if full:
190 | self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups),
191 | nn.BatchNorm3d(in_planes))
192 | else:
193 | self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups),
194 | nn.BatchNorm3d(in_planes))
195 | self.conv2 = nn.Sequential(nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups),
196 | nn.BatchNorm3d(in_planes*expand))
197 | else:
198 | if in_planes != out_planes:
199 | self.isproj = True
200 | self.proj = nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups)
201 | if full:
202 | self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups)
203 | else:
204 | self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups)
205 | self.conv2 = nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups)
206 | self.relu = nn.ReLU(inplace=True)
207 |
208 | #@profile
209 | def forward(self,x):
210 | b,c,u,v,h,w = x.shape
211 | x = self.conv2(x.view(b,c,u,v,-1))
212 | b,c,u,v,_ = x.shape
213 | x = self.relu(x)
214 | x = self.conv1(x.view(b,c,-1,h,w))
215 | b,c,_,h,w = x.shape
216 |
217 | if self.isproj:
218 | x = self.proj(x.view(b,c,-1,w))
219 | x = x.view(b,-1,u,v,h,w)
220 | return x
221 |
222 |
223 | class sepConv4dBlock(torch.nn.Module):
224 | '''
225 | Separable 4d convolution block as 2 2D convolutions and a projection
226 | layer
227 | '''
228 | def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, full=True,groups=1):
229 | super(sepConv4dBlock, self).__init__()
230 | if in_planes == out_planes and stride==(1,1,1):
231 | self.downsample = None
232 | else:
233 | if full:
234 | self.downsample = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn,ksize=1, full=full,groups=groups)
235 | else:
236 | self.downsample = projfeat4d(in_planes, out_planes,stride[0], with_bn=with_bn,groups=groups)
237 | self.conv1 = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn, full=full ,groups=groups)
238 | self.conv2 = sepConv4d(out_planes, out_planes,(1,1,1), with_bn=with_bn, full=full,groups=groups)
239 | self.relu1 = nn.ReLU(inplace=True)
240 | self.relu2 = nn.ReLU(inplace=True)
241 |
242 | #@profile
243 | def forward(self,x):
244 | out = self.relu1(self.conv1(x))
245 | if self.downsample:
246 | x = self.downsample(x)
247 | out = self.relu2(x + self.conv2(out))
248 | return out
249 |
250 |
251 | ##import torch.backends.cudnn as cudnn
252 | ##cudnn.benchmark = True
253 | #import time
254 | ##im = torch.randn(9,64,9,160,224).cuda()
255 | ##net = torch.nn.Conv3d(64, 64, 3).cuda()
256 | ##net = Conv4d(1,1,3,bias=True,pre_permuted_filters=True).cuda()
257 | ##net = sepConv4dBlock(2,2,stride=(1,1,1)).cuda()
258 | #
259 | ##im = torch.randn(1,16,9,9,96,320).cuda()
260 | ##net = sepConv4d(16,16,with_bn=False).cuda()
261 | #
262 | ##im = torch.randn(1,16,81,96,320).cuda()
263 | ##net = torch.nn.Conv3d(16,16,(1,3,3),padding=(0,1,1)).cuda()
264 | #
265 | ##im = torch.randn(1,16,9,9,96*320).cuda()
266 | ##net = torch.nn.Conv3d(16,16,(3,3,1),padding=(1,1,0)).cuda()
267 | #
268 | ##im = torch.randn(10000,10,9,9).cuda()
269 | ##net = torch.nn.Conv2d(10,10,3,padding=1).cuda()
270 | #
271 | ##im = torch.randn(81,16,96,320).cuda()
272 | ##net = torch.nn.Conv2d(16,16,3,padding=1).cuda()
273 | #c= int(16 *1)
274 | #cp = int(16 *1)
275 | #h=int(96 *4)
276 | #w=int(320 *4)
277 | #k=3
278 | #im = torch.randn(1,c,h,w).cuda()
279 | #net = torch.nn.Conv2d(c,cp,k,padding=k//2).cuda()
280 | #
281 | #im2 = torch.randn(cp,k*k*c).cuda()
282 | #im1 = F.unfold(im, (k,k), padding=k//2)[0]
283 | #
284 | #
285 | #net(im)
286 | #net(im)
287 | #torch.mm(im2,im1)
288 | #torch.mm(im2,im1)
289 | #torch.cuda.synchronize()
290 | #beg = time.time()
291 | #for i in range(100):
292 | # net(im)
293 | # #im1 = F.unfold(im, (k,k), padding=k//2)[0]
294 | # torch.mm(im2,im1)
295 | #torch.cuda.synchronize()
296 | #print('%f'%((time.time()-beg)*10.))
297 |
--------------------------------------------------------------------------------
/expansion/submission.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import sys
3 | import cv2
4 | import argparse
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.backends.cudnn as cudnn
9 | import torch.optim as optim
10 | import torch.nn.functional as F
11 | cudnn.benchmark = False
12 |
13 | class Expansion():
14 |
15 | def __init__(self, loadmodel = 'pretrained_models/optical_expansion/robust.pth', testres = 1, maxdisp = 256, fac = 1):
16 |
17 | maxw,maxh = [int(testres*1280), int(testres*384)]
18 |
19 | max_h = int(maxh // 64 * 64)
20 | max_w = int(maxw // 64 * 64)
21 | if max_h < maxh: max_h += 64
22 | if max_w < maxw: max_w += 64
23 | maxh = max_h
24 | maxw = max_w
25 |
26 | mean_L = [[0.33,0.33,0.33]]
27 | mean_R = [[0.33,0.33,0.33]]
28 |
29 | # construct model, VCN-expansion
30 | from expansion.models.VCN_exp import VCN
31 | model = VCN([1, maxw, maxh], md=[int(4*(maxdisp/256)),4,4,4,4], fac=fac,
32 | exp_unc=('robust' in loadmodel)) # expansion uncertainty only in the new model
33 | model = nn.DataParallel(model, device_ids=[0])
34 | model.cuda()
35 |
36 | if loadmodel is not None:
37 | pretrained_dict = torch.load(loadmodel)
38 | mean_L=pretrained_dict['mean_L']
39 | mean_R=pretrained_dict['mean_R']
40 | pretrained_dict['state_dict'] = {k:v for k,v in pretrained_dict['state_dict'].items()}
41 | model.load_state_dict(pretrained_dict['state_dict'],strict=False)
42 | else:
43 | print('dry run')
44 |
45 | model.eval()
46 | # resize
47 | maxh = 256
48 | maxw = 256
49 | max_h = int(maxh // 64 * 64)
50 | max_w = int(maxw // 64 * 64)
51 | if max_h < maxh: max_h += 64
52 | if max_w < maxw: max_w += 64
53 |
54 | # modify module according to inputs
55 | from expansion.models.VCN_exp import WarpModule, flow_reg
56 | for i in range(len(model.module.reg_modules)):
57 | model.module.reg_modules[i] = flow_reg([1,max_w//(2**(6-i)), max_h//(2**(6-i))],
58 | ent=getattr(model.module, 'flow_reg%d'%2**(6-i)).ent,\
59 | maxdisp=getattr(model.module, 'flow_reg%d'%2**(6-i)).md,\
60 | fac=getattr(model.module, 'flow_reg%d'%2**(6-i)).fac).cuda()
61 | for i in range(len(model.module.warp_modules)):
62 | model.module.warp_modules[i] = WarpModule([1,max_w//(2**(6-i)), max_h//(2**(6-i))]).cuda()
63 |
64 | mean_L = torch.from_numpy(np.asarray(mean_L).astype(np.float32).mean(0)[np.newaxis,:,np.newaxis,np.newaxis]).cuda()
65 | mean_R = torch.from_numpy(np.asarray(mean_R).astype(np.float32).mean(0)[np.newaxis,:,np.newaxis,np.newaxis]).cuda()
66 |
67 | self.max_h = max_h
68 | self.max_w = max_w
69 | self.model = model
70 | self.mean_L = mean_L
71 | self.mean_R = mean_R
72 |
73 | def run(self, imgL_o, imgR_o):
74 | model = self.model
75 | mean_L = self.mean_L
76 | mean_R = self.mean_R
77 |
78 | imgL_o[imgL_o<-1] = -1
79 | imgL_o[imgL_o>1] = 1
80 | imgR_o[imgR_o<-1] = -1
81 | imgR_o[imgR_o>1] = 1
82 | imgL = (imgL_o+1.)*0.5-mean_L
83 | imgR = (imgR_o*1.)*0.5-mean_R
84 |
85 | with torch.no_grad():
86 | imgLR = torch.cat([imgL,imgR],0)
87 | model.eval()
88 | torch.cuda.synchronize()
89 | rts = model(imgLR)
90 | torch.cuda.synchronize()
91 | flow, occ, logmid, logexp = rts
92 |
93 | torch.cuda.empty_cache()
94 |
95 | return flow, logexp
96 |
--------------------------------------------------------------------------------
/expansion/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/utils/__init__.py
--------------------------------------------------------------------------------
/expansion/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/expansion/utils/__pycache__/flowlib.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/utils/__pycache__/flowlib.cpython-38.pyc
--------------------------------------------------------------------------------
/expansion/utils/__pycache__/io.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/utils/__pycache__/io.cpython-38.pyc
--------------------------------------------------------------------------------
/expansion/utils/__pycache__/pfm.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/utils/__pycache__/pfm.cpython-38.pyc
--------------------------------------------------------------------------------
/expansion/utils/__pycache__/util_flow.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/utils/__pycache__/util_flow.cpython-38.pyc
--------------------------------------------------------------------------------
/expansion/utils/io.py:
--------------------------------------------------------------------------------
1 | import errno
2 | import os
3 | import shutil
4 | import sys
5 | import traceback
6 | import zipfile
7 |
8 | if sys.version_info[0] == 2:
9 | import urllib2
10 | else:
11 | import urllib.request
12 |
13 |
14 | # Converts a string to bytes (for writing the string into a file). Provided for
15 | # compatibility with Python 2 and 3.
16 | def StrToBytes(text):
17 | if sys.version_info[0] == 2:
18 | return text
19 | else:
20 | return bytes(text, 'UTF-8')
21 |
22 |
23 | # Outputs the given text and lets the user input a response (submitted by
24 | # pressing the return key). Provided for compatibility with Python 2 and 3.
25 | def GetUserInput(text):
26 | if sys.version_info[0] == 2:
27 | return raw_input(text)
28 | else:
29 | return input(text)
30 |
31 |
32 | # Creates the given directory (hierarchy), which may already exist. Provided for
33 | # compatibility with Python 2 and 3.
34 | def MakeDirsExistOk(directory_path):
35 | try:
36 | os.makedirs(directory_path)
37 | except OSError as exception:
38 | if exception.errno != errno.EEXIST:
39 | raise
40 |
41 |
42 | # Deletes all files and folders within the given folder.
43 | def DeleteFolderContents(folder_path):
44 | for file_name in os.listdir(folder_path):
45 | file_path = os.path.join(folder_path, file_name)
46 | try:
47 | if os.path.isfile(file_path):
48 | os.unlink(file_path)
49 | else: #if os.path.isdir(file_path):
50 | shutil.rmtree(file_path)
51 | except Exception as e:
52 | print('Exception in DeleteFolderContents():')
53 | print(e)
54 | print('Stack trace:')
55 | print(traceback.format_exc())
56 |
57 |
58 | # Creates the given directory, respectively deletes all content of the directory
59 | # in case it already exists.
60 | def MakeCleanDirectory(folder_path):
61 | if os.path.isdir(folder_path):
62 | DeleteFolderContents(folder_path)
63 | else:
64 | MakeDirsExistOk(folder_path)
65 |
66 |
67 | # Downloads the given URL to a file in the given directory. Returns the
68 | # path to the downloaded file.
69 | # In part adapted from: https://stackoverflow.com/questions/22676
70 | def DownloadFile(url, dest_dir_path):
71 | file_name = url.split('/')[-1]
72 | dest_file_path = os.path.join(dest_dir_path, file_name)
73 |
74 | if os.path.isfile(dest_file_path):
75 | print('The following file already exists:')
76 | print(dest_file_path)
77 | print('Please choose whether to re-download and overwrite the file [o] or to skip downloading this file [s] by entering o or s.')
78 | while True:
79 | response = GetUserInput("> ")
80 | if response == 's':
81 | return dest_file_path
82 | elif response == 'o':
83 | break
84 | else:
85 | print('Please enter o or s.')
86 |
87 | url_object = None
88 | if sys.version_info[0] == 2:
89 | url_object = urllib2.urlopen(url)
90 | else:
91 | url_object = urllib.request.urlopen(url)
92 |
93 | with open(dest_file_path, 'wb') as outfile:
94 | meta = url_object.info()
95 | file_size = 0
96 | if sys.version_info[0] == 2:
97 | file_size = int(meta.getheaders("Content-Length")[0])
98 | else:
99 | file_size = int(meta["Content-Length"])
100 | print("Downloading: %s (size [bytes]: %s)" % (url, file_size))
101 |
102 | file_size_downloaded = 0
103 | block_size = 8192
104 | while True:
105 | buffer = url_object.read(block_size)
106 | if not buffer:
107 | break
108 |
109 | file_size_downloaded += len(buffer)
110 | outfile.write(buffer)
111 |
112 | sys.stdout.write("%d / %d (%3f%%)\r" % (file_size_downloaded, file_size, file_size_downloaded * 100. / file_size))
113 | sys.stdout.flush()
114 |
115 | return dest_file_path
116 |
117 |
118 | # Unzips the given zip file into the given directory.
119 | def UnzipFile(file_path, unzip_dir_path, overwrite=True):
120 | zip_ref = zipfile.ZipFile(open(file_path, 'rb'))
121 |
122 | if not overwrite:
123 | for f in zip_ref.namelist():
124 | if not os.path.isfile(os.path.join(unzip_dir_path, f)):
125 | zip_ref.extract(f, path=unzip_dir_path)
126 | else:
127 | print('Not overwriting {}'.format(f))
128 | else:
129 | zip_ref.extractall(unzip_dir_path)
130 | zip_ref.close()
131 |
132 |
133 | # Creates a zip file with the contents of the given directory.
134 | # The archive_base_path must not include the extension .zip. The full, final
135 | # path of the archive is returned by the function.
136 | def ZipDirectory(archive_base_path, root_dir_path):
137 | # return shutil.make_archive(archive_base_path, 'zip', root_dir_path) # THIS WILL ALWAYS HAVE ./ FOLDER INCLUDED
138 | with zipfile.ZipFile(archive_base_path+'.zip', "w", compression=zipfile.ZIP_DEFLATED) as zf:
139 | base_path = os.path.normpath(root_dir_path)
140 | for dirpath, dirnames, filenames in os.walk(root_dir_path):
141 | for name in sorted(dirnames):
142 | path = os.path.normpath(os.path.join(dirpath, name))
143 | zf.write(path, os.path.relpath(path, base_path))
144 | for name in filenames:
145 | path = os.path.normpath(os.path.join(dirpath, name))
146 | if os.path.isfile(path):
147 | zf.write(path, os.path.relpath(path, base_path))
148 |
149 | return archive_base_path+'.zip'
150 |
151 |
152 | # Downloads a zip file and directly unzips it.
153 | def DownloadAndUnzipFile(url, archive_dir_path, unzip_dir_path, overwrite=True):
154 | archive_path = DownloadFile(url, archive_dir_path)
155 | UnzipFile(archive_path, unzip_dir_path, overwrite=overwrite)
156 |
157 | def mkdir_p(path):
158 | try:
159 | os.makedirs(path)
160 | except OSError as exc: # Python >2.5
161 | if exc.errno == errno.EEXIST and os.path.isdir(path):
162 | pass
163 | else:
164 | raise
165 |
--------------------------------------------------------------------------------
/expansion/utils/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | File: logger.py
3 | Modified by: Senthil Purushwalkam
4 | Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514
5 | Email: spurushwandrewcmuedu
6 | Github: https://github.com/senthilps8
7 | Description:
8 | """
9 | import pdb
10 | import tensorflow as tf
11 | from torch.autograd import Variable
12 | import numpy as np
13 | import scipy.misc
14 | import os
15 | try:
16 | from StringIO import StringIO # Python 2.7
17 | except ImportError:
18 | from io import BytesIO # Python 3.x
19 |
20 |
21 | class Logger(object):
22 |
23 | def __init__(self, log_dir, name=None):
24 | """Create a summary writer logging to log_dir."""
25 | if name is None:
26 | name = 'temp'
27 | self.name = name
28 | if name is not None:
29 | try:
30 | os.makedirs(os.path.join(log_dir, name))
31 | except:
32 | pass
33 | self.writer = tf.summary.FileWriter(os.path.join(log_dir, name),
34 | filename_suffix=name)
35 | else:
36 | self.writer = tf.summary.FileWriter(log_dir, filename_suffix=name)
37 |
38 | def scalar_summary(self, tag, value, step):
39 | """Log a scalar variable."""
40 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
41 | self.writer.add_summary(summary, step)
42 |
43 | def image_summary(self, tag, images, step):
44 | """Log a list of images."""
45 |
46 | img_summaries = []
47 | for i, img in enumerate(images):
48 | # Write the image to a string
49 | try:
50 | s = StringIO()
51 | except:
52 | s = BytesIO()
53 | scipy.misc.toimage(img).save(s, format="png")
54 |
55 | # Create an Image object
56 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
57 | height=img.shape[0],
58 | width=img.shape[1])
59 | # Create a Summary value
60 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))
61 |
62 | # Create and write Summary
63 | summary = tf.Summary(value=img_summaries)
64 | self.writer.add_summary(summary, step)
65 |
66 | def histo_summary(self, tag, values, step, bins=1000):
67 | """Log a histogram of the tensor of values."""
68 |
69 | # Create a histogram using numpy
70 | counts, bin_edges = np.histogram(values, bins=bins)
71 |
72 | # Fill the fields of the histogram proto
73 | hist = tf.HistogramProto()
74 | hist.min = float(np.min(values))
75 | hist.max = float(np.max(values))
76 | hist.num = int(np.prod(values.shape))
77 | hist.sum = float(np.sum(values))
78 | hist.sum_squares = float(np.sum(values**2))
79 |
80 | # Drop the start of the first bin
81 | bin_edges = bin_edges[1:]
82 |
83 | # Add bin edges and counts
84 | for edge in bin_edges:
85 | hist.bucket_limit.append(edge)
86 | for c in counts:
87 | hist.bucket.append(c)
88 |
89 | # Create and write Summary
90 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
91 | self.writer.add_summary(summary, step)
92 | self.writer.flush()
93 |
94 | def to_np(self, x):
95 | return x.data.cpu().numpy()
96 |
97 | def to_var(self, x):
98 | if torch.cuda.is_available():
99 | x = x.cuda()
100 | return Variable(x)
101 |
102 | def model_param_histo_summary(self, model, step):
103 | """log histogram summary of model's parameters
104 | and parameter gradients
105 | """
106 | for tag, value in model.named_parameters():
107 | if value.grad is None:
108 | continue
109 | tag = tag.replace('.', '/')
110 | tag = self.name+'/'+tag
111 | self.histo_summary(tag, self.to_np(value), step)
112 | self.histo_summary(tag+'/grad', self.to_np(value.grad), step)
113 |
114 |
--------------------------------------------------------------------------------
/expansion/utils/multiscaleloss.py:
--------------------------------------------------------------------------------
1 | """
2 | Taken from https://github.com/ClementPinard/FlowNetPytorch
3 | """
4 | import pdb
5 | import torch
6 | import torch.nn.functional as F
7 |
8 |
9 | def EPE(input_flow, target_flow, mask, sparse=False, mean=True):
10 | #mask = target_flow[:,2]>0
11 | target_flow = target_flow[:,:2]
12 | EPE_map = torch.norm(target_flow-input_flow,2,1)
13 | batch_size = EPE_map.size(0)
14 | if sparse:
15 | # invalid flow is defined with both flow coordinates to be exactly 0
16 | mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0)
17 |
18 | EPE_map = EPE_map[~mask]
19 | if mean:
20 | return EPE_map[mask].mean()
21 | else:
22 | return EPE_map[mask].sum()/batch_size
23 |
24 | def rob_EPE(input_flow, target_flow, mask, sparse=False, mean=True):
25 | #mask = target_flow[:,2]>0
26 | target_flow = target_flow[:,:2]
27 | #TODO
28 | # EPE_map = torch.norm(target_flow-input_flow,2,1)
29 | EPE_map = (torch.norm(target_flow-input_flow,1,1)+0.01).pow(0.4)
30 | batch_size = EPE_map.size(0)
31 | if sparse:
32 | # invalid flow is defined with both flow coordinates to be exactly 0
33 | mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0)
34 |
35 | EPE_map = EPE_map[~mask]
36 | if mean:
37 | return EPE_map[mask].mean()
38 | else:
39 | return EPE_map[mask].sum()/batch_size
40 |
41 | def sparse_max_pool(input, size):
42 | '''Downsample the input by considering 0 values as invalid.
43 |
44 | Unfortunately, no generic interpolation mode can resize a sparse map correctly,
45 | the strategy here is to use max pooling for positive values and "min pooling"
46 | for negative values, the two results are then summed.
47 | This technique allows sparsity to be minized, contrary to nearest interpolation,
48 | which could potentially lose information for isolated data points.'''
49 |
50 | positive = (input > 0).float()
51 | negative = (input < 0).float()
52 | output = F.adaptive_max_pool2d(input * positive, size) - F.adaptive_max_pool2d(-input * negative, size)
53 | return output
54 |
55 |
56 | def multiscaleEPE(network_output, target_flow, mask, weights=None, sparse=False, rob_loss = False):
57 | def one_scale(output, target, mask, sparse):
58 |
59 | b, _, h, w = output.size()
60 |
61 | if sparse:
62 | target_scaled = sparse_max_pool(target, (h, w))
63 | else:
64 | target_scaled = F.interpolate(target, (h, w), mode='area')
65 | mask = F.interpolate(mask.float().unsqueeze(1), (h, w), mode='bilinear').squeeze(1)==1
66 | if rob_loss:
67 | return rob_EPE(output, target_scaled, mask, sparse, mean=False)
68 | else:
69 | return EPE(output, target_scaled, mask, sparse, mean=False)
70 |
71 | if type(network_output) not in [tuple, list]:
72 | network_output = [network_output]
73 | if weights is None:
74 | weights = [0.005, 0.01, 0.02, 0.08, 0.32] # as in original article
75 | assert(len(weights) == len(network_output))
76 |
77 | loss = 0
78 | for output, weight in zip(network_output, weights):
79 | loss += weight * one_scale(output, target_flow, mask, sparse)
80 | return loss
81 |
82 |
83 | def realEPE(output, target, mask, sparse=False):
84 | b, _, h, w = target.size()
85 | upsampled_output = F.interpolate(output, (h,w), mode='bilinear', align_corners=False)
86 | return EPE(upsampled_output, target,mask, sparse, mean=True)
87 |
--------------------------------------------------------------------------------
/expansion/utils/pfm.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | import sys
4 |
5 | def readPFM(file):
6 | file = open(file, 'rb')
7 |
8 | color = None
9 | width = None
10 | height = None
11 | scale = None
12 | endian = None
13 |
14 | header = file.readline().rstrip()
15 | if (sys.version[0]) == '3':
16 | header = header.decode('utf-8')
17 | if header == 'PF':
18 | color = True
19 | elif header == 'Pf':
20 | color = False
21 | else:
22 | raise Exception('Not a PFM file.')
23 |
24 | if (sys.version[0]) == '3':
25 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
26 | else:
27 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline())
28 | if dim_match:
29 | width, height = map(int, dim_match.groups())
30 | else:
31 | raise Exception('Malformed PFM header.')
32 |
33 | if (sys.version[0]) == '3':
34 | scale = float(file.readline().rstrip().decode('utf-8'))
35 | else:
36 | scale = float(file.readline().rstrip())
37 |
38 | if scale < 0: # little-endian
39 | endian = '<'
40 | scale = -scale
41 | else:
42 | endian = '>' # big-endian
43 |
44 | data = np.fromfile(file, endian + 'f')
45 | shape = (height, width, 3) if color else (height, width)
46 |
47 | data = np.reshape(data, shape)
48 | data = np.flipud(data)
49 | return data, scale
50 |
51 |
52 | def writePFM(file, image, scale=1):
53 | file = open(file, 'wb')
54 |
55 | color = None
56 |
57 | if image.dtype.name != 'float32':
58 | raise Exception('Image dtype must be float32.')
59 |
60 | image = np.flipud(image)
61 |
62 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
63 | color = True
64 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
65 | color = False
66 | else:
67 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
68 |
69 | file.write('PF\n' if color else 'Pf\n')
70 | file.write('%d %d\n' % (image.shape[1], image.shape[0]))
71 |
72 | endian = image.dtype.byteorder
73 |
74 | if endian == '<' or endian == '=' and sys.byteorder == 'little':
75 | scale = -scale
76 |
77 | file.write('%f\n' % scale)
78 |
79 | image.tofile(file)
80 |
--------------------------------------------------------------------------------
/expansion/utils/readpfm.py:
--------------------------------------------------------------------------------
1 | import re
2 | import numpy as np
3 | import sys
4 |
5 |
6 | def readPFM(file):
7 | file = open(file, 'rb')
8 |
9 | color = None
10 | width = None
11 | height = None
12 | scale = None
13 | endian = None
14 |
15 | header = file.readline().rstrip()
16 | if (sys.version[0]) == '3':
17 | header = header.decode('utf-8')
18 | if header == 'PF':
19 | color = True
20 | elif header == 'Pf':
21 | color = False
22 | else:
23 | raise Exception('Not a PFM file.')
24 |
25 | if (sys.version[0]) == '3':
26 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
27 | else:
28 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline())
29 | if dim_match:
30 | width, height = map(int, dim_match.groups())
31 | else:
32 | raise Exception('Malformed PFM header.')
33 |
34 | if (sys.version[0]) == '3':
35 | scale = float(file.readline().rstrip().decode('utf-8'))
36 | else:
37 | scale = float(file.readline().rstrip())
38 |
39 | if scale < 0: # little-endian
40 | endian = '<'
41 | scale = -scale
42 | else:
43 | endian = '>' # big-endian
44 |
45 | data = np.fromfile(file, endian + 'f')
46 | shape = (height, width, 3) if color else (height, width)
47 |
48 | data = np.reshape(data, shape)
49 | data = np.flipud(data)
50 | return data, scale
51 |
52 |
--------------------------------------------------------------------------------
/expansion/utils/sintel_io.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python2
2 |
3 | """
4 | I/O script to save and load the data coming with the MPI-Sintel low-level
5 | computer vision benchmark.
6 |
7 | For more details about the benchmark, please visit www.mpi-sintel.de
8 |
9 | CHANGELOG:
10 | v1.0 (2015/02/03): First release
11 |
12 | Copyright (c) 2015 Jonas Wulff
13 | Max Planck Institute for Intelligent Systems, Tuebingen, Germany
14 |
15 | """
16 |
17 | # Requirements: Numpy as PIL/Pillow
18 | import numpy as np
19 | from PIL import Image
20 |
21 | # Check for endianness, based on Daniel Scharstein's optical flow code.
22 | # Using little-endian architecture, these two should be equal.
23 | TAG_FLOAT = 202021.25
24 | TAG_CHAR = 'PIEH'
25 |
26 | def flow_read(filename):
27 | """ Read optical flow from file, return (U,V) tuple.
28 |
29 | Original code by Deqing Sun, adapted from Daniel Scharstein.
30 | """
31 | f = open(filename,'rb')
32 | check = np.fromfile(f,dtype=np.float32,count=1)[0]
33 | assert check == TAG_FLOAT, ' flow_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check)
34 | width = np.fromfile(f,dtype=np.int32,count=1)[0]
35 | height = np.fromfile(f,dtype=np.int32,count=1)[0]
36 | size = width*height
37 | assert width > 0 and height > 0 and size > 1 and size < 100000000, ' flow_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height)
38 | tmp = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width*2))
39 | u = tmp[:,np.arange(width)*2]
40 | v = tmp[:,np.arange(width)*2 + 1]
41 | return u,v
42 |
43 | def flow_write(filename,uv,v=None):
44 | """ Write optical flow to file.
45 |
46 | If v is None, uv is assumed to contain both u and v channels,
47 | stacked in depth.
48 |
49 | Original code by Deqing Sun, adapted from Daniel Scharstein.
50 | """
51 | nBands = 2
52 |
53 | if v is None:
54 | assert(uv.ndim == 3)
55 | assert(uv.shape[2] == 2)
56 | u = uv[:,:,0]
57 | v = uv[:,:,1]
58 | else:
59 | u = uv
60 |
61 | assert(u.shape == v.shape)
62 | height,width = u.shape
63 | f = open(filename,'wb')
64 | # write the header
65 | f.write(TAG_CHAR)
66 | np.array(width).astype(np.int32).tofile(f)
67 | np.array(height).astype(np.int32).tofile(f)
68 | # arrange into matrix form
69 | tmp = np.zeros((height, width*nBands))
70 | tmp[:,np.arange(width)*2] = u
71 | tmp[:,np.arange(width)*2 + 1] = v
72 | tmp.astype(np.float32).tofile(f)
73 | f.close()
74 |
75 |
76 | def depth_read(filename):
77 | """ Read depth data from file, return as numpy array. """
78 | f = open(filename,'rb')
79 | check = np.fromfile(f,dtype=np.float32,count=1)[0]
80 | assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check)
81 | width = np.fromfile(f,dtype=np.int32,count=1)[0]
82 | height = np.fromfile(f,dtype=np.int32,count=1)[0]
83 | size = width*height
84 | assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height)
85 | depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width))
86 | return depth
87 |
88 | def depth_write(filename, depth):
89 | """ Write depth to file. """
90 | height,width = depth.shape[:2]
91 | f = open(filename,'wb')
92 | # write the header
93 | f.write(TAG_CHAR)
94 | np.array(width).astype(np.int32).tofile(f)
95 | np.array(height).astype(np.int32).tofile(f)
96 |
97 | depth.astype(np.float32).tofile(f)
98 | f.close()
99 |
100 |
101 | def disparity_write(filename,disparity,bitdepth=16):
102 | """ Write disparity to file.
103 |
104 | bitdepth can be either 16 (default) or 32.
105 |
106 | The maximum disparity is 1024, since the image width in Sintel
107 | is 1024.
108 | """
109 | d = disparity.copy()
110 |
111 | # Clip disparity.
112 | d[d>1024] = 1024
113 | d[d<0] = 0
114 |
115 | d_r = (d / 4.0).astype('uint8')
116 | d_g = ((d * (2.0**6)) % 256).astype('uint8')
117 |
118 | out = np.zeros((d.shape[0],d.shape[1],3),dtype='uint8')
119 | out[:,:,0] = d_r
120 | out[:,:,1] = d_g
121 |
122 | if bitdepth > 16:
123 | d_b = (d * (2**14) % 256).astype('uint8')
124 | out[:,:,2] = d_b
125 |
126 | Image.fromarray(out,'RGB').save(filename,'PNG')
127 |
128 |
129 | def disparity_read(filename):
130 | """ Return disparity read from filename. """
131 | f_in = np.array(Image.open(filename))
132 | d_r = f_in[:,:,0].astype('float64')
133 | d_g = f_in[:,:,1].astype('float64')
134 | d_b = f_in[:,:,2].astype('float64')
135 |
136 | depth = d_r * 4 + d_g / (2**6) + d_b / (2**14)
137 | return depth
138 |
139 |
140 | #def cam_read(filename):
141 | # """ Read camera data, return (M,N) tuple.
142 | #
143 | # M is the intrinsic matrix, N is the extrinsic matrix, so that
144 | #
145 | # x = M*N*X,
146 | # with x being a point in homogeneous image pixel coordinates, X being a
147 | # point in homogeneous world coordinates.
148 | # """
149 | # txtdata = np.loadtxt(filename)
150 | # intrinsic = txtdata[0,:9].reshape((3,3))
151 | # extrinsic = textdata[1,:12].reshape((3,4))
152 | # return intrinsic,extrinsic
153 | #
154 | #
155 | #def cam_write(filename,M,N):
156 | # """ Write intrinsic matrix M and extrinsic matrix N to file. """
157 | # Z = np.zeros((2,12))
158 | # Z[0,:9] = M.ravel()
159 | # Z[1,:12] = N.ravel()
160 | # np.savetxt(filename,Z)
161 |
162 | def cam_read(filename):
163 | """ Read camera data, return (M,N) tuple.
164 |
165 | M is the intrinsic matrix, N is the extrinsic matrix, so that
166 |
167 | x = M*N*X,
168 | with x being a point in homogeneous image pixel coordinates, X being a
169 | point in homogeneous world coordinates.
170 | """
171 | f = open(filename,'rb')
172 | check = np.fromfile(f,dtype=np.float32,count=1)[0]
173 | assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check)
174 | M = np.fromfile(f,dtype='float64',count=9).reshape((3,3))
175 | N = np.fromfile(f,dtype='float64',count=12).reshape((3,4))
176 | return M,N
177 |
178 | def cam_write(filename, M, N):
179 | """ Write intrinsic matrix M and extrinsic matrix N to file. """
180 | f = open(filename,'wb')
181 | # write the header
182 | f.write(TAG_CHAR)
183 | M.astype('float64').tofile(f)
184 | N.astype('float64').tofile(f)
185 | f.close()
186 |
187 |
188 | def segmentation_write(filename,segmentation):
189 | """ Write segmentation to file. """
190 |
191 | segmentation_ = segmentation.astype('int32')
192 | seg_r = np.floor(segmentation_ / (256**2)).astype('uint8')
193 | seg_g = np.floor((segmentation_ % (256**2)) / 256).astype('uint8')
194 | seg_b = np.floor(segmentation_ % 256).astype('uint8')
195 |
196 | out = np.zeros((segmentation.shape[0],segmentation.shape[1],3),dtype='uint8')
197 | out[:,:,0] = seg_r
198 | out[:,:,1] = seg_g
199 | out[:,:,2] = seg_b
200 |
201 | Image.fromarray(out,'RGB').save(filename,'PNG')
202 |
203 |
204 | def segmentation_read(filename):
205 | """ Return disparity read from filename. """
206 | f_in = np.array(Image.open(filename))
207 | seg_r = f_in[:,:,0].astype('int32')
208 | seg_g = f_in[:,:,1].astype('int32')
209 | seg_b = f_in[:,:,2].astype('int32')
210 |
211 | segmentation = (seg_r * 256 + seg_g) * 256 + seg_b
212 | return segmentation
213 |
214 |
215 |
--------------------------------------------------------------------------------
/expansion/utils/util_flow.py:
--------------------------------------------------------------------------------
1 | import math
2 | import png
3 | import struct
4 | import array
5 | import numpy as np
6 | import cv2
7 | import pdb
8 |
9 | from io import *
10 |
11 | UNKNOWN_FLOW_THRESH = 1e9;
12 | UNKNOWN_FLOW = 1e10;
13 |
14 | # Middlebury checks
15 | TAG_STRING = 'PIEH' # use this when WRITING the file
16 | TAG_FLOAT = 202021.25 # check for this when READING the file
17 |
18 | def readPFM(file):
19 | import re
20 | file = open(file, 'rb')
21 |
22 | color = None
23 | width = None
24 | height = None
25 | scale = None
26 | endian = None
27 |
28 | header = file.readline().rstrip()
29 | if header == b'PF':
30 | color = True
31 | elif header == b'Pf':
32 | color = False
33 | else:
34 | raise Exception('Not a PFM file.')
35 |
36 | dim_match = re.match(b'^(\d+)\s(\d+)\s$', file.readline())
37 | if dim_match:
38 | width, height = map(int, dim_match.groups())
39 | else:
40 | raise Exception('Malformed PFM header.')
41 |
42 | scale = float(file.readline().rstrip())
43 | if scale < 0: # little-endian
44 | endian = '<'
45 | scale = -scale
46 | else:
47 | endian = '>' # big-endian
48 |
49 | data = np.fromfile(file, endian + 'f')
50 | shape = (height, width, 3) if color else (height, width)
51 |
52 | data = np.reshape(data, shape)
53 | data = np.flipud(data)
54 | return data, scale
55 |
56 |
57 | def save_pfm(file, image, scale = 1):
58 | import sys
59 | color = None
60 |
61 | if image.dtype.name != 'float32':
62 | raise Exception('Image dtype must be float32.')
63 |
64 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
65 | color = True
66 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
67 | color = False
68 | else:
69 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
70 |
71 | file.write('PF\n' if color else 'Pf\n')
72 | file.write('%d %d\n' % (image.shape[1], image.shape[0]))
73 |
74 | endian = image.dtype.byteorder
75 |
76 | if endian == '<' or endian == '=' and sys.byteorder == 'little':
77 | scale = -scale
78 |
79 | file.write('%f\n' % scale)
80 |
81 | image.tofile(file)
82 |
83 |
84 | def ReadMiddleburyFloFile(path):
85 | """ Read .FLO file as specified by Middlebury.
86 |
87 | Returns tuple (width, height, u, v, mask), where u, v, mask are flat
88 | arrays of values.
89 | """
90 |
91 | with open(path, 'rb') as fil:
92 | tag = struct.unpack('f', fil.read(4))[0]
93 | width = struct.unpack('i', fil.read(4))[0]
94 | height = struct.unpack('i', fil.read(4))[0]
95 |
96 | assert tag == TAG_FLOAT
97 |
98 | #data = np.fromfile(path, dtype=np.float, count=-1)
99 | #data = data[3:]
100 |
101 | fmt = 'f' * width*height*2
102 | data = struct.unpack(fmt, fil.read(4*width*height*2))
103 |
104 | u = data[::2]
105 | v = data[1::2]
106 |
107 | mask = map(lambda x,y: abs(x) 0:
144 | # print(u[ind], v[ind], mask[ind], row[3*x], row[3*x+1], row[3*x+2])
145 |
146 | #png_reader.close()
147 |
148 | return (width, height, u, v, mask)
149 |
150 |
151 | def WriteMiddleburyFloFile(path, width, height, u, v, mask=None):
152 | """ Write .FLO file as specified by Middlebury.
153 | """
154 |
155 | if mask is not None:
156 | u_masked = map(lambda x,y: x if y else UNKNOWN_FLOW, u, mask)
157 | v_masked = map(lambda x,y: x if y else UNKNOWN_FLOW, v, mask)
158 | else:
159 | u_masked = u
160 | v_masked = v
161 |
162 | fmt = 'f' * width*height*2
163 | # Interleave lists
164 | data = [x for t in zip(u_masked,v_masked) for x in t]
165 |
166 | with open(path, 'wb') as fil:
167 | fil.write(str.encode(TAG_STRING))
168 | fil.write(struct.pack('i', width))
169 | fil.write(struct.pack('i', height))
170 | fil.write(struct.pack(fmt, *data))
171 |
172 |
173 | def write_flow(path,flow):
174 |
175 | invalid_idx = (flow[:, :, 2] == 0)
176 | flow[:, :, 0:2] = flow[:, :, 0:2]*64.+ 2 ** 15
177 | flow[invalid_idx, 0] = 0
178 | flow[invalid_idx, 1] = 0
179 |
180 | flow = flow.astype(np.uint16)
181 | flow = cv2.imwrite(path, flow[:,:,::-1])
182 |
183 | #WriteKittiPngFile(path,
184 | # flow.shape[1], flow.shape[0], flow[:,:,0].flatten(),
185 | # flow[:,:,1].flatten(), flow[:,:,2].flatten())
186 |
187 |
188 |
189 | def WriteKittiPngFile(path, width, height, u, v, mask=None):
190 | """ Write 16-bit .PNG file as specified by KITTI-2015 (flow).
191 |
192 | u, v are lists of float values
193 | mask is a list of floats, denoting the *valid* pixels.
194 | """
195 |
196 | data = array.array('H',[0])*width*height*3
197 |
198 | for i,(u_,v_,mask_) in enumerate(zip(u,v,mask)):
199 | data[3*i] = int(u_*64.0+2**15)
200 | data[3*i+1] = int(v_*64.0+2**15)
201 | data[3*i+2] = int(mask_)
202 |
203 | # if mask_ > 0:
204 | # print(data[3*i], data[3*i+1],data[3*i+2])
205 |
206 | with open(path, 'wb') as png_file:
207 | png_writer = png.Writer(width=width, height=height, bitdepth=16, compression=3, greyscale=False)
208 | png_writer.write_array(png_file, data)
209 |
210 |
211 | def ConvertMiddleburyFloToKittiPng(src_path, dest_path):
212 | width, height, u, v, mask = ReadMiddleburyFloFile(src_path)
213 | WriteKittiPngFile(dest_path, width, height, u, v, mask=mask)
214 |
215 | def ConvertKittiPngToMiddleburyFlo(src_path, dest_path):
216 | width, height, u, v, mask = ReadKittiPngFile(src_path)
217 | WriteMiddleburyFloFile(dest_path, width, height, u, v, mask=mask)
218 |
219 |
220 | def ParseFilenameKitti(filename):
221 | # Parse kitti filename (seq_frameno.xx),
222 | # return seq, frameno, ext.
223 | # Be aware that seq might contain the dataset name (if contained as prefix)
224 | ext = filename[filename.rfind('.'):]
225 | frameno = filename[filename.rfind('_')+1:filename.rfind('.')]
226 | frameno = int(frameno)
227 | seq = filename[:filename.rfind('_')]
228 | return seq, frameno, ext
229 |
230 |
231 | def read_calib_file(filepath):
232 | """Read in a calibration file and parse into a dictionary."""
233 | data = {}
234 |
235 | with open(filepath, 'r') as f:
236 | for line in f.readlines():
237 | key, value = line.split(':', 1)
238 | # The only non-float values in these files are dates, which
239 | # we don't care about anyway
240 | try:
241 | data[key] = np.array([float(x) for x in value.split()])
242 | except ValueError:
243 | pass
244 |
245 | return data
246 |
247 | def load_calib_cam_to_cam(cam_to_cam_file):
248 | # We'll return the camera calibration as a dictionary
249 | data = {}
250 |
251 | # Load and parse the cam-to-cam calibration data
252 | filedata = read_calib_file(cam_to_cam_file)
253 |
254 | # Create 3x4 projection matrices
255 | P_rect_00 = np.reshape(filedata['P_rect_00'], (3, 4))
256 | P_rect_10 = np.reshape(filedata['P_rect_01'], (3, 4))
257 | P_rect_20 = np.reshape(filedata['P_rect_02'], (3, 4))
258 | P_rect_30 = np.reshape(filedata['P_rect_03'], (3, 4))
259 |
260 | # Compute the camera intrinsics
261 | data['K_cam0'] = P_rect_00[0:3, 0:3]
262 | data['K_cam1'] = P_rect_10[0:3, 0:3]
263 | data['K_cam2'] = P_rect_20[0:3, 0:3]
264 | data['K_cam3'] = P_rect_30[0:3, 0:3]
265 |
266 | data['b00'] = P_rect_00[0, 3] / P_rect_00[0, 0]
267 | data['b10'] = P_rect_10[0, 3] / P_rect_10[0, 0]
268 | data['b20'] = P_rect_20[0, 3] / P_rect_20[0, 0]
269 | data['b30'] = P_rect_30[0, 3] / P_rect_30[0, 0]
270 |
271 | return data
272 |
273 |
--------------------------------------------------------------------------------
/interface/flask_app.py:
--------------------------------------------------------------------------------
1 | from flask import Flask, render_template, request, redirect, url_for, abort
2 | import json
3 |
4 | app = Flask(__name__)
5 |
6 | import sys
7 | sys.path.append(".")
8 | sys.path.append("..")
9 |
10 | import argparse
11 | from PIL import Image, ImageOps
12 | import numpy as np
13 | import base64
14 | import cv2
15 | from inference import demo
16 |
17 | def Base64ToNdarry(img_base64):
18 | img_data = base64.b64decode(img_base64)
19 | img_np = np.fromstring(img_data, np.uint8)
20 | src = cv2.imdecode(img_np, cv2.IMREAD_ANYCOLOR)
21 |
22 | return src
23 |
24 | def NdarrayToBase64(dst):
25 | result, dst_data = cv2.imencode('.png', dst)
26 | dst_base64 = base64.b64encode(dst_data)
27 |
28 | return dst_base64
29 |
30 | parser = argparse.ArgumentParser(description='User controllable latent transformer')
31 | parser.add_argument('--checkpoint_path', default='pretrained_models/latent_transformer/cat.pt')
32 | args = parser.parse_args()
33 |
34 | demo = demo(args.checkpoint_path)
35 |
36 | @app.route("/", methods=["GET", "POST"])
37 | #@auth.login_required
38 | def init():
39 | if request.method == "GET":
40 | input_img = demo.run()
41 | input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
42 | return render_template("index.html", filepath1=input_base64, canvas_img=input_base64, result=True)
43 | if request.method == "POST":
44 | if 'zi' in request.form.keys():
45 | input_img = demo.move(z=-0.05)
46 | elif 'zo' in request.form.keys():
47 | input_img = demo.move(z=0.05)
48 | elif 'u' in request.form.keys():
49 | input_img = demo.move(y=-0.5, z=-0.0)
50 | elif 'd' in request.form.keys():
51 | input_img = demo.move(y=0.5, z=-0.0)
52 | elif 'l' in request.form.keys():
53 | input_img = demo.move(x=-0.5, z=-0.0)
54 | elif 'r' in request.form.keys():
55 | input_img = demo.move(x=0.5, z=-0.0)
56 | else:
57 | input_img = demo.run()
58 |
59 | input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
60 | return render_template("index.html", filepath1=input_base64, canvas_img=input_base64, result=True)
61 |
62 | @app.route('/zoom', methods=["POST"])
63 | def zoom_func():
64 |
65 | dz = json.loads(request.form['dz'])
66 | sx = json.loads(request.form['sx'])
67 | sy = json.loads(request.form['sy'])
68 | stop_points = json.loads(request.form['stop_points'])
69 |
70 | input_img = demo.zoom(dz,sxsy=[sx,sy],stop_points=stop_points)
71 | input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
72 | res = {'img':input_base64}
73 | return json.dumps(res)
74 |
75 | @app.route('/translate', methods=["POST"])
76 | def translate_func():
77 |
78 | dx = json.loads(request.form['dx'])
79 | dy = json.loads(request.form['dy'])
80 | dz = json.loads(request.form['dz'])
81 | sx = json.loads(request.form['sx'])
82 | sy = json.loads(request.form['sy'])
83 | stop_points = json.loads(request.form['stop_points'])
84 | zi = json.loads(request.form['zi'])
85 | zo = json.loads(request.form['zo'])
86 |
87 | input_img = demo.translate([dx,dy],sxsy=[sx,sy],stop_points=stop_points,zoom_in=zi,zoom_out=zo)
88 | input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
89 | res = {'img':input_base64}
90 | return json.dumps(res)
91 |
92 | @app.route('/changestyle', methods=["POST"])
93 | def changestyle_func():
94 | input_img = demo.change_style()
95 | input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
96 | res = {'img':input_base64}
97 | return json.dumps(res)
98 |
99 | @app.route('/reset', methods=["POST"])
100 | def reset_func():
101 | input_img = demo.reset()
102 | input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode()
103 | res = {'img':input_base64}
104 | return json.dumps(res)
105 |
106 | if __name__ == "__main__":
107 | app.run(debug=False, host='0.0.0.0', port=8000)
--------------------------------------------------------------------------------
/interface/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import Namespace
3 | import numpy as np
4 | import torch
5 | import sys
6 |
7 | sys.path.append(".")
8 | sys.path.append("..")
9 |
10 | from models.StyleGANControler import StyleGANControler
11 |
12 | class demo():
13 |
14 | def __init__(self, checkpoint_path, truncation = 0.5, use_average_code_as_input = False):
15 | self.truncation = truncation
16 | self.use_average_code_as_input = use_average_code_as_input
17 | ckpt = torch.load(checkpoint_path, map_location='cpu')
18 | opts = ckpt['opts']
19 | opts['checkpoint_path'] = checkpoint_path
20 | self.opts = Namespace(**ckpt['opts'])
21 |
22 | self.net = StyleGANControler(self.opts)
23 | self.net.eval()
24 | self.net.cuda()
25 | self.target_layers = [0,1,2,3,4,5]
26 |
27 | self.w1 = None
28 | self.w1_after = None
29 | self.f1 = None
30 |
31 | def run(self):
32 | z1 = torch.randn(1,512).to("cuda")
33 | x1, self.w1, self.f1 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_feature_map=True,return_latents=True,truncation=self.truncation, truncation_latent=self.net.latent_avg[0])
34 | self.w1_after = self.w1.clone()
35 | x1 = self.net.face_pool(x1)
36 | result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
37 | return result
38 |
39 | def translate(self, dxy, sxsy=[0,0], stop_points=[], zoom_in=False, zoom_out=False):
40 | dz = -5. if zoom_in else 0.
41 | dz = 5. if zoom_out else dz
42 |
43 | dxyz = np.array([dxy[0],dxy[1],dz], dtype=np.float32)
44 | dxy_norm = np.linalg.norm(dxyz[:2], ord=2)
45 | dxyz[:2] = dxyz[:2]/dxy_norm
46 | vec_num = dxy_norm/10
47 |
48 | x = torch.from_numpy(np.array([[dxyz]],dtype=np.float32)).cuda()
49 | f1 = torch.nn.functional.interpolate(self.f1, (256,256))
50 | y = f1[:,:,sxsy[1],sxsy[0]].unsqueeze(0)
51 |
52 | if len(stop_points)>0:
53 | x = torch.cat([x, torch.zeros(x.shape[0],len(stop_points),x.shape[2]).cuda()], dim=1)
54 | tmp = []
55 | for sp in stop_points:
56 | tmp.append(f1[:,:,sp[1],sp[0]].unsqueeze(1))
57 | y = torch.cat([y,torch.cat(tmp, dim=1)],dim=1)
58 |
59 | if not self.use_average_code_as_input:
60 | w_hat = self.net.encoder(self.w1[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num)
61 | w1 = self.w1.clone()
62 | w1[:,self.target_layers] = w_hat
63 | else:
64 | w_hat = self.net.encoder(self.net.latent_avg.unsqueeze(0)[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num)
65 | w1 = self.w1.clone()
66 | w1[:,self.target_layers] = self.w1.clone()[:,self.target_layers] + w_hat - self.net.latent_avg.unsqueeze(0)[:,self.target_layers]
67 |
68 | x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
69 |
70 | self.w1_after = w1.clone()
71 | x1 = self.net.face_pool(x1)
72 | result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
73 | return result
74 |
75 | def zoom(self, dz, sxsy=[0,0], stop_points=[]):
76 | vec_num = abs(dz)/5
77 | dz = 100*np.sign(dz)
78 | x = torch.from_numpy(np.array([[[1.,0,dz]]],dtype=np.float32)).cuda()
79 | f1 = torch.nn.functional.interpolate(self.f1, (256,256))
80 | y = f1[:,:,sxsy[1],sxsy[0]].unsqueeze(0)
81 |
82 | if len(stop_points)>0:
83 | x = torch.cat([x, torch.zeros(x.shape[0],len(stop_points),x.shape[2]).cuda()], dim=1)
84 | tmp = []
85 | for sp in stop_points:
86 | tmp.append(f1[:,:,sp[1],sp[0]].unsqueeze(1))
87 | y = torch.cat([y,torch.cat(tmp, dim=1)],dim=1)
88 |
89 | if not self.use_average_code_as_input:
90 | w_hat = self.net.encoder(self.w1[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num)
91 | w1 = self.w1.clone()
92 | w1[:,self.target_layers] = w_hat
93 | else:
94 | w_hat = self.net.encoder(self.net.latent_avg.unsqueeze(0)[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num)
95 | w1 = self.w1.clone()
96 | w1[:,self.target_layers] = self.w1.clone()[:,self.target_layers] + w_hat - self.net.latent_avg.unsqueeze(0)[:,self.target_layers]
97 |
98 |
99 | x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
100 |
101 | x1 = self.net.face_pool(x1)
102 | result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
103 | return result
104 |
105 | def change_style(self):
106 | z1 = torch.randn(1,512).to("cuda")
107 | x1, w2 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_latents=True, truncation=self.truncation, truncation_latent=self.net.latent_avg[0])
108 | self.w1_after[:,6:] = w2.detach()[:,0]
109 | x1, _ = self.net.decoder([self.w1_after], input_is_latent=True, randomize_noise=False, return_latents=False)
110 | result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
111 | return result
112 |
113 | def reset(self):
114 | x1, _ = self.net.decoder([self.w1], input_is_latent=True, randomize_noise=False, return_latents=False)
115 | result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
116 | return result
117 |
--------------------------------------------------------------------------------
/interface/templates/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
19 |
20 | Mouse drag: | Translation |
21 | Middle mouse button: | Set anchor point |
22 | Mouse wheel: | Zoom in & out |
23 | 'i' or 'o' key + mouse drag: | Translation with zooming in & out |
24 | 's' key: | style mixing |
25 |
26 |
27 |
28 |
194 |
195 |
--------------------------------------------------------------------------------
/licenses/LICENSE_ gengshan-y_expansion:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Carnegie Mellon University
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/licenses/LICENSE_HuangYG123:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 HuangYG123
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/licenses/LICENSE_S-aiueo32:
--------------------------------------------------------------------------------
1 | BSD 2-Clause License
2 |
3 | Copyright (c) 2020, Sou Uchida
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/licenses/LICENSE_TreB1eN:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 TreB1eN
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/licenses/LICENSE_lessw2020:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/licenses/LICENSE_pixel2style2pixel:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Elad Richardson, Yuval Alaluf
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/licenses/LICENSE_rosinality:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Kim Seonghyeon
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/models/StyleGANControler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from models.networks import latent_transformer
4 | from models.stylegan2.model import Generator
5 | import numpy as np
6 |
7 | def get_keys(d, name):
8 | if 'state_dict' in d:
9 | d = d['state_dict']
10 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name}
11 | return d_filt
12 |
13 |
14 | class StyleGANControler(nn.Module):
15 |
16 | def __init__(self, opts):
17 | super(StyleGANControler, self).__init__()
18 | self.set_opts(opts)
19 | # Define architecture
20 |
21 | if 'ffhq' in self.opts.stylegan_weights:
22 | self.style_num = 18
23 | elif 'car' in self.opts.stylegan_weights:
24 | self.style_num = 16
25 | elif 'cat' in self.opts.stylegan_weights:
26 | self.style_num = 14
27 | elif 'church' in self.opts.stylegan_weights:
28 | self.style_num = 14
29 | elif 'anime' in self.opts.stylegan_weights:
30 | self.style_num = 16
31 | else:
32 | self.style_num = 18 #Please modify to adjust network architecture to your pre-trained StyleGAN2
33 |
34 | self.encoder = self.set_encoder()
35 | if self.style_num==18:
36 | self.decoder = Generator(1024, 512, 8, channel_multiplier=2)
37 | elif self.style_num==16:
38 | self.decoder = Generator(512, 512, 8, channel_multiplier=2)
39 | elif self.style_num==14:
40 | self.decoder = Generator(256, 512, 8, channel_multiplier=2)
41 |
42 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
43 |
44 | # Load weights if needed
45 | self.load_weights()
46 |
47 | def set_encoder(self):
48 | encoder = latent_transformer.Network(self.opts)
49 | return encoder
50 |
51 | def load_weights(self):
52 | if self.opts.checkpoint_path is not None:
53 | print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path))
54 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu')
55 | self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True)
56 | self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True)
57 | self.__load_latent_avg(ckpt)
58 | else:
59 | print('Loading decoder weights from pretrained!')
60 | ckpt = torch.load(self.opts.stylegan_weights)
61 | self.decoder.load_state_dict(ckpt['g_ema'], strict=True)
62 | self.__load_latent_avg(ckpt, repeat=self.opts.style_num)
63 |
64 | def set_opts(self, opts):
65 | self.opts = opts
66 |
67 | def __load_latent_avg(self, ckpt, repeat=None):
68 | if 'latent_avg' in ckpt:
69 | self.latent_avg = ckpt['latent_avg'].to(self.opts.device)
70 | if repeat is not None:
71 | self.latent_avg = self.latent_avg.repeat(repeat, 1)
72 | else:
73 | self.latent_avg = None
74 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/models/__init__.py
--------------------------------------------------------------------------------
/models/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/models/networks/__init__.py
--------------------------------------------------------------------------------
/models/networks/latent_transformer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn, einsum
5 | from einops import rearrange
6 |
7 | # classes
8 | class PreNorm(nn.Module):
9 | def __init__(self, dim, fn):
10 | super().__init__()
11 | self.norm = nn.LayerNorm(dim)
12 | self.fn = fn
13 | def forward(self, x, **kwargs):
14 | return self.fn(self.norm(x), **kwargs)
15 |
16 | class FeedForward(nn.Module):
17 | def __init__(self, dim, hidden_dim, dropout = 0.):
18 | super().__init__()
19 | self.net = nn.Sequential(
20 | nn.Linear(dim, hidden_dim),
21 | nn.GELU(),
22 | nn.Dropout(dropout),
23 | nn.Linear(hidden_dim, dim),
24 | nn.Dropout(dropout)
25 | )
26 | def forward(self, x):
27 | return self.net(x)
28 |
29 | class Attention(nn.Module):
30 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
31 | super().__init__()
32 | inner_dim = dim_head * heads
33 | project_out = not (heads == 1 and dim_head == dim)
34 |
35 | self.heads = heads
36 | self.scale = dim_head ** -0.5
37 |
38 | self.attend = nn.Softmax(dim = -1)
39 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
40 |
41 | self.to_out = nn.Sequential(
42 | nn.Linear(inner_dim, dim),
43 | nn.Dropout(dropout)
44 | ) if project_out else nn.Identity()
45 |
46 | def forward(self, x):
47 | qkv = self.to_qkv(x).chunk(3, dim = -1)
48 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
49 |
50 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
51 |
52 | attn = self.attend(dots)
53 |
54 | out = torch.matmul(attn, v)
55 | out = rearrange(out, 'b h n d -> b n (h d)')
56 | return self.to_out(out)
57 |
58 | class CrossAttention(nn.Module):
59 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
60 | super().__init__()
61 | inner_dim = dim_head * heads
62 | project_out = not (heads == 1 and dim_head == dim)
63 |
64 | self.heads = heads
65 | self.scale = dim_head ** -0.5
66 |
67 | self.to_k = nn.Linear(dim, inner_dim , bias=False)
68 | self.to_v = nn.Linear(dim, inner_dim , bias = False)
69 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
70 |
71 | self.to_out = nn.Sequential(
72 | nn.Linear(inner_dim, dim),
73 | nn.Dropout(dropout)
74 | ) if project_out else nn.Identity()
75 |
76 | def forward(self, x_qkv, query_length=1):
77 | h = self.heads
78 |
79 | k = self.to_k(x_qkv)[:, query_length:]
80 | k = rearrange(k, 'b n (h d) -> b h n d', h = h)
81 |
82 | v = self.to_v(x_qkv)[:, query_length:]
83 | v = rearrange(v, 'b n (h d) -> b h n d', h = h)
84 |
85 | q = self.to_q(x_qkv)[:, :query_length]
86 | q = rearrange(q, 'b n (h d) -> b h n d', h = h)
87 |
88 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
89 |
90 | attn = dots.softmax(dim=-1)
91 |
92 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
93 | out = rearrange(out, 'b h n d -> b n (h d)')
94 | out = self.to_out(out)
95 |
96 | return out
97 |
98 | class TransformerEncoder(nn.Module):
99 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
100 | super().__init__()
101 | self.layers = nn.ModuleList([])
102 | for _ in range(depth):
103 | self.layers.append(nn.ModuleList([
104 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
105 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
106 | ]))
107 | def forward(self, x):
108 | for attn, ff in self.layers:
109 | x = attn(x) + x
110 | x = ff(x) + x
111 | return x
112 |
113 | class TransformerDecoder(nn.Module):
114 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
115 | super().__init__()
116 | self.pos_embedding = nn.Parameter(torch.randn(1, 6, dim))
117 | self.layers = nn.ModuleList([])
118 | for _ in range(depth):
119 | self.layers.append(nn.ModuleList([
120 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
121 | PreNorm(dim, CrossAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
122 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
123 | ]))
124 | def forward(self, x, y):
125 | x = x + self.pos_embedding[:, :x.shape[1]]
126 | for sattn, cattn, ff in self.layers:
127 | x = sattn(x) + x
128 | xy = torch.cat((x,y), dim=1)
129 | x = cattn(xy, query_length=x.shape[1]) + x
130 | x = ff(x) + x
131 | return x
132 |
133 | class Network(nn.Module):
134 | def __init__(self, opts):
135 | super(Network, self).__init__()
136 |
137 | self.transformer_encoder = TransformerEncoder(dim=512, depth=6, heads=8, dim_head=64, mlp_dim=512, dropout=0)
138 | self.transformer_decoder = TransformerDecoder(dim=512, depth=6, heads=8, dim_head=64, mlp_dim=512, dropout=0)
139 | self.layer1 = nn.Linear(3, 256)
140 | self.layer2 = nn.Linear(512, 256)
141 | self.layer3 = nn.Linear(512, 512)
142 | self.layer4 = nn.Linear(512, 512)
143 | self.mlp_head = nn.Sequential(
144 | nn.Linear(512, 512)
145 | )
146 |
147 | def forward(self, w, x, y, alpha=1.):
148 | #w: latent vectors
149 | #x: flow vectors
150 | #y: StyleGAN features
151 | xh = F.relu(self.layer1(x))
152 | yh = F.relu(self.layer2(y))
153 | xyh = torch.cat([xh,yh], dim=2)
154 | xyh = F.relu(self.layer3(xyh))
155 | xyh = self.transformer_encoder(xyh)
156 |
157 | wh = F.relu(self.layer4(w))
158 |
159 | h = self.transformer_decoder(wh, xyh)
160 | h = self.mlp_head(h)
161 | w_hat = w+alpha*h
162 | return w_hat
163 |
--------------------------------------------------------------------------------
/models/stylegan2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/models/stylegan2/__init__.py
--------------------------------------------------------------------------------
/models/stylegan2/op/__init__.py:
--------------------------------------------------------------------------------
1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu
2 | from .upfirdn2d import upfirdn2d
3 |
--------------------------------------------------------------------------------
/models/stylegan2/op/fused_act.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch import nn
5 | from torch.autograd import Function
6 | from torch.utils.cpp_extension import load
7 |
8 | module_path = os.path.dirname(__file__)
9 | fused = load(
10 | 'fused',
11 | sources=[
12 | os.path.join(module_path, 'fused_bias_act.cpp'),
13 | os.path.join(module_path, 'fused_bias_act_kernel.cu'),
14 | ],
15 | )
16 |
17 |
18 | class FusedLeakyReLUFunctionBackward(Function):
19 | @staticmethod
20 | def forward(ctx, grad_output, out, negative_slope, scale):
21 | ctx.save_for_backward(out)
22 | ctx.negative_slope = negative_slope
23 | ctx.scale = scale
24 |
25 | empty = grad_output.new_empty(0)
26 |
27 | grad_input = fused.fused_bias_act(
28 | grad_output, empty, out, 3, 1, negative_slope, scale
29 | )
30 |
31 | dim = [0]
32 |
33 | if grad_input.ndim > 2:
34 | dim += list(range(2, grad_input.ndim))
35 |
36 | grad_bias = grad_input.sum(dim).detach()
37 |
38 | return grad_input, grad_bias
39 |
40 | @staticmethod
41 | def backward(ctx, gradgrad_input, gradgrad_bias):
42 | out, = ctx.saved_tensors
43 | gradgrad_out = fused.fused_bias_act(
44 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
45 | )
46 |
47 | return gradgrad_out, None, None, None
48 |
49 |
50 | class FusedLeakyReLUFunction(Function):
51 | @staticmethod
52 | def forward(ctx, input, bias, negative_slope, scale):
53 | empty = input.new_empty(0)
54 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale)
55 | ctx.save_for_backward(out)
56 | ctx.negative_slope = negative_slope
57 | ctx.scale = scale
58 |
59 | return out
60 |
61 | @staticmethod
62 | def backward(ctx, grad_output):
63 | out, = ctx.saved_tensors
64 |
65 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
66 | grad_output, out, ctx.negative_slope, ctx.scale
67 | )
68 |
69 | return grad_input, grad_bias, None, None
70 |
71 |
72 | class FusedLeakyReLU(nn.Module):
73 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
74 | super().__init__()
75 |
76 | self.bias = nn.Parameter(torch.zeros(channel))
77 | self.negative_slope = negative_slope
78 | self.scale = scale
79 |
80 | def forward(self, input):
81 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
82 |
83 |
84 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
85 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
86 |
--------------------------------------------------------------------------------
/models/stylegan2/op/fused_bias_act.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
5 | int act, int grad, float alpha, float scale);
6 |
7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
10 |
11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
12 | int act, int grad, float alpha, float scale) {
13 | CHECK_CUDA(input);
14 | CHECK_CUDA(bias);
15 |
16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale);
17 | }
18 |
19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)");
21 | }
--------------------------------------------------------------------------------
/models/stylegan2/op/fused_bias_act_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | template
19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref,
20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) {
21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x;
22 |
23 | scalar_t zero = 0.0;
24 |
25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) {
26 | scalar_t x = p_x[xi];
27 |
28 | if (use_bias) {
29 | x += p_b[(xi / step_b) % size_b];
30 | }
31 |
32 | scalar_t ref = use_ref ? p_ref[xi] : zero;
33 |
34 | scalar_t y;
35 |
36 | switch (act * 10 + grad) {
37 | default:
38 | case 10: y = x; break;
39 | case 11: y = x; break;
40 | case 12: y = 0.0; break;
41 |
42 | case 30: y = (x > 0.0) ? x : x * alpha; break;
43 | case 31: y = (ref > 0.0) ? x : x * alpha; break;
44 | case 32: y = 0.0; break;
45 | }
46 |
47 | out[xi] = y * scale;
48 | }
49 | }
50 |
51 |
52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer,
53 | int act, int grad, float alpha, float scale) {
54 | int curDevice = -1;
55 | cudaGetDevice(&curDevice);
56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
57 |
58 | auto x = input.contiguous();
59 | auto b = bias.contiguous();
60 | auto ref = refer.contiguous();
61 |
62 | int use_bias = b.numel() ? 1 : 0;
63 | int use_ref = ref.numel() ? 1 : 0;
64 |
65 | int size_x = x.numel();
66 | int size_b = b.numel();
67 | int step_b = 1;
68 |
69 | for (int i = 1 + 1; i < x.dim(); i++) {
70 | step_b *= x.size(i);
71 | }
72 |
73 | int loop_x = 4;
74 | int block_size = 4 * 32;
75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1;
76 |
77 | auto y = torch::empty_like(x);
78 |
79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] {
80 | fused_bias_act_kernel<<>>(
81 | y.data_ptr(),
82 | x.data_ptr(),
83 | b.data_ptr(),
84 | ref.data_ptr(),
85 | act,
86 | grad,
87 | alpha,
88 | scale,
89 | loop_x,
90 | size_x,
91 | step_b,
92 | size_b,
93 | use_bias,
94 | use_ref
95 | );
96 | });
97 |
98 | return y;
99 | }
--------------------------------------------------------------------------------
/models/stylegan2/op/upfirdn2d.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 |
4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
5 | int up_x, int up_y, int down_x, int down_y,
6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1);
7 |
8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
11 |
12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel,
13 | int up_x, int up_y, int down_x, int down_y,
14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
15 | CHECK_CUDA(input);
16 | CHECK_CUDA(kernel);
17 |
18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1);
19 | }
20 |
21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)");
23 | }
--------------------------------------------------------------------------------
/models/stylegan2/op/upfirdn2d.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 | from torch.autograd import Function
5 | from torch.utils.cpp_extension import load
6 |
7 | module_path = os.path.dirname(__file__)
8 | upfirdn2d_op = load(
9 | 'upfirdn2d',
10 | sources=[
11 | os.path.join(module_path, 'upfirdn2d.cpp'),
12 | os.path.join(module_path, 'upfirdn2d_kernel.cu'),
13 | ],
14 | )
15 |
16 |
17 | class UpFirDn2dBackward(Function):
18 | @staticmethod
19 | def forward(
20 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
21 | ):
22 | up_x, up_y = up
23 | down_x, down_y = down
24 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
25 |
26 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
27 |
28 | grad_input = upfirdn2d_op.upfirdn2d(
29 | grad_output,
30 | grad_kernel,
31 | down_x,
32 | down_y,
33 | up_x,
34 | up_y,
35 | g_pad_x0,
36 | g_pad_x1,
37 | g_pad_y0,
38 | g_pad_y1,
39 | )
40 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
41 |
42 | ctx.save_for_backward(kernel)
43 |
44 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
45 |
46 | ctx.up_x = up_x
47 | ctx.up_y = up_y
48 | ctx.down_x = down_x
49 | ctx.down_y = down_y
50 | ctx.pad_x0 = pad_x0
51 | ctx.pad_x1 = pad_x1
52 | ctx.pad_y0 = pad_y0
53 | ctx.pad_y1 = pad_y1
54 | ctx.in_size = in_size
55 | ctx.out_size = out_size
56 |
57 | return grad_input
58 |
59 | @staticmethod
60 | def backward(ctx, gradgrad_input):
61 | kernel, = ctx.saved_tensors
62 |
63 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
64 |
65 | gradgrad_out = upfirdn2d_op.upfirdn2d(
66 | gradgrad_input,
67 | kernel,
68 | ctx.up_x,
69 | ctx.up_y,
70 | ctx.down_x,
71 | ctx.down_y,
72 | ctx.pad_x0,
73 | ctx.pad_x1,
74 | ctx.pad_y0,
75 | ctx.pad_y1,
76 | )
77 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3])
78 | gradgrad_out = gradgrad_out.view(
79 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
80 | )
81 |
82 | return gradgrad_out, None, None, None, None, None, None, None, None
83 |
84 |
85 | class UpFirDn2d(Function):
86 | @staticmethod
87 | def forward(ctx, input, kernel, up, down, pad):
88 | up_x, up_y = up
89 | down_x, down_y = down
90 | pad_x0, pad_x1, pad_y0, pad_y1 = pad
91 |
92 | kernel_h, kernel_w = kernel.shape
93 | batch, channel, in_h, in_w = input.shape
94 | ctx.in_size = input.shape
95 |
96 | input = input.reshape(-1, in_h, in_w, 1)
97 |
98 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
99 |
100 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
101 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
102 | ctx.out_size = (out_h, out_w)
103 |
104 | ctx.up = (up_x, up_y)
105 | ctx.down = (down_x, down_y)
106 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
107 |
108 | g_pad_x0 = kernel_w - pad_x0 - 1
109 | g_pad_y0 = kernel_h - pad_y0 - 1
110 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
111 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
112 |
113 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
114 |
115 | out = upfirdn2d_op.upfirdn2d(
116 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
117 | )
118 | # out = out.view(major, out_h, out_w, minor)
119 | out = out.view(-1, channel, out_h, out_w)
120 |
121 | return out
122 |
123 | @staticmethod
124 | def backward(ctx, grad_output):
125 | kernel, grad_kernel = ctx.saved_tensors
126 |
127 | grad_input = UpFirDn2dBackward.apply(
128 | grad_output,
129 | kernel,
130 | grad_kernel,
131 | ctx.up,
132 | ctx.down,
133 | ctx.pad,
134 | ctx.g_pad,
135 | ctx.in_size,
136 | ctx.out_size,
137 | )
138 |
139 | return grad_input, None, None, None, None
140 |
141 |
142 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
143 | out = UpFirDn2d.apply(
144 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
145 | )
146 |
147 | return out
148 |
149 |
150 | def upfirdn2d_native(
151 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
152 | ):
153 | _, in_h, in_w, minor = input.shape
154 | kernel_h, kernel_w = kernel.shape
155 |
156 | out = input.view(-1, in_h, 1, in_w, 1, minor)
157 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
158 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
159 |
160 | out = F.pad(
161 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
162 | )
163 | out = out[
164 | :,
165 | max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0),
166 | max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0),
167 | :,
168 | ]
169 |
170 | out = out.permute(0, 3, 1, 2)
171 | out = out.reshape(
172 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
173 | )
174 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
175 | out = F.conv2d(out, w)
176 | out = out.reshape(
177 | -1,
178 | minor,
179 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
180 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
181 | )
182 | out = out.permute(0, 2, 3, 1)
183 |
184 | return out[:, ::down_y, ::down_x, :]
185 |
--------------------------------------------------------------------------------
/models/stylegan2/op/upfirdn2d_kernel.cu:
--------------------------------------------------------------------------------
1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | //
3 | // This work is made available under the Nvidia Source Code License-NC.
4 | // To view a copy of this license, visit
5 | // https://nvlabs.github.io/stylegan2/license.html
6 |
7 | #include
8 |
9 | #include
10 | #include
11 | #include
12 | #include
13 |
14 | #include
15 | #include
16 |
17 |
18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) {
19 | int c = a / b;
20 |
21 | if (c * b > a) {
22 | c--;
23 | }
24 |
25 | return c;
26 | }
27 |
28 |
29 | struct UpFirDn2DKernelParams {
30 | int up_x;
31 | int up_y;
32 | int down_x;
33 | int down_y;
34 | int pad_x0;
35 | int pad_x1;
36 | int pad_y0;
37 | int pad_y1;
38 |
39 | int major_dim;
40 | int in_h;
41 | int in_w;
42 | int minor_dim;
43 | int kernel_h;
44 | int kernel_w;
45 | int out_h;
46 | int out_w;
47 | int loop_major;
48 | int loop_x;
49 | };
50 |
51 |
52 | template
53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) {
54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1;
55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1;
56 |
57 | __shared__ volatile float sk[kernel_h][kernel_w];
58 | __shared__ volatile float sx[tile_in_h][tile_in_w];
59 |
60 | int minor_idx = blockIdx.x;
61 | int tile_out_y = minor_idx / p.minor_dim;
62 | minor_idx -= tile_out_y * p.minor_dim;
63 | tile_out_y *= tile_out_h;
64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w;
65 | int major_idx_base = blockIdx.z * p.loop_major;
66 |
67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) {
68 | return;
69 | }
70 |
71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) {
72 | int ky = tap_idx / kernel_w;
73 | int kx = tap_idx - ky * kernel_w;
74 | scalar_t v = 0.0;
75 |
76 | if (kx < p.kernel_w & ky < p.kernel_h) {
77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)];
78 | }
79 |
80 | sk[ky][kx] = v;
81 | }
82 |
83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) {
84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) {
85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0;
86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0;
87 | int tile_in_x = floor_div(tile_mid_x, up_x);
88 | int tile_in_y = floor_div(tile_mid_y, up_y);
89 |
90 | __syncthreads();
91 |
92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) {
93 | int rel_in_y = in_idx / tile_in_w;
94 | int rel_in_x = in_idx - rel_in_y * tile_in_w;
95 | int in_x = rel_in_x + tile_in_x;
96 | int in_y = rel_in_y + tile_in_y;
97 |
98 | scalar_t v = 0.0;
99 |
100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) {
101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx];
102 | }
103 |
104 | sx[rel_in_y][rel_in_x] = v;
105 | }
106 |
107 | __syncthreads();
108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) {
109 | int rel_out_y = out_idx / tile_out_w;
110 | int rel_out_x = out_idx - rel_out_y * tile_out_w;
111 | int out_x = rel_out_x + tile_out_x;
112 | int out_y = rel_out_y + tile_out_y;
113 |
114 | int mid_x = tile_mid_x + rel_out_x * down_x;
115 | int mid_y = tile_mid_y + rel_out_y * down_y;
116 | int in_x = floor_div(mid_x, up_x);
117 | int in_y = floor_div(mid_y, up_y);
118 | int rel_in_x = in_x - tile_in_x;
119 | int rel_in_y = in_y - tile_in_y;
120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1;
121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1;
122 |
123 | scalar_t v = 0.0;
124 |
125 | #pragma unroll
126 | for (int y = 0; y < kernel_h / up_y; y++)
127 | #pragma unroll
128 | for (int x = 0; x < kernel_w / up_x; x++)
129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x];
130 |
131 | if (out_x < p.out_w & out_y < p.out_h) {
132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v;
133 | }
134 | }
135 | }
136 | }
137 | }
138 |
139 |
140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel,
141 | int up_x, int up_y, int down_x, int down_y,
142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) {
143 | int curDevice = -1;
144 | cudaGetDevice(&curDevice);
145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice);
146 |
147 | UpFirDn2DKernelParams p;
148 |
149 | auto x = input.contiguous();
150 | auto k = kernel.contiguous();
151 |
152 | p.major_dim = x.size(0);
153 | p.in_h = x.size(1);
154 | p.in_w = x.size(2);
155 | p.minor_dim = x.size(3);
156 | p.kernel_h = k.size(0);
157 | p.kernel_w = k.size(1);
158 | p.up_x = up_x;
159 | p.up_y = up_y;
160 | p.down_x = down_x;
161 | p.down_y = down_y;
162 | p.pad_x0 = pad_x0;
163 | p.pad_x1 = pad_x1;
164 | p.pad_y0 = pad_y0;
165 | p.pad_y1 = pad_y1;
166 |
167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y;
168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x;
169 |
170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options());
171 |
172 | int mode = -1;
173 |
174 | int tile_out_h;
175 | int tile_out_w;
176 |
177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
178 | mode = 1;
179 | tile_out_h = 16;
180 | tile_out_w = 64;
181 | }
182 |
183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) {
184 | mode = 2;
185 | tile_out_h = 16;
186 | tile_out_w = 64;
187 | }
188 |
189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) {
190 | mode = 3;
191 | tile_out_h = 16;
192 | tile_out_w = 64;
193 | }
194 |
195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) {
196 | mode = 4;
197 | tile_out_h = 16;
198 | tile_out_w = 64;
199 | }
200 |
201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) {
202 | mode = 5;
203 | tile_out_h = 8;
204 | tile_out_w = 32;
205 | }
206 |
207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) {
208 | mode = 6;
209 | tile_out_h = 8;
210 | tile_out_w = 32;
211 | }
212 |
213 | dim3 block_size;
214 | dim3 grid_size;
215 |
216 | if (tile_out_h > 0 && tile_out_w) {
217 | p.loop_major = (p.major_dim - 1) / 16384 + 1;
218 | p.loop_x = 1;
219 | block_size = dim3(32 * 8, 1, 1);
220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim,
221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1,
222 | (p.major_dim - 1) / p.loop_major + 1);
223 | }
224 |
225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] {
226 | switch (mode) {
227 | case 1:
228 | upfirdn2d_kernel<<>>(
229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
230 | );
231 |
232 | break;
233 |
234 | case 2:
235 | upfirdn2d_kernel<<>>(
236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
237 | );
238 |
239 | break;
240 |
241 | case 3:
242 | upfirdn2d_kernel<<>>(
243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
244 | );
245 |
246 | break;
247 |
248 | case 4:
249 | upfirdn2d_kernel<<>>(
250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
251 | );
252 |
253 | break;
254 |
255 | case 5:
256 | upfirdn2d_kernel<<>>(
257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
258 | );
259 |
260 | break;
261 |
262 | case 6:
263 | upfirdn2d_kernel<<>>(
264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p
265 | );
266 |
267 | break;
268 | }
269 | });
270 |
271 | return out;
272 | }
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/options/__init__.py
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 | class TrainOptions:
4 |
5 | def __init__(self):
6 | self.parser = ArgumentParser()
7 | self.initialize()
8 |
9 | def initialize(self):
10 | self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory')
11 |
12 | self.parser.add_argument('--batch_size', default=1, type=int, help='Batch size for training')
13 | self.parser.add_argument('--learning_rate', default=0.001, type=float, help='Optimizer learning rate')
14 | self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use')
15 | self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model')
16 |
17 | self.parser.add_argument('--lpips_lambda', default=0., type=float, help='LPIPS loss multiplier factor')
18 | self.parser.add_argument('--l2_lambda', default=0, type=float, help='L2 loss multiplier factor')
19 | self.parser.add_argument('--l2latent_lambda', default=1.0, type=float, help='L2 loss multiplier factor')
20 |
21 | self.parser.add_argument('--stylegan_weights', default='pretrained_models/stylegan2-cat-config-f.pt', type=str, help='Path to StyleGAN model weights')
22 | self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint')
23 |
24 | self.parser.add_argument('--max_steps', default=60100, type=int, help='Maximum number of training steps')
25 | self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training')
26 | self.parser.add_argument('--save_interval', default=10000, type=int, help='Model checkpoint interval')
27 |
28 | self.parser.add_argument('--style_num', default=14, type=int, help='The number of StyleGAN layers get latent codes ')
29 | self.parser.add_argument('--channel_multiplier', default=2, type=int, help='StyleGAN parameter')
30 |
31 | def parse(self):
32 | opts = self.parser.parse_args()
33 | return opts
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | """
2 | This file runs the main training/val loop
3 | """
4 | import os
5 | import json
6 | import sys
7 | import pprint
8 |
9 | sys.path.append(".")
10 | sys.path.append("..")
11 |
12 | from options.train_options import TrainOptions
13 | from training.coach import Coach
14 |
15 |
16 | def main():
17 | opts = TrainOptions().parse()
18 | os.makedirs(opts.exp_dir, exist_ok=True)
19 |
20 | opts_dict = vars(opts)
21 | pprint.pprint(opts_dict)
22 | with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f:
23 | json.dump(opts_dict, f, indent=4, sort_keys=True)
24 |
25 | coach = Coach(opts)
26 | coach.train()
27 |
28 | if __name__ == '__main__':
29 | main()
30 |
--------------------------------------------------------------------------------
/training/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/training/__init__.py
--------------------------------------------------------------------------------
/training/coach.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math, random
3 | import numpy as np
4 | import matplotlib
5 | import matplotlib.pyplot as plt
6 | matplotlib.use('Agg')
7 |
8 | import torch
9 | from torch import nn
10 | from torch.utils.tensorboard import SummaryWriter
11 | import torch.nn.functional as F
12 |
13 | from utils import common
14 | from criteria.lpips.lpips import LPIPS
15 | from models.StyleGANControler import StyleGANControler
16 | from training.ranger import Ranger
17 |
18 | from expansion.submission import Expansion
19 | from expansion.utils.flowlib import point_vec
20 |
21 | class Coach:
22 | def __init__(self, opts):
23 | self.opts = opts
24 | if self.opts.checkpoint_path is None:
25 | self.global_step = 0
26 | else:
27 | self.global_step = int(os.path.splitext(os.path.basename(self.opts.checkpoint_path))[0].split('_')[-1])
28 |
29 | self.device = 'cuda:0' # TODO: Allow multiple GPU? currently using CUDA_VISIBLE_DEVICES
30 | self.opts.device = self.device
31 |
32 | # Initialize network
33 | self.net = StyleGANControler(self.opts).to(self.device)
34 |
35 | # Initialize loss
36 | if self.opts.lpips_lambda > 0:
37 | self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval()
38 | self.mse_loss = nn.MSELoss().to(self.device).eval()
39 |
40 | # Initialize optimizer
41 | self.optimizer = self.configure_optimizers()
42 |
43 | # Initialize logger
44 | log_dir = os.path.join(opts.exp_dir, 'logs')
45 | os.makedirs(log_dir, exist_ok=True)
46 | self.logger = SummaryWriter(log_dir=log_dir)
47 |
48 | # Initialize checkpoint dir
49 | self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints')
50 | os.makedirs(self.checkpoint_dir, exist_ok=True)
51 | self.best_val_loss = None
52 | if self.opts.save_interval is None:
53 | self.opts.save_interval = self.opts.max_steps
54 |
55 | # Initialize optical flow estimator
56 | self.ex = Expansion()
57 |
58 | # Set flow normalization values
59 | if 'ffhq' in self.opts.stylegan_weights:
60 | self.sigma_f = 4
61 | self.sigma_e = 0.02
62 | elif 'car' in self.opts.stylegan_weights:
63 | self.sigma_f = 5
64 | self.sigma_e = 0.03
65 | elif 'cat' in self.opts.stylegan_weights:
66 | self.sigma_f = 12
67 | self.sigma_e = 0.04
68 | elif 'church' in self.opts.stylegan_weights:
69 | self.sigma_f = 8
70 | self.sigma_e = 0.02
71 | elif 'anime' in self.opts.stylegan_weights:
72 | self.sigma_f = 7
73 | self.sigma_e = 0.025
74 |
75 | def train(self, truncation = 0.3, sigma = 0.1, target_layers = [0,1,2,3,4,5]):
76 |
77 | x = np.array(range(0,256,16)).astype(np.float32)/127.5-1.
78 | y = np.array(range(0,256,16)).astype(np.float32)/127.5-1.
79 | xx, yy = np.meshgrid(x,y)
80 | grid = np.concatenate([xx[:,:,None],yy[:,:,None]], axis=2)
81 | grid = torch.from_numpy(grid[None,:]).cuda()
82 | grid = grid.repeat(self.opts.batch_size,1,1,1)
83 |
84 | while self.global_step < self.opts.max_steps:
85 | with torch.no_grad():
86 | z1 = torch.randn(self.opts.batch_size,512).to("cuda")
87 | z2 = torch.randn(self.opts.batch_size,self.net.style_num, 512).to("cuda")
88 |
89 | x1, w1, f1 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_feature_map=True,return_latents=True,truncation=truncation, truncation_latent=self.net.latent_avg[0])
90 | x1 = self.net.face_pool(x1)
91 | x2, w2 = self.net.decoder([z2],input_is_latent=False,randomize_noise=False,return_latents=True, truncation_latent=self.net.latent_avg[0])
92 | x2 = self.net.face_pool(x2)
93 | w_mid = w1.clone()
94 | w_mid[:,target_layers] = w_mid[:,target_layers]+sigma*(w2[:,target_layers]-w_mid[:,target_layers])
95 | x_mid, _ = self.net.decoder([w_mid], input_is_latent=True, randomize_noise=False, return_latents=False)
96 | x_mid = self.net.face_pool(x_mid)
97 |
98 | flow, logexp = self.ex.run(x1.detach(),x_mid.detach())
99 | flow_feature = torch.cat([flow/self.sigma_f, logexp/self.sigma_e], dim=1)
100 | f1 = F.interpolate(f1, (flow_feature.shape[2:]))
101 | f1 = F.grid_sample(f1, grid, mode='nearest', align_corners=True)
102 | flow_feature = F.grid_sample(flow_feature, grid, mode='nearest', align_corners=True)
103 | flow_feature = flow_feature.view(flow_feature.shape[0], flow_feature.shape[1], -1).permute(0,2,1)
104 | f1 = f1.view(f1.shape[0], f1.shape[1], -1).permute(0,2,1)
105 |
106 | self.net.train()
107 | self.optimizer.zero_grad()
108 | w_hat = self.net.encoder(w1[:,target_layers].detach(), flow_feature.detach(), f1.detach())
109 | loss, loss_dict, id_logs = self.calc_loss(w_hat, w_mid[:,target_layers].detach())
110 | loss.backward()
111 | self.optimizer.step()
112 |
113 | w_mid[:,target_layers] = w_hat.detach()
114 | x_hat, _ = self.net.decoder([w_mid], input_is_latent=True, randomize_noise=False)
115 | x_hat = self.net.face_pool(x_hat)
116 | if self.global_step % self.opts.image_interval == 0 or (
117 | self.global_step < 1000 and self.global_step % 100 == 0):
118 | imgL_o = ((x1.detach()+1.)*127.5)[0].permute(1,2,0).cpu().numpy()
119 | flow = torch.cat((flow,torch.ones_like(flow)[:,:1]), dim=1)[0].permute(1,2,0).cpu().numpy()
120 | flowvis = point_vec(imgL_o, flow)
121 | flowvis = torch.from_numpy(flowvis[:,:,::-1].copy()).permute(2,0,1).unsqueeze(0)/127.5-1.
122 | self.parse_and_log_images(None, flowvis, x_mid, x_hat, title='trained_images')
123 | print(loss_dict)
124 |
125 | if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps:
126 | self.checkpoint_me(loss_dict, is_best=False)
127 |
128 | if self.global_step == self.opts.max_steps:
129 | print('OMG, finished training!')
130 | break
131 |
132 | self.global_step += 1
133 |
134 | def checkpoint_me(self, loss_dict, is_best):
135 | save_name = 'best_model.pt' if is_best else 'iteration_{}.pt'.format(self.global_step)
136 | save_dict = self.__get_save_dict()
137 | checkpoint_path = os.path.join(self.checkpoint_dir, save_name)
138 | torch.save(save_dict, checkpoint_path)
139 | with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f:
140 | if is_best:
141 | f.write('**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict))
142 | else:
143 | f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict))
144 |
145 | def configure_optimizers(self):
146 | params = list(self.net.encoder.parameters())
147 | if self.opts.train_decoder:
148 | params += list(self.net.decoder.parameters())
149 | if self.opts.optim_name == 'adam':
150 | optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate)
151 | else:
152 | optimizer = Ranger(params, lr=self.opts.learning_rate)
153 | return optimizer
154 |
155 | def calc_loss(self, latent, w, y_hat=None, y=None):
156 | loss_dict = {}
157 | loss = 0.0
158 | id_logs = None
159 |
160 | if self.opts.l2_lambda > 0 and (y_hat is not None) and (y is not None):
161 | loss_l2 = F.mse_loss(y_hat, y)
162 | loss_dict['loss_l2'] = float(loss_l2)
163 | loss += loss_l2 * self.opts.l2_lambda
164 | if self.opts.lpips_lambda > 0 and (y_hat is not None) and (y is not None):
165 | loss_lpips = self.lpips_loss(y_hat, y)
166 | loss_dict['loss_lpips'] = float(loss_lpips)
167 | loss += loss_lpips * self.opts.lpips_lambda
168 | if self.opts.l2latent_lambda > 0:
169 | loss_l2 = F.mse_loss(latent, w)
170 | loss_dict['loss_l2latent'] = float(loss_l2)
171 | loss += loss_l2 * self.opts.l2latent_lambda
172 |
173 | loss_dict['loss'] = float(loss)
174 | return loss, loss_dict, id_logs
175 |
176 | def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=None, display_count=1):
177 | im_data = []
178 | for i in range(display_count):
179 | cur_im_data = {
180 | 'input_face': common.tensor2im(x[i]),
181 | 'target_face': common.tensor2im(y[i]),
182 | 'output_face': common.tensor2im(y_hat[i]),
183 | }
184 | if id_logs is not None:
185 | for key in id_logs[i]:
186 | cur_im_data[key] = id_logs[i][key]
187 | im_data.append(cur_im_data)
188 | self.log_images(title, im_data=im_data, subscript=subscript)
189 |
190 |
191 | def log_images(self, name, im_data, subscript=None, log_latest=False):
192 | fig = common.vis_faces(im_data)
193 | step = self.global_step
194 | if log_latest:
195 | step = 0
196 | if subscript:
197 | path = os.path.join(self.logger.log_dir, name, '{}_{:04d}.jpg'.format(subscript, step))
198 | else:
199 | path = os.path.join(self.logger.log_dir, name, '{:04d}.jpg'.format(step))
200 | os.makedirs(os.path.dirname(path), exist_ok=True)
201 | fig.savefig(path)
202 | plt.close(fig)
203 |
204 | def __get_save_dict(self):
205 | save_dict = {
206 | 'state_dict': self.net.state_dict(),
207 | 'opts': vars(self.opts)
208 | }
209 |
210 | save_dict['latent_avg'] = self.net.latent_avg
211 | return save_dict
--------------------------------------------------------------------------------
/training/ranger.py:
--------------------------------------------------------------------------------
1 | # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer.
2 |
3 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
4 | # and/or
5 | # https://github.com/lessw2020/Best-Deep-Learning-Optimizers
6 |
7 | # Ranger has now been used to capture 12 records on the FastAI leaderboard.
8 |
9 | # This version = 20.4.11
10 |
11 | # Credits:
12 | # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization
13 | # RAdam --> https://github.com/LiyuanLucasLiu/RAdam
14 | # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code.
15 | # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610
16 |
17 | # summary of changes:
18 | # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init.
19 | # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights),
20 | # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues.
21 | # changes 8/31/19 - fix references to *self*.N_sma_threshold;
22 | # changed eps to 1e-5 as better default than 1e-8.
23 |
24 | import math
25 | import torch
26 | from torch.optim.optimizer import Optimizer
27 |
28 |
29 | class Ranger(Optimizer):
30 |
31 | def __init__(self, params, lr=1e-3, # lr
32 | alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options
33 | betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options
34 | use_gc=True, gc_conv_only=False
35 | # Gradient centralization on or off, applied to conv layers only or conv + fc layers
36 | ):
37 |
38 | # parameter checks
39 | if not 0.0 <= alpha <= 1.0:
40 | raise ValueError(f'Invalid slow update rate: {alpha}')
41 | if not 1 <= k:
42 | raise ValueError(f'Invalid lookahead steps: {k}')
43 | if not lr > 0:
44 | raise ValueError(f'Invalid Learning Rate: {lr}')
45 | if not eps > 0:
46 | raise ValueError(f'Invalid eps: {eps}')
47 |
48 | # parameter comments:
49 | # beta1 (momentum) of .95 seems to work better than .90...
50 | # N_sma_threshold of 5 seems better in testing than 4.
51 | # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
52 |
53 | # prep defaults and init torch.optim base
54 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold,
55 | eps=eps, weight_decay=weight_decay)
56 | super().__init__(params, defaults)
57 |
58 | # adjustable threshold
59 | self.N_sma_threshhold = N_sma_threshhold
60 |
61 | # look ahead params
62 |
63 | self.alpha = alpha
64 | self.k = k
65 |
66 | # radam buffer for state
67 | self.radam_buffer = [[None, None, None] for ind in range(10)]
68 |
69 | # gc on or off
70 | self.use_gc = use_gc
71 |
72 | # level of gradient centralization
73 | self.gc_gradient_threshold = 3 if gc_conv_only else 1
74 |
75 | def __setstate__(self, state):
76 | super(Ranger, self).__setstate__(state)
77 |
78 | def step(self, closure=None):
79 | loss = None
80 |
81 | # Evaluate averages and grad, update param tensors
82 | for group in self.param_groups:
83 |
84 | for p in group['params']:
85 | if p.grad is None:
86 | continue
87 | grad = p.grad.data.float()
88 |
89 | if grad.is_sparse:
90 | raise RuntimeError('Ranger optimizer does not support sparse gradients')
91 |
92 | p_data_fp32 = p.data.float()
93 |
94 | state = self.state[p] # get state dict for this param
95 |
96 | if len(state) == 0: # if first time to run...init dictionary with our desired entries
97 | # if self.first_run_check==0:
98 | # self.first_run_check=1
99 | # print("Initializing slow buffer...should not see this at load from saved model!")
100 | state['step'] = 0
101 | state['exp_avg'] = torch.zeros_like(p_data_fp32)
102 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
103 |
104 | # look ahead weight storage now in state dict
105 | state['slow_buffer'] = torch.empty_like(p.data)
106 | state['slow_buffer'].copy_(p.data)
107 |
108 | else:
109 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
110 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
111 |
112 | # begin computations
113 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
114 | beta1, beta2 = group['betas']
115 |
116 | # GC operation for Conv layers and FC layers
117 | if grad.dim() > self.gc_gradient_threshold:
118 | grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True))
119 |
120 | state['step'] += 1
121 |
122 | # compute variance mov avg
123 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
124 | # compute mean moving avg
125 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
126 |
127 | buffered = self.radam_buffer[int(state['step'] % 10)]
128 |
129 | if state['step'] == buffered[0]:
130 | N_sma, step_size = buffered[1], buffered[2]
131 | else:
132 | buffered[0] = state['step']
133 | beta2_t = beta2 ** state['step']
134 | N_sma_max = 2 / (1 - beta2) - 1
135 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
136 | buffered[1] = N_sma
137 | if N_sma > self.N_sma_threshhold:
138 | step_size = math.sqrt(
139 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
140 | N_sma_max - 2)) / (1 - beta1 ** state['step'])
141 | else:
142 | step_size = 1.0 / (1 - beta1 ** state['step'])
143 | buffered[2] = step_size
144 |
145 | if group['weight_decay'] != 0:
146 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
147 |
148 | # apply lr
149 | if N_sma > self.N_sma_threshhold:
150 | denom = exp_avg_sq.sqrt().add_(group['eps'])
151 | p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr'])
152 | else:
153 | p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr'])
154 |
155 | p.data.copy_(p_data_fp32)
156 |
157 | # integrated look ahead...
158 | # we do it at the param level instead of group level
159 | if state['step'] % group['k'] == 0:
160 | slow_p = state['slow_buffer'] # get access to slow param tensor
161 | slow_p.add_(p.data - slow_p, alpha=self.alpha) # (fast weights - slow weights) * alpha
162 | p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor
163 |
164 | return loss
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/utils/__init__.py
--------------------------------------------------------------------------------
/utils/common.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from PIL import Image
4 | import matplotlib.pyplot as plt
5 | import random
6 |
7 |
8 | # Log images
9 | def log_input_image(x, opts):
10 | if opts.label_nc == 0:
11 | return tensor2im(x)
12 | elif opts.label_nc == 1:
13 | return tensor2sketch(x)
14 | else:
15 | return tensor2map(x)
16 |
17 |
18 | def tensor2im(var):
19 | var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy()
20 | var = ((var + 1) / 2)
21 | var[var < 0] = 0
22 | var[var > 1] = 1
23 | var = var * 255
24 | return Image.fromarray(var.astype('uint8'))
25 |
26 |
27 | def tensor2map(var):
28 | mask = np.argmax(var.data.cpu().numpy(), axis=0)
29 | colors = get_colors()
30 | mask_image = np.ones(shape=(mask.shape[0], mask.shape[1], 3))
31 | for class_idx in np.unique(mask):
32 | mask_image[mask == class_idx] = colors[class_idx]
33 | mask_image = mask_image.astype('uint8')
34 | return Image.fromarray(mask_image)
35 |
36 |
37 | def tensor2sketch(var):
38 | im = var[0].cpu().detach().numpy()
39 | im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
40 | im = (im * 255).astype(np.uint8)
41 | return Image.fromarray(im)
42 |
43 |
44 | # Visualization utils
45 | def get_colors():
46 | # currently support up to 19 classes (for the celebs-hq-mask dataset)
47 | colors = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255],
48 | [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204],
49 | [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]]
50 |
51 | # asign random colors to more 200 classes
52 | random.seed(0)
53 | for i in range(200):
54 | colors.append([random.randint(0,255),random.randint(0,255),random.randint(0,255)])
55 | return colors
56 |
57 |
58 | def vis_faces(log_hooks):
59 | display_count = len(log_hooks)
60 | fig = plt.figure(figsize=(8, 4 * display_count))
61 | gs = fig.add_gridspec(display_count, 3)
62 | for i in range(display_count):
63 | hooks_dict = log_hooks[i]
64 | fig.add_subplot(gs[i, 0])
65 | if 'diff_input' in hooks_dict:
66 | vis_faces_with_id(hooks_dict, fig, gs, i)
67 | else:
68 | vis_faces_no_id(hooks_dict, fig, gs, i)
69 | plt.tight_layout()
70 | return fig
71 |
72 |
73 | def vis_faces_with_id(hooks_dict, fig, gs, i):
74 | plt.imshow(hooks_dict['input_face'])
75 | plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input'])))
76 | fig.add_subplot(gs[i, 1])
77 | plt.imshow(hooks_dict['target_face'])
78 | plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']),
79 | float(hooks_dict['diff_target'])))
80 | fig.add_subplot(gs[i, 2])
81 | plt.imshow(hooks_dict['output_face'])
82 | plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target'])))
83 |
84 |
85 | def vis_faces_no_id(hooks_dict, fig, gs, i):
86 | plt.imshow(hooks_dict['input_face'], cmap="gray")
87 | plt.title('Input')
88 | fig.add_subplot(gs[i, 1])
89 | plt.imshow(hooks_dict['target_face'])
90 | plt.title('Target')
91 | fig.add_subplot(gs[i, 2])
92 | plt.imshow(hooks_dict['output_face'])
93 | plt.title('Output')
94 |
--------------------------------------------------------------------------------