├── LICENSE ├── README.md ├── criteria ├── __init__.py └── lpips │ ├── __init__.py │ ├── lpips.py │ ├── networks.py │ └── utils.py ├── docs ├── anime.gif ├── car.gif ├── church.gif ├── ffhq.gif ├── teaser.jpg └── thumb.gif ├── env.yaml ├── expansion ├── __init__.py ├── dataloader │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ └── seqlist.cpython-38.pyc │ ├── chairslist.py │ ├── chairssdlist.py │ ├── depth_transforms.py │ ├── depthloader.py │ ├── flow_transforms.py │ ├── hd1klist.py │ ├── kitti12list.py │ ├── kitti15list.py │ ├── kitti15list_train.py │ ├── kitti15list_train_lidar.py │ ├── kitti15list_val.py │ ├── kitti15list_val_lidar.py │ ├── kitti15list_val_mr.py │ ├── robloader.py │ ├── sceneflowlist.py │ ├── seqlist.py │ ├── sintellist.py │ ├── sintellist_clean.py │ ├── sintellist_final.py │ ├── sintellist_train.py │ ├── sintellist_val.py │ └── thingslist.py ├── models │ ├── VCN_exp.py │ ├── __init__.py │ ├── __pycache__ │ │ ├── VCN_exp.cpython-38.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── conv4d.cpython-38.pyc │ │ └── submodule.cpython-38.pyc │ ├── conv4d.py │ └── submodule.py ├── submission.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── flowlib.cpython-38.pyc │ ├── io.cpython-38.pyc │ ├── pfm.cpython-38.pyc │ └── util_flow.cpython-38.pyc │ ├── flowlib.py │ ├── io.py │ ├── logger.py │ ├── multiscaleloss.py │ ├── pfm.py │ ├── readpfm.py │ ├── sintel_io.py │ └── util_flow.py ├── interface ├── flask_app.py ├── inference.py └── templates │ └── index.html ├── licenses ├── LICENSE_ gengshan-y_expansion ├── LICENSE_HuangYG123 ├── LICENSE_S-aiueo32 ├── LICENSE_TreB1eN ├── LICENSE_lessw2020 ├── LICENSE_pixel2style2pixel └── LICENSE_rosinality ├── models ├── StyleGANControler.py ├── __init__.py ├── networks │ ├── __init__.py │ └── latent_transformer.py └── stylegan2 │ ├── __init__.py │ ├── model.py │ └── op │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── options ├── __init__.py └── train_options.py ├── scripts └── train.py ├── training ├── __init__.py ├── coach.py └── ranger.py └── utils ├── __init__.py └── common.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yuki Endo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # User-Controllable Latent Transformer for StyleGAN Image Layout Editing 2 | 3 | 4 |

5 | 6 |

7 | 8 | This repository contains our implementation of the following paper: 9 | 10 | Yuki Endo: "User-Controllable Latent Transformer for StyleGAN Image Layout Editing," Computer Graphpics Forum (Pacific Graphics 2022) [[Project](http://www.cgg.cs.tsukuba.ac.jp/~endo/projects/UserControllableLT)] [[PDF (preprint)](http://arxiv.org/abs/2208.12408)] 11 | 12 | ## Prerequisites 13 | 1. Python 3.8 14 | 2. PyTorch 1.9.0 15 | 3. Flask 16 | 4. Others (see env.yml) 17 | 18 | ## Preparation 19 | Download and decompress our pre-trained models. 20 | 21 | ## Inference with our pre-trained models 22 |
23 | We provide an interactive interface based on Flask. This interface can be locally launched with 24 | ``` 25 | python interface/flask_app.py --checkpoint_path=pretrained_models/latent_transformer/cat.pt 26 | ``` 27 | The interface can be accessed via http://localhost:8000/. 28 | 29 | ## Training 30 | The latent transformer can be trained with 31 | ``` 32 | python scripts/train.py --exp_dir=results --stylegan_weights=pretrained_models/stylegan2-cat-config-f.pt 33 | ``` 34 | To perform training with your dataset, you need first to train StyleGAN2 on your dataset using [rosinality's code](https://github.com/rosinality/stylegan2-pytorch) and then run the above script with specifying the trained weights. 35 | 36 | ## Link 37 | [Gradio demo](https://huggingface.co/spaces/radames/UserControllableLT-Latent-Transformer) by Radamés Ajna 38 | 39 | ## Citation 40 | Please cite our paper if you find the code useful: 41 | ``` 42 | @Article{endoPG2022, 43 | Title = {User-Controllable Latent Transformer for StyleGAN Image Layout Editing}, 44 | Author = {Yuki Endo}, 45 | Journal = {Computer Graphics Forum}, 46 | volume = {41}, 47 | number = {7}, 48 | pages = {395-406}, 49 | doi = {10.1111/cgf.14686}, 50 | Year = {2022} 51 | } 52 | ``` 53 | 54 | ## Acknowledgements 55 | This code heavily borrows from the [pixel2style2pixel](https://github.com/eladrich/pixel2style2pixel) and [expansion](https://github.com/gengshan-y/expansion) repositories. 56 | -------------------------------------------------------------------------------- /criteria/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/criteria/__init__.py -------------------------------------------------------------------------------- /criteria/lpips/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/criteria/lpips/__init__.py -------------------------------------------------------------------------------- /criteria/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from criteria.lpips.networks import get_network, LinLayers 5 | from criteria.lpips.utils import get_state_dict 6 | 7 | 8 | class LPIPS(nn.Module): 9 | r"""Creates a criterion that measures 10 | Learned Perceptual Image Patch Similarity (LPIPS). 11 | Arguments: 12 | net_type (str): the network type to compare the features: 13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'. 14 | version (str): the version of LPIPS. Default: 0.1. 15 | """ 16 | def __init__(self, net_type: str = 'alex', version: str = '0.1'): 17 | 18 | assert version in ['0.1'], 'v0.1 is only supported now' 19 | 20 | super(LPIPS, self).__init__() 21 | 22 | # pretrained network 23 | self.net = get_network(net_type).to("cuda") 24 | 25 | # linear layers 26 | self.lin = LinLayers(self.net.n_channels_list).to("cuda") 27 | self.lin.load_state_dict(get_state_dict(net_type, version)) 28 | 29 | def forward(self, x: torch.Tensor, y: torch.Tensor): 30 | feat_x, feat_y = self.net(x), self.net(y) 31 | 32 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)] 33 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)] 34 | 35 | return torch.sum(torch.cat(res, 0)) / x.shape[0] 36 | -------------------------------------------------------------------------------- /criteria/lpips/networks.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | from itertools import chain 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torchvision import models 8 | 9 | from criteria.lpips.utils import normalize_activation 10 | 11 | 12 | def get_network(net_type: str): 13 | if net_type == 'alex': 14 | return AlexNet() 15 | elif net_type == 'squeeze': 16 | return SqueezeNet() 17 | elif net_type == 'vgg': 18 | return VGG16() 19 | else: 20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].') 21 | 22 | 23 | class LinLayers(nn.ModuleList): 24 | def __init__(self, n_channels_list: Sequence[int]): 25 | super(LinLayers, self).__init__([ 26 | nn.Sequential( 27 | nn.Identity(), 28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False) 29 | ) for nc in n_channels_list 30 | ]) 31 | 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | class BaseNet(nn.Module): 37 | def __init__(self): 38 | super(BaseNet, self).__init__() 39 | 40 | # register buffer 41 | self.register_buffer( 42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 43 | self.register_buffer( 44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None]) 45 | 46 | def set_requires_grad(self, state: bool): 47 | for param in chain(self.parameters(), self.buffers()): 48 | param.requires_grad = state 49 | 50 | def z_score(self, x: torch.Tensor): 51 | return (x - self.mean) / self.std 52 | 53 | def forward(self, x: torch.Tensor): 54 | x = self.z_score(x) 55 | 56 | output = [] 57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1): 58 | x = layer(x) 59 | if i in self.target_layers: 60 | output.append(normalize_activation(x)) 61 | if len(output) == len(self.target_layers): 62 | break 63 | return output 64 | 65 | 66 | class SqueezeNet(BaseNet): 67 | def __init__(self): 68 | super(SqueezeNet, self).__init__() 69 | 70 | self.layers = models.squeezenet1_1(True).features 71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13] 72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512] 73 | 74 | self.set_requires_grad(False) 75 | 76 | 77 | class AlexNet(BaseNet): 78 | def __init__(self): 79 | super(AlexNet, self).__init__() 80 | 81 | self.layers = models.alexnet(True).features 82 | self.target_layers = [2, 5, 8, 10, 12] 83 | self.n_channels_list = [64, 192, 384, 256, 256] 84 | 85 | self.set_requires_grad(False) 86 | 87 | 88 | class VGG16(BaseNet): 89 | def __init__(self): 90 | super(VGG16, self).__init__() 91 | 92 | self.layers = models.vgg16(True).features 93 | self.target_layers = [4, 9, 16, 23, 30] 94 | self.n_channels_list = [64, 128, 256, 512, 512] 95 | 96 | self.set_requires_grad(False) -------------------------------------------------------------------------------- /criteria/lpips/utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | 5 | 6 | def normalize_activation(x, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True)) 8 | return x / (norm_factor + eps) 9 | 10 | 11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'): 12 | # build url 13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \ 14 | + f'master/lpips/weights/v{version}/{net_type}.pth' 15 | 16 | # download 17 | old_state_dict = torch.hub.load_state_dict_from_url( 18 | url, progress=True, 19 | map_location=None if torch.cuda.is_available() else torch.device('cpu') 20 | ) 21 | 22 | # rename keys 23 | new_state_dict = OrderedDict() 24 | for key, val in old_state_dict.items(): 25 | new_key = key 26 | new_key = new_key.replace('lin', '') 27 | new_key = new_key.replace('model.', '') 28 | new_state_dict[new_key] = val 29 | 30 | return new_state_dict 31 | -------------------------------------------------------------------------------- /docs/anime.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/docs/anime.gif -------------------------------------------------------------------------------- /docs/car.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/docs/car.gif -------------------------------------------------------------------------------- /docs/church.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/docs/church.gif -------------------------------------------------------------------------------- /docs/ffhq.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/docs/ffhq.gif -------------------------------------------------------------------------------- /docs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/docs/teaser.jpg -------------------------------------------------------------------------------- /docs/thumb.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/docs/thumb.gif -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: uclt 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _ipyw_jlab_nb_ext_conf=0.1.0=py38_0 10 | - _libgcc_mutex=0.1=conda_forge 11 | - _openmp_mutex=4.5=1_llvm 12 | - absl-py=0.13.0=pyhd8ed1ab_0 13 | - aiohttp=3.7.4.post0=py38h497a2fe_0 14 | - albumentations=1.0.3=pyhd8ed1ab_0 15 | - alsa-lib=1.2.3=h516909a_0 16 | - anaconda-client=1.8.0=py38h06a4308_0 17 | - anaconda-navigator=2.0.4=py38_0 18 | - anyio=2.2.0=py38h06a4308_1 19 | - appdirs=1.4.4=pyh9f0ad1d_0 20 | - argon2-cffi=20.1.0=py38h27cfd23_1 21 | - async-timeout=3.0.1=py_1000 22 | - async_generator=1.10=pyhd3eb1b0_0 23 | - attrs=21.2.0=pyhd3eb1b0_0 24 | - babel=2.9.1=pyhd3eb1b0_0 25 | - backcall=0.2.0=pyhd3eb1b0_0 26 | - backports=1.0=pyhd3eb1b0_2 27 | - backports.functools_lru_cache=1.6.4=pyhd3eb1b0_0 28 | - backports.tempfile=1.0=pyhd3eb1b0_1 29 | - backports.weakref=1.0.post1=py_1 30 | - beautifulsoup4=4.9.3=pyha847dfd_0 31 | - blas=1.0=mkl 32 | - bleach=4.0.0=pyhd3eb1b0_0 33 | - blinker=1.4=py_1 34 | - brotli=1.0.9=h7f98852_5 35 | - brotli-bin=1.0.9=h7f98852_5 36 | - brotlipy=0.7.0=py38h27cfd23_1003 37 | - bzip2=1.0.8=h7b6447c_0 38 | - c-ares=1.17.1=h27cfd23_0 39 | - ca-certificates=2021.10.8=ha878542_0 40 | - cachetools=4.2.2=pyhd8ed1ab_0 41 | - cairo=1.16.0=hf32fb01_1 42 | - certifi=2021.10.8=py38h578d9bd_1 43 | - cffi=1.14.6=py38h400218f_0 44 | - chardet=4.0.0=py38h06a4308_1003 45 | - click=8.0.1=pyhd3eb1b0_0 46 | - cloudpickle=1.6.0=py_0 47 | - clyent=1.2.2=py38_1 48 | - conda=4.11.0=py38h578d9bd_0 49 | - conda-build=3.21.4=py38h06a4308_0 50 | - conda-content-trust=0.1.1=pyhd3eb1b0_0 51 | - conda-env=2.6.0=1 52 | - conda-package-handling=1.7.3=py38h27cfd23_1 53 | - conda-repo-cli=1.0.4=pyhd3eb1b0_0 54 | - conda-token=0.3.0=pyhd3eb1b0_0 55 | - conda-verify=3.4.2=py_1 56 | - cryptography=3.4.7=py38hd23ed53_0 57 | - cudatoolkit=11.1.74=h6bb024c_0 58 | - cycler=0.10.0=py_2 59 | - cytoolz=0.11.0=py38h497a2fe_3 60 | - dask-core=2021.8.1=pyhd8ed1ab_0 61 | - dbus=1.13.18=hb2f20db_0 62 | - decorator=5.0.9=pyhd3eb1b0_0 63 | - defusedxml=0.7.1=pyhd3eb1b0_0 64 | - dill=0.3.4=pyhd8ed1ab_0 65 | - dominate=2.6.0=pyhd8ed1ab_0 66 | - entrypoints=0.3=py38_0 67 | - enum34=1.1.10=py38h32f6830_2 68 | - expat=2.4.1=h2531618_2 69 | - ffmpeg=4.3.2=hca11adc_0 70 | - filelock=3.0.12=pyhd3eb1b0_1 71 | - flask=1.1.2=pyh9f0ad1d_0 72 | - flask-httpauth=4.4.0=pyhd8ed1ab_0 73 | - fontconfig=2.13.1=h6c09931_0 74 | - fonttools=4.25.0=pyhd3eb1b0_0 75 | - freetype=2.10.4=h5ab3b9f_0 76 | - fsspec=2021.7.0=pyhd8ed1ab_0 77 | - ftfy=6.0.3=pyhd8ed1ab_0 78 | - func_timeout=4.3.5=py_0 79 | - future=0.18.2=py38_1 80 | - gdown=4.2.0=pyhd8ed1ab_0 81 | - geos=3.10.0=h9c3ff4c_0 82 | - gettext=0.19.8.1=h0b5b191_1005 83 | - git=2.23.0=pl526hacde149_0 84 | - glib=2.68.4=h9c3ff4c_0 85 | - glib-tools=2.68.4=h9c3ff4c_0 86 | - glob2=0.7=pyhd3eb1b0_0 87 | - gmp=6.2.1=h58526e2_0 88 | - gnutls=3.6.13=h85f3911_1 89 | - google-auth=1.35.0=pyh6c4a22f_0 90 | - google-auth-oauthlib=0.4.5=pyhd8ed1ab_0 91 | - gputil=1.4.0=pyh9f0ad1d_0 92 | - graphite2=1.3.13=h58526e2_1001 93 | - gst-plugins-base=1.18.4=hf529b03_2 94 | - gstreamer=1.18.4=h76c114f_2 95 | - harfbuzz=2.9.0=h83ec7ef_0 96 | - hdf5=1.10.6=nompi_h6a2412b_1114 97 | - icu=68.1=h58526e2_0 98 | - idna=2.10=pyhd3eb1b0_0 99 | - imagecodecs-lite=2019.12.3=py38h5c078b8_3 100 | - imageio=2.9.0=py_0 101 | - imageio-ffmpeg=0.4.5=pyhd8ed1ab_0 102 | - imgaug=0.4.0=py_1 103 | - importlib-metadata=3.10.0=py38h06a4308_0 104 | - importlib_metadata=3.10.0=hd3eb1b0_0 105 | - intel-openmp=2021.3.0=h06a4308_3350 106 | - ipykernel=5.3.4=py38h5ca1d4c_0 107 | - ipympl=0.8.2=pyhd8ed1ab_0 108 | - ipython=7.26.0=py38hb070fc8_0 109 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 110 | - ipywidgets=7.6.3=pyhd3eb1b0_1 111 | - itsdangerous=2.0.1=pyhd3eb1b0_0 112 | - jasper=1.900.1=h07fcdf6_1006 113 | - jedi=0.18.0=py38h06a4308_1 114 | - jinja2=2.11.3=pyhd3eb1b0_0 115 | - joblib=1.1.0=pyhd8ed1ab_0 116 | - jpeg=9d=h36c2ea0_0 117 | - json5=0.9.6=pyhd3eb1b0_0 118 | - jsonnet=0.17.0=py38hadf7658_0 119 | - jsonschema=3.2.0=py_2 120 | - jupyter_client=6.1.12=pyhd3eb1b0_0 121 | - jupyter_core=4.7.1=py38h06a4308_0 122 | - jupyter_server=1.4.1=py38h06a4308_0 123 | - jupyterlab=3.1.7=pyhd3eb1b0_0 124 | - jupyterlab_pygments=0.1.2=py_0 125 | - jupyterlab_server=2.7.1=pyhd3eb1b0_0 126 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 127 | - kiwisolver=1.3.1=py38h1fd1430_1 128 | - krb5=1.19.2=hcc1bbae_0 129 | - lame=3.100=h7f98852_1001 130 | - lcms2=2.12=h3be6417_0 131 | - ld_impl_linux-64=2.35.1=h7274673_9 132 | - libarchive=3.4.2=h62408e4_0 133 | - libblas=3.9.0=11_linux64_mkl 134 | - libbrotlicommon=1.0.9=h7f98852_5 135 | - libbrotlidec=1.0.9=h7f98852_5 136 | - libbrotlienc=1.0.9=h7f98852_5 137 | - libcblas=3.9.0=11_linux64_mkl 138 | - libcurl=7.78.0=h2574ce0_0 139 | - libedit=3.1.20191231=he28a2e2_2 140 | - libev=4.33=h516909a_1 141 | - libevent=2.1.10=hcdb4288_3 142 | - libffi=3.3=he6710b0_2 143 | - libgcc-ng=11.1.0=hc902ee8_8 144 | - libgfortran-ng=11.1.0=h69a702a_8 145 | - libgfortran5=11.1.0=h6c583b3_8 146 | - libglib=2.68.4=h3e27bee_0 147 | - libiconv=1.16=h516909a_0 148 | - liblapack=3.9.0=11_linux64_mkl 149 | - liblapacke=3.9.0=11_linux64_mkl 150 | - liblief=0.10.1=he6710b0_0 151 | - libllvm11=11.1.0=hf817b99_2 152 | - libnghttp2=1.43.0=h812cca2_0 153 | - libogg=1.3.4=h7f98852_1 154 | - libopencv=4.5.2=py38hcdf9bf1_0 155 | - libopus=1.3.1=h7f98852_1 156 | - libpng=1.6.37=hbc83047_0 157 | - libpq=13.3=hd57d9b9_0 158 | - libprotobuf=3.15.8=h780b84a_0 159 | - libsodium=1.0.18=h7b6447c_0 160 | - libssh2=1.9.0=ha56f1ee_6 161 | - libstdcxx-ng=11.1.0=h56837e0_8 162 | - libtiff=4.2.0=h85742a9_0 163 | - libuuid=1.0.3=h1bed415_2 164 | - libuv=1.40.0=h7b6447c_0 165 | - libvorbis=1.3.7=h9c3ff4c_0 166 | - libwebp-base=1.2.0=h27cfd23_0 167 | - libxcb=1.14=h7b6447c_0 168 | - libxkbcommon=1.0.3=he3ba5ed_0 169 | - libxml2=2.9.12=h72842e0_0 170 | - llvm-openmp=12.0.1=h4bd325d_1 171 | - locket=0.2.0=py_2 172 | - lz4-c=1.9.3=h295c915_1 173 | - markdown=3.3.4=pyhd8ed1ab_0 174 | - markupsafe=2.0.1=py38h27cfd23_0 175 | - matplotlib=3.4.2=py38h578d9bd_0 176 | - matplotlib-base=3.4.2=py38hab158f2_0 177 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 178 | - mistune=0.8.4=py38h7b6447c_1000 179 | - mkl=2021.3.0=h06a4308_520 180 | - mkl-service=2.4.0=py38h7f8727e_0 181 | - mkl_fft=1.3.0=py38h42c9631_2 182 | - mkl_random=1.2.2=py38h51133e4_0 183 | - multidict=5.1.0=py38h497a2fe_1 184 | - munkres=1.1.4=pyh9f0ad1d_0 185 | - mysql-common=8.0.25=ha770c72_0 186 | - mysql-libs=8.0.25=h935591d_0 187 | - navigator-updater=0.2.1=py38_0 188 | - nbclassic=0.2.6=pyhd3eb1b0_0 189 | - nbclient=0.5.3=pyhd3eb1b0_0 190 | - nbconvert=6.1.0=py38h06a4308_0 191 | - nbformat=5.1.3=pyhd3eb1b0_0 192 | - ncurses=6.2=he6710b0_1 193 | - nest-asyncio=1.5.1=pyhd3eb1b0_0 194 | - nettle=3.6=he412f7d_0 195 | - networkx=2.3=py_0 196 | - ninja=1.10.2=hff7bd54_1 197 | - notebook=6.4.3=py38h06a4308_0 198 | - nspr=4.30=h9c3ff4c_0 199 | - nss=3.69=hb5efdd6_0 200 | - numpy=1.20.3=py38hf144106_0 201 | - numpy-base=1.20.3=py38h74d4b33_0 202 | - oauthlib=3.1.1=pyhd8ed1ab_0 203 | - olefile=0.46=py_0 204 | - opencv=4.5.2=py38h578d9bd_0 205 | - openh264=2.1.1=h780b84a_0 206 | - openjpeg=2.3.0=h05c96fa_1 207 | - openssl=1.1.1l=h7f98852_0 208 | - packaging=21.0=pyhd3eb1b0_0 209 | - pandas=1.3.2=py38h43a58ef_0 210 | - pandocfilters=1.4.3=py38h06a4308_1 211 | - parso=0.8.2=pyhd3eb1b0_0 212 | - partd=1.2.0=pyhd8ed1ab_0 213 | - patchelf=0.12=h2531618_1 214 | - pathlib=1.0.1=py38h578d9bd_4 215 | - patsy=0.5.1=py_0 216 | - pcre=8.45=h295c915_0 217 | - perl=5.26.2=h14c3975_0 218 | - pexpect=4.8.0=pyhd3eb1b0_3 219 | - pickleshare=0.7.5=pyhd3eb1b0_1003 220 | - pillow=8.3.1=py38h2c7a002_0 221 | - pip=21.2.2=py38h06a4308_0 222 | - pixman=0.40.0=h36c2ea0_0 223 | - pkginfo=1.7.1=py38h06a4308_0 224 | - pooch=1.5.1=pyhd8ed1ab_0 225 | - portalocker=1.7.0=py38h578d9bd_1 226 | - prometheus_client=0.11.0=pyhd3eb1b0_0 227 | - prompt-toolkit=3.0.17=pyh06a4308_0 228 | - protobuf=3.15.8=py38h709712a_0 229 | - psutil=5.8.0=py38h27cfd23_1 230 | - ptyprocess=0.7.0=pyhd3eb1b0_2 231 | - py-lief=0.10.1=py38h403a769_0 232 | - py-opencv=4.5.2=py38hd0cf306_0 233 | - pyasn1=0.4.8=py_0 234 | - pyasn1-modules=0.2.7=py_0 235 | - pycosat=0.6.3=py38h7b6447c_1 236 | - pycparser=2.20=py_2 237 | - pygments=2.10.0=pyhd3eb1b0_0 238 | - pyjwt=2.1.0=pyhd8ed1ab_0 239 | - pyopenssl=20.0.1=pyhd3eb1b0_1 240 | - pyparsing=2.4.7=pyhd3eb1b0_0 241 | - pypng=0.0.20=py_0 242 | - pyqt=5.12.3=py38h578d9bd_7 243 | - pyqt-impl=5.12.3=py38h7400c14_7 244 | - pyqt5-sip=4.19.18=py38h709712a_7 245 | - pyqtchart=5.12=py38h7400c14_7 246 | - pyqtwebengine=5.12.1=py38h7400c14_7 247 | - pyrsistent=0.17.3=py38h7b6447c_0 248 | - pysocks=1.7.1=py38h06a4308_0 249 | - python=3.8.10=h12debd9_8 250 | - python-dateutil=2.8.2=pyhd3eb1b0_0 251 | - python-libarchive-c=2.9=pyhd3eb1b0_1 252 | - python-lmdb=0.99=py38h709712a_0 253 | - python_abi=3.8=2_cp38 254 | - pytorch=1.9.0=py3.8_cuda11.1_cudnn8.0.5_0 255 | - pytz=2021.1=pyhd3eb1b0_0 256 | - pyu2f=0.1.5=pyhd8ed1ab_0 257 | - pywavelets=1.1.1=py38h5c078b8_3 258 | - pyyaml=5.4.1=py38h27cfd23_1 259 | - pyzmq=22.2.1=py38h295c915_1 260 | - qt=5.12.9=hda022c4_4 261 | - qtpy=1.9.0=py_0 262 | - readline=8.1=h27cfd23_0 263 | - regex=2021.8.28=py38h497a2fe_0 264 | - requests=2.25.1=pyhd3eb1b0_0 265 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 266 | - ripgrep=12.1.1=0 267 | - rsa=4.7.2=pyh44b312d_0 268 | - ruamel_yaml=0.15.100=py38h27cfd23_0 269 | - scikit-image=0.18.3=py38h43a58ef_0 270 | - scikit-learn=1.0=py38hacb3eff_1 271 | - scipy=1.7.1=py38h56a6a73_0 272 | - seaborn=0.11.2=hd8ed1ab_0 273 | - seaborn-base=0.11.2=pyhd8ed1ab_0 274 | - send2trash=1.5.0=pyhd3eb1b0_1 275 | - setuptools=52.0.0=py38h06a4308_0 276 | - shapely=1.8.0=py38hf7953bd_1 277 | - sip=4.19.13=py38he6710b0_0 278 | - sniffio=1.2.0=py38h06a4308_1 279 | - soupsieve=2.2.1=pyhd3eb1b0_0 280 | - sqlite=3.36.0=hc218d9a_0 281 | - statsmodels=0.12.2=py38h5c078b8_0 282 | - tensorboard=2.6.0=pyhd8ed1ab_1 283 | - tensorboard-data-server=0.6.0=py38h2b97feb_0 284 | - tensorboard-plugin-wit=1.8.0=pyh44b312d_0 285 | - tensorboardx=2.4=pyhd8ed1ab_0 286 | - terminado=0.9.4=py38h06a4308_0 287 | - testpath=0.5.0=pyhd3eb1b0_0 288 | - threadpoolctl=3.0.0=pyh8a188c0_0 289 | - tifffile=2019.7.26.2=py38_0 290 | - tk=8.6.10=hbc83047_0 291 | - toolz=0.11.1=py_0 292 | - torchfile=0.1.0=py_0 293 | - tornado=6.1=py38h27cfd23_0 294 | - tqdm=4.62.1=pyhd3eb1b0_1 295 | - traitlets=5.0.5=pyhd3eb1b0_0 296 | - typing_extensions=3.10.0.0=pyh06a4308_0 297 | - urllib3=1.26.6=pyhd3eb1b0_1 298 | - wcwidth=0.2.5=py_0 299 | - webencodings=0.5.1=py38_1 300 | - werkzeug=1.0.1=pyhd3eb1b0_0 301 | - wheel=0.37.0=pyhd3eb1b0_0 302 | - widgetsnbextension=3.5.1=py38_0 303 | - x264=1!161.3030=h7f98852_1 304 | - xmltodict=0.12.0=py_0 305 | - xz=5.2.5=h7b6447c_0 306 | - yacs=0.1.6=py_0 307 | - yaml=0.2.5=h7b6447c_0 308 | - yarl=1.6.3=py38h497a2fe_2 309 | - zeromq=4.3.4=h2531618_0 310 | - zipp=3.5.0=pyhd3eb1b0_0 311 | - zlib=1.2.11=h7b6447c_3 312 | - zstd=1.4.9=haebb681_0 313 | - pip: 314 | - addict==2.4.0 315 | - altair==4.2.0 316 | - astor==0.8.1 317 | - astunparse==1.6.3 318 | - backports-zoneinfo==0.2.1 319 | - base58==2.1.1 320 | - basicsr==1.3.4.1 321 | - boto3==1.18.33 322 | - botocore==1.21.33 323 | - clang==5.0 324 | - clean-fid==0.1.22 325 | - clip==1.0 326 | - colorama==0.4.4 327 | - commonmark==0.9.1 328 | - cython==0.29.30 329 | - einops==0.3.2 330 | - enum-compat==0.0.3 331 | - facexlib==0.2.0.3 332 | - filterpy==1.4.5 333 | - flatbuffers==1.12 334 | - gast==0.4.0 335 | - google-pasta==0.2.0 336 | - grpcio==1.39.0 337 | - h5py==3.1.0 338 | - ipdb==0.13.9 339 | - jacinle==1.0.0 340 | - jmespath==0.10.0 341 | - jsonpickle==2.2.0 342 | - keras==2.7.0 343 | - keras-preprocessing==1.1.2 344 | - libclang==12.0.0 345 | - llvmlite==0.37.0 346 | - lpips==0.1.4 347 | - numba==0.54.0 348 | - opencv-python==4.5.3.56 349 | - opt-einsum==3.3.0 350 | - pkgconfig==1.5.5 351 | - pyarrow==8.0.0 352 | - pydantic==1.8.2 353 | - pydeck==0.7.1 354 | - pyhocon==0.3.58 355 | - pytz-deprecation-shim==0.1.0.post0 356 | - pyvis==0.2.1 357 | - realesrgan==0.2.2.3 358 | - rich==10.9.0 359 | - s3transfer==0.5.0 360 | - six==1.15.0 361 | - sklearn==0.0 362 | - streamlit==0.64.0 363 | - tabulate==0.8.9 364 | - tb-nightly==2.7.0a20210827 365 | - tensorflow-estimator==2.7.0 366 | - tensorflow-gpu==2.7.0 367 | - tensorflow-io-gcs-filesystem==0.21.0 368 | - tensorfn==0.1.19 369 | - termcolor==1.1.0 370 | - toml==0.10.2 371 | - torchsample==0.1.3 372 | - torchvision==0.10.0+cu111 373 | - typing-extensions==3.7.4.3 374 | - tzdata==2022.1 375 | - tzlocal==4.2 376 | - validators==0.19.0 377 | - vit-pytorch==0.24.3 378 | - watchdog==2.1.8 379 | - wrapt==1.12.1 380 | - yapf==0.31.0 -------------------------------------------------------------------------------- /expansion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/__init__.py -------------------------------------------------------------------------------- /expansion/dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/dataloader/__init__.py -------------------------------------------------------------------------------- /expansion/dataloader/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/dataloader/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /expansion/dataloader/__pycache__/seqlist.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/dataloader/__pycache__/seqlist.cpython-38.pyc -------------------------------------------------------------------------------- /expansion/dataloader/chairslist.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import glob 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | l0_train = [] 20 | l1_train = [] 21 | flow_train = [] 22 | for flow_map in sorted(glob.glob(os.path.join(filepath,'*_flow.flo'))): 23 | root_filename = flow_map[:-9] 24 | img1 = root_filename+'_img1.ppm' 25 | img2 = root_filename+'_img2.ppm' 26 | if not (os.path.isfile(os.path.join(filepath,img1)) and os.path.isfile(os.path.join(filepath,img2))): 27 | continue 28 | 29 | l0_train.append(img1) 30 | l1_train.append(img2) 31 | flow_train.append(flow_map) 32 | 33 | return l0_train, l1_train, flow_train 34 | -------------------------------------------------------------------------------- /expansion/dataloader/chairssdlist.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import glob 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | l0_train = [] 20 | l1_train = [] 21 | flow_train = [] 22 | for flow_map in sorted(glob.glob('%s/flow/*.pfm'%filepath)): 23 | img1 = flow_map.replace('flow','t0').replace('.pfm','.png') 24 | img2 = flow_map.replace('flow','t1').replace('.pfm','.png') 25 | 26 | l0_train.append(img1) 27 | l1_train.append(img2) 28 | flow_train.append(flow_map) 29 | 30 | return l0_train, l1_train, flow_train 31 | -------------------------------------------------------------------------------- /expansion/dataloader/depthloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numbers 3 | import torch 4 | import torch.utils.data as data 5 | import torch 6 | import torchvision.transforms as transforms 7 | import random 8 | from PIL import Image, ImageOps 9 | import numpy as np 10 | import torchvision 11 | from . import depth_transforms as flow_transforms 12 | import pdb 13 | import cv2 14 | from utils.flowlib import read_flow 15 | from utils.util_flow import readPFM, load_calib_cam_to_cam 16 | 17 | def default_loader(path): 18 | return Image.open(path).convert('RGB') 19 | 20 | def flow_loader(path): 21 | if '.pfm' in path: 22 | data = readPFM(path)[0] 23 | data[:,:,2] = 1 24 | return data 25 | else: 26 | return read_flow(path) 27 | 28 | def load_exts(cam_file): 29 | with open(cam_file, 'r') as f: 30 | lines = f.readlines() 31 | 32 | l_exts = [] 33 | r_exts = [] 34 | for l in lines: 35 | if 'L ' in l: 36 | l_exts.append(np.asarray([float(i) for i in l[2:].strip().split(' ')]).reshape(4,4)) 37 | if 'R ' in l: 38 | r_exts.append(np.asarray([float(i) for i in l[2:].strip().split(' ')]).reshape(4,4)) 39 | return l_exts,r_exts 40 | 41 | def disparity_loader(path): 42 | if '.png' in path: 43 | data = Image.open(path) 44 | data = np.ascontiguousarray(data,dtype=np.float32)/256 45 | return data 46 | else: 47 | return readPFM(path)[0] 48 | 49 | # triangulation 50 | def triangulation(disp, xcoord, ycoord, bl=1, fl = 450, cx = 479.5, cy = 269.5): 51 | depth = bl*fl / disp # 450px->15mm focal length 52 | X = (xcoord - cx) * depth / fl 53 | Y = (ycoord - cy) * depth / fl 54 | Z = depth 55 | P = np.concatenate((X[np.newaxis],Y[np.newaxis],Z[np.newaxis]),0).reshape(3,-1) 56 | P = np.concatenate((P,np.ones((1,P.shape[-1]))),0) 57 | return P 58 | 59 | class myImageFloder(data.Dataset): 60 | def __init__(self, iml0, iml1, flowl0, loader=default_loader, dploader= flow_loader, scale=1.,shape=[320,448], order=1, noise=0.06, pca_augmentor=True, prob = 1.,sc=False,disp0=None,disp1=None,calib=None ): 61 | self.iml0 = iml0 62 | self.iml1 = iml1 63 | self.flowl0 = flowl0 64 | self.loader = loader 65 | self.dploader = dploader 66 | self.scale=scale 67 | self.shape=shape 68 | self.order=order 69 | self.noise = noise 70 | self.pca_augmentor = pca_augmentor 71 | self.prob = prob 72 | self.sc = sc 73 | self.disp0 = disp0 74 | self.disp1 = disp1 75 | self.calib = calib 76 | 77 | def __getitem__(self, index): 78 | iml0 = self.iml0[index] 79 | iml1 = self.iml1[index] 80 | flowl0= self.flowl0[index] 81 | th, tw = self.shape 82 | 83 | iml0 = self.loader(iml0) 84 | iml1 = self.loader(iml1) 85 | 86 | # get disparity 87 | if self.sc: 88 | flowl0 = self.dploader(flowl0) 89 | flowl0 = np.ascontiguousarray(flowl0,dtype=np.float32) 90 | flowl0[np.isnan(flowl0)] = 1e6 # set to max 91 | if 'camera_data.txt' in self.calib[index]: 92 | bl=1 93 | if '15mm_' in self.calib[index]: 94 | fl=450 # 450 95 | else: 96 | fl=1050 97 | cx = 479.5 98 | cy = 269.5 99 | # negative disp 100 | d1 = np.abs(disparity_loader(self.disp0[index])) 101 | d2 = np.abs(disparity_loader(self.disp1[index]) + d1) 102 | elif 'Sintel' in self.calib[index]: 103 | fl = 1000 104 | bl = 1 105 | cx = 511.5 106 | cy = 217.5 107 | d1 = np.zeros(flowl0.shape[:2]) 108 | d2 = np.zeros(flowl0.shape[:2]) 109 | else: 110 | ints = load_calib_cam_to_cam(self.calib[index]) 111 | fl = ints['K_cam2'][0,0] 112 | cx = ints['K_cam2'][0,2] 113 | cy = ints['K_cam2'][1,2] 114 | bl = ints['b20']-ints['b30'] 115 | d1 = disparity_loader(self.disp0[index]) 116 | d2 = disparity_loader(self.disp1[index]) 117 | #flowl0[:,:,2] = (flowl0[:,:,2]==1).astype(float) 118 | flowl0[:,:,2] = np.logical_and(np.logical_and(flowl0[:,:,2]==1, d1!=0), d2!=0).astype(float) 119 | 120 | shape = d1.shape 121 | mesh = np.meshgrid(range(shape[1]),range(shape[0])) 122 | xcoord = mesh[0].astype(float) 123 | ycoord = mesh[1].astype(float) 124 | 125 | # triangulation in two frames 126 | P0 = triangulation(d1, xcoord, ycoord, bl=bl, fl = fl, cx = cx, cy = cy) 127 | P1 = triangulation(d2, xcoord + flowl0[:,:,0], ycoord + flowl0[:,:,1], bl=bl, fl = fl, cx = cx, cy = cy) 128 | dis0 = P0[2] 129 | dis1 = P1[2] 130 | 131 | change_size = dis0.reshape(shape).astype(np.float32) 132 | flow3d = (P1-P0)[:3].reshape((3,)+shape).transpose((1,2,0)) 133 | 134 | gt_normal = np.concatenate((d1[:,:,np.newaxis],d2[:,:,np.newaxis],d2[:,:,np.newaxis]),-1) 135 | change_size = np.concatenate((change_size[:,:,np.newaxis],gt_normal,flow3d),2) 136 | else: 137 | shape = iml0.size 138 | shape=[shape[1],shape[0]] 139 | flowl0 = np.zeros((shape[0],shape[1],3)) 140 | change_size = np.zeros((shape[0],shape[1],7)) 141 | depth = disparity_loader(self.iml1[index].replace('camera','groundtruth')) 142 | change_size[:,:,0] = depth 143 | 144 | seqid = self.iml0[index].split('/')[-5].rsplit('_',3)[0] 145 | ints = load_calib_cam_to_cam('/data/gengshay/KITTI/%s/calib_cam_to_cam.txt'%seqid) 146 | fl = ints['K_cam2'][0,0] 147 | cx = ints['K_cam2'][0,2] 148 | cy = ints['K_cam2'][1,2] 149 | bl = ints['b20']-ints['b30'] 150 | 151 | 152 | iml1 = np.asarray(iml1)/255. 153 | iml0 = np.asarray(iml0)/255. 154 | iml0 = iml0[:,:,::-1].copy() 155 | iml1 = iml1[:,:,::-1].copy() 156 | 157 | ## following data augmentation procedure in PWCNet 158 | ## https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu 159 | import __main__ # a workaround for "discount_coeff" 160 | try: 161 | with open('/scratch/gengshay/iter_counts-%d.txt'%int(__main__.args.logname.split('-')[-1]), 'r') as f: 162 | iter_counts = int(f.readline()) 163 | except: 164 | iter_counts = 0 165 | schedule = [0.5, 1., 50000.] # initial coeff, final_coeff, half life 166 | schedule_coeff = schedule[0] + (schedule[1] - schedule[0]) * \ 167 | (2/(1+np.exp(-1.0986*iter_counts/schedule[2])) - 1) 168 | 169 | if self.pca_augmentor: 170 | pca_augmentor = flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff) 171 | else: 172 | pca_augmentor = flow_transforms.Scale(1., order=0) 173 | 174 | if np.random.binomial(1,self.prob): 175 | co_transform1 = flow_transforms.Compose([ 176 | flow_transforms.SpatialAug([th,tw], 177 | scale=[0.2,0.,0.1], 178 | rot=[0.4,0.], 179 | trans=[0.4,0.], 180 | squeeze=[0.3,0.], schedule_coeff=schedule_coeff, order=self.order), 181 | ]) 182 | else: 183 | co_transform1 = flow_transforms.Compose([ 184 | flow_transforms.RandomCrop([th,tw]), 185 | ]) 186 | 187 | co_transform2 = flow_transforms.Compose([ 188 | flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff), 189 | #flow_transforms.PCAAug(schedule_coeff=schedule_coeff), 190 | flow_transforms.ChromaticAug( schedule_coeff=schedule_coeff, noise=self.noise), 191 | ]) 192 | 193 | flowl0 = np.concatenate([flowl0,change_size],-1) 194 | augmented,flowl0,intr = co_transform1([iml0, iml1], flowl0, [fl,cx,cy,bl]) 195 | imol0 = augmented[0] 196 | imol1 = augmented[1] 197 | augmented,flowl0,intr = co_transform2(augmented, flowl0, intr) 198 | 199 | iml0 = augmented[0] 200 | iml1 = augmented[1] 201 | flowl0 = flowl0.astype(np.float32) 202 | change_size = flowl0[:,:,3:] 203 | flowl0 = flowl0[:,:,:3] 204 | 205 | # randomly cover a region 206 | sx=0;sy=0;cx=0;cy=0 207 | if np.random.binomial(1,0.5): 208 | sx = int(np.random.uniform(25,100)) 209 | sy = int(np.random.uniform(25,100)) 210 | #sx = int(np.random.uniform(50,150)) 211 | #sy = int(np.random.uniform(50,150)) 212 | cx = int(np.random.uniform(sx,iml1.shape[0]-sx)) 213 | cy = int(np.random.uniform(sy,iml1.shape[1]-sy)) 214 | iml1[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(iml1,0),0)[np.newaxis,np.newaxis] 215 | 216 | iml0 = torch.Tensor(np.transpose(iml0,(2,0,1))) 217 | iml1 = torch.Tensor(np.transpose(iml1,(2,0,1))) 218 | 219 | return iml0, iml1, flowl0, change_size, intr, imol0, imol1, np.asarray([cx-sx,cx+sx,cy-sy,cy+sy]) 220 | 221 | def __len__(self): 222 | return len(self.iml0) 223 | -------------------------------------------------------------------------------- /expansion/dataloader/hd1klist.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import pdb 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | 20 | left_fold = 'image_2/' 21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('HD1K2018') > -1] 22 | train = sorted(train) 23 | 24 | l0_train = [filepath+left_fold+img for img in train] 25 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%04d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] 26 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%04d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] 27 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train] 28 | 29 | return l0_train, l1_train, flow_train 30 | -------------------------------------------------------------------------------- /expansion/dataloader/kitti12list.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def dataloader(filepath): 18 | 19 | left_fold = 'colored_0/' 20 | flow_noc = 'flow_occ/' 21 | 22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 23 | 24 | l0_train = [filepath+left_fold+img for img in train] 25 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] 26 | flow_train = [filepath+flow_noc+img for img in train] 27 | 28 | 29 | return l0_train, l1_train, flow_train 30 | -------------------------------------------------------------------------------- /expansion/dataloader/kitti15list.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def dataloader(filepath): 18 | 19 | left_fold = 'image_2/' 20 | flow_noc = 'flow_occ/' 21 | 22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 23 | 24 | l0_train = [filepath+left_fold+img for img in train] 25 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] 26 | flow_train = [filepath+flow_noc+img for img in train] 27 | 28 | 29 | return sorted(l0_train), sorted(l1_train), sorted(flow_train) 30 | -------------------------------------------------------------------------------- /expansion/dataloader/kitti15list_train.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def dataloader(filepath): 18 | 19 | left_fold = 'image_2/' 20 | flow_noc = 'flow_occ/' 21 | 22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 23 | 24 | train = [i for i in train if int(i.split('_')[0])%5!=0] 25 | 26 | l0_train = [filepath+left_fold+img for img in train] 27 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] 28 | flow_train = [filepath+flow_noc+img for img in train] 29 | 30 | 31 | return sorted(l0_train), sorted(l1_train), sorted(flow_train) 32 | -------------------------------------------------------------------------------- /expansion/dataloader/kitti15list_train_lidar.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def dataloader(filepath): 18 | 19 | left_fold = 'image_2/' 20 | flow_noc = 'flow_occ/' 21 | 22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 23 | 24 | # train = [i for i in train if int(i.split('_')[0])%5!=0] 25 | with open('/data/gengshay/kitti_scene/devkit/mapping/train_mapping.txt','r') as f: 26 | flags = [True if len(i)>1 else False for i in f.readlines()] 27 | train = [fn for (it,fn) in enumerate(sorted(train)) if flags[it] ][:100] 28 | 29 | l0_train = [filepath+left_fold+img for img in train] 30 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] 31 | flow_train = [filepath+flow_noc+img for img in train] 32 | 33 | 34 | return sorted(l0_train), sorted(l1_train), sorted(flow_train) 35 | -------------------------------------------------------------------------------- /expansion/dataloader/kitti15list_val.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def dataloader(filepath): 18 | 19 | left_fold = 'image_2/' 20 | flow_noc = 'flow_occ/' 21 | 22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 23 | 24 | train = [i for i in train if int(i.split('_')[0])%5==0] 25 | 26 | l0_train = [filepath+left_fold+img for img in train] 27 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] 28 | flow_train = [filepath+flow_noc+img for img in train] 29 | 30 | 31 | return sorted(l0_train), sorted(l1_train), sorted(flow_train) 32 | -------------------------------------------------------------------------------- /expansion/dataloader/kitti15list_val_lidar.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def dataloader(filepath): 18 | 19 | left_fold = 'image_2/' 20 | flow_noc = 'flow_occ/' 21 | 22 | train = [img for img in os.listdir(filepath+left_fold) if img.find('_10') > -1] 23 | 24 | # train = [i for i in train if int(i.split('_')[0])%5!=0] 25 | with open('/data/gengshay/kitti_scene/devkit/mapping/train_mapping.txt','r') as f: 26 | flags = [True if len(i)>1 else False for i in f.readlines()] 27 | train = [fn for (it,fn) in enumerate(sorted(train)) if flags[it] ][100:] 28 | 29 | l0_train = [filepath+left_fold+img for img in train] 30 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] 31 | flow_train = [filepath+flow_noc+img for img in train] 32 | 33 | 34 | return sorted(l0_train), sorted(l1_train), sorted(flow_train) 35 | -------------------------------------------------------------------------------- /expansion/dataloader/kitti15list_val_mr.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def dataloader(filepath): 18 | 19 | left_fold = 'image_2/' 20 | flow_noc = 'flow_occ/' 21 | 22 | train = [img for img in os.listdir(filepath+left_fold) if 'Kitti' in img and img.find('_10') > -1] 23 | 24 | # train = [i for i in train if int(i.split('_')[1])%5==0] 25 | import pdb; pdb.set_trace() 26 | train = sorted([i for i in train if int(i.split('_')[1])%5==0])[0:1] 27 | 28 | l0_train = [filepath+left_fold+img for img in train] 29 | l1_train = [filepath+left_fold+img.replace('_10','_11') for img in train] 30 | flow_train = [filepath+flow_noc+img for img in train] 31 | 32 | l0_train += [filepath+left_fold+img.replace('_10','_09') for img in train] 33 | l1_train += [filepath+left_fold+img for img in train] 34 | flow_train += flow_train 35 | 36 | tmp = l0_train 37 | l0_train = l0_train+ [i.replace('rob_flow', 'kitti_scene').replace('Kitti2015_','') for i in l1_train] 38 | l1_train = l1_train+tmp 39 | flow_train += flow_train 40 | 41 | return l0_train, l1_train, flow_train 42 | -------------------------------------------------------------------------------- /expansion/dataloader/robloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numbers 3 | import torch 4 | import torch.utils.data as data 5 | import torch 6 | import torchvision.transforms as transforms 7 | import random 8 | from PIL import Image, ImageOps 9 | import numpy as np 10 | import torchvision 11 | from . import flow_transforms 12 | import pdb 13 | import cv2 14 | from utils.flowlib import read_flow 15 | from utils.util_flow import readPFM 16 | 17 | 18 | def default_loader(path): 19 | return Image.open(path).convert('RGB') 20 | 21 | def flow_loader(path): 22 | if '.pfm' in path: 23 | data = readPFM(path)[0] 24 | data[:,:,2] = 1 25 | return data 26 | else: 27 | return read_flow(path) 28 | 29 | 30 | def disparity_loader(path): 31 | if '.png' in path: 32 | data = Image.open(path) 33 | data = np.ascontiguousarray(data,dtype=np.float32)/256 34 | return data 35 | else: 36 | return readPFM(path)[0] 37 | 38 | class myImageFloder(data.Dataset): 39 | def __init__(self, iml0, iml1, flowl0, loader=default_loader, dploader= flow_loader, scale=1.,shape=[320,448], order=1, noise=0.06, pca_augmentor=True, prob = 1., cover=False, black=False, scale_aug=[0.4,0.2]): 40 | self.iml0 = iml0 41 | self.iml1 = iml1 42 | self.flowl0 = flowl0 43 | self.loader = loader 44 | self.dploader = dploader 45 | self.scale=scale 46 | self.shape=shape 47 | self.order=order 48 | self.noise = noise 49 | self.pca_augmentor = pca_augmentor 50 | self.prob = prob 51 | self.cover = cover 52 | self.black = black 53 | self.scale_aug = scale_aug 54 | 55 | def __getitem__(self, index): 56 | iml0 = self.iml0[index] 57 | iml1 = self.iml1[index] 58 | flowl0= self.flowl0[index] 59 | th, tw = self.shape 60 | 61 | iml0 = self.loader(iml0) 62 | iml1 = self.loader(iml1) 63 | iml1 = np.asarray(iml1)/255. 64 | iml0 = np.asarray(iml0)/255. 65 | iml0 = iml0[:,:,::-1].copy() 66 | iml1 = iml1[:,:,::-1].copy() 67 | flowl0 = self.dploader(flowl0) 68 | #flowl0[:,:,-1][flowl0[:,:,0]==np.inf]=0 # for gtav window pfm files 69 | #flowl0[:,:,0][~flowl0[:,:,2].astype(bool)]=0 70 | #flowl0[:,:,1][~flowl0[:,:,2].astype(bool)]=0 # avoid nan in grad 71 | flowl0 = np.ascontiguousarray(flowl0,dtype=np.float32) 72 | flowl0[np.isnan(flowl0)] = 1e6 # set to max 73 | 74 | ## following data augmentation procedure in PWCNet 75 | ## https://github.com/lmb-freiburg/flownet2/blob/master/src/caffe/layers/data_augmentation_layer.cu 76 | import __main__ # a workaround for "discount_coeff" 77 | try: 78 | with open('iter_counts-%d.txt'%int(__main__.args.logname.split('-')[-1]), 'r') as f: 79 | iter_counts = int(f.readline()) 80 | except: 81 | iter_counts = 0 82 | schedule = [0.5, 1., 50000.] # initial coeff, final_coeff, half life 83 | schedule_coeff = schedule[0] + (schedule[1] - schedule[0]) * \ 84 | (2/(1+np.exp(-1.0986*iter_counts/schedule[2])) - 1) 85 | 86 | if self.pca_augmentor: 87 | pca_augmentor = flow_transforms.pseudoPCAAug( schedule_coeff=schedule_coeff) 88 | else: 89 | pca_augmentor = flow_transforms.Scale(1., order=0) 90 | 91 | if np.random.binomial(1,self.prob): 92 | co_transform = flow_transforms.Compose([ 93 | flow_transforms.Scale(self.scale, order=self.order), 94 | #flow_transforms.SpatialAug([th,tw], trans=[0.2,0.03], order=self.order, black=self.black), 95 | flow_transforms.SpatialAug([th,tw],scale=[self.scale_aug[0],0.03,self.scale_aug[1]], 96 | rot=[0.4,0.03], 97 | trans=[0.4,0.03], 98 | squeeze=[0.3,0.], schedule_coeff=schedule_coeff, order=self.order, black=self.black), 99 | #flow_transforms.pseudoPCAAug(schedule_coeff=schedule_coeff), 100 | flow_transforms.PCAAug(schedule_coeff=schedule_coeff), 101 | flow_transforms.ChromaticAug( schedule_coeff=schedule_coeff, noise=self.noise), 102 | ]) 103 | else: 104 | co_transform = flow_transforms.Compose([ 105 | flow_transforms.Scale(self.scale, order=self.order), 106 | flow_transforms.SpatialAug([th,tw], trans=[0.4,0.03], order=self.order, black=self.black) 107 | ]) 108 | 109 | augmented,flowl0 = co_transform([iml0, iml1], flowl0) 110 | iml0 = augmented[0] 111 | iml1 = augmented[1] 112 | 113 | if self.cover: 114 | ## randomly cover a region 115 | # following sec. 3.2 of http://openaccess.thecvf.com/content_CVPR_2019/html/Yang_Hierarchical_Deep_Stereo_Matching_on_High-Resolution_Images_CVPR_2019_paper.html 116 | if np.random.binomial(1,0.5): 117 | #sx = int(np.random.uniform(25,100)) 118 | #sy = int(np.random.uniform(25,100)) 119 | sx = int(np.random.uniform(50,125)) 120 | sy = int(np.random.uniform(50,125)) 121 | #sx = int(np.random.uniform(50,150)) 122 | #sy = int(np.random.uniform(50,150)) 123 | cx = int(np.random.uniform(sx,iml1.shape[0]-sx)) 124 | cy = int(np.random.uniform(sy,iml1.shape[1]-sy)) 125 | iml1[cx-sx:cx+sx,cy-sy:cy+sy] = np.mean(np.mean(iml1,0),0)[np.newaxis,np.newaxis] 126 | 127 | iml0 = torch.Tensor(np.transpose(iml0,(2,0,1))) 128 | iml1 = torch.Tensor(np.transpose(iml1,(2,0,1))) 129 | 130 | return iml0, iml1, flowl0 131 | 132 | def __len__(self): 133 | return len(self.iml0) 134 | -------------------------------------------------------------------------------- /expansion/dataloader/sceneflowlist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import glob 4 | 5 | def dataloader(filepath, level=6): 6 | iml0 = [] 7 | iml1 = [] 8 | flowl0 = [] 9 | disp0 = [] 10 | dispc = [] 11 | calib = [] 12 | level_stars = '/*'*level 13 | candidate_pool = glob.glob('%s/optical_flow%s'%(filepath,level_stars)) 14 | for flow_path in sorted(candidate_pool): 15 | if 'TEST' in flow_path: continue 16 | if 'flower_storm_x2/into_future/right/OpticalFlowIntoFuture_0023_R.pfm' in flow_path: 17 | continue 18 | if 'flower_storm_x2/into_future/left/OpticalFlowIntoFuture_0023_L.pfm' in flow_path: 19 | continue 20 | if 'flower_storm_augmented0_x2/into_future/right/OpticalFlowIntoFuture_0023_R.pfm' in flow_path: 21 | continue 22 | if 'flower_storm_augmented0_x2/into_future/left/OpticalFlowIntoFuture_0023_L.pfm' in flow_path: 23 | continue 24 | if 'FlyingThings' in flow_path and '_0014_' in flow_path: 25 | continue 26 | if 'FlyingThings' in flow_path and '_0015_' in flow_path: 27 | continue 28 | idd = flow_path.split('/')[-1].split('_')[-2] 29 | if 'into_future' in flow_path: 30 | idd_p1 = '%04d'%(int(idd)+1) 31 | else: 32 | idd_p1 = '%04d'%(int(idd)-1) 33 | if os.path.exists(flow_path.replace(idd,idd_p1)): 34 | d0_path = flow_path.replace('/into_future/','/').replace('/into_past/','/').replace('optical_flow','disparity') 35 | d0_path = '%s/%s.pfm'%(d0_path.rsplit('/',1)[0],idd) 36 | dc_path = flow_path.replace('optical_flow','disparity_change') 37 | dc_path = '%s/%s.pfm'%(dc_path.rsplit('/',1)[0],idd) 38 | im_path = flow_path.replace('/into_future/','/').replace('/into_past/','/').replace('optical_flow','frames_cleanpass') 39 | im0_path = '%s/%s.png'%(im_path.rsplit('/',1)[0],idd) 40 | im1_path = '%s/%s.png'%(im_path.rsplit('/',1)[0],idd_p1) 41 | #with open('%s/camera_data.txt'%(im0_path.replace('frames_cleanpass','camera_data').rsplit('/',2)[0]),'r') as f: 42 | # if 'FlyingThings' in flow_path and len(f.readlines())!=40: 43 | # print(flow_path) 44 | # continue 45 | iml0.append(im0_path) 46 | iml1.append(im1_path) 47 | flowl0.append(flow_path) 48 | disp0.append(d0_path) 49 | dispc.append(dc_path) 50 | calib.append('%s/camera_data.txt'%(im0_path.replace('frames_cleanpass','camera_data').rsplit('/',2)[0])) 51 | return iml0, iml1, flowl0, disp0, dispc, calib 52 | -------------------------------------------------------------------------------- /expansion/dataloader/seqlist.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import glob 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | 20 | train = [img for img in sorted(glob.glob('%s/*'%filepath))] 21 | 22 | l0_train = train[:-1] 23 | l1_train = train[1:] 24 | 25 | 26 | return sorted(l0_train), sorted(l1_train), sorted(l0_train) 27 | -------------------------------------------------------------------------------- /expansion/dataloader/sintellist.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import pdb 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | 20 | left_fold = 'image_2/' 21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1] 22 | 23 | l0_train = [filepath+left_fold+img for img in train] 24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] 25 | 26 | #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val 27 | 28 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] 29 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train] 30 | 31 | 32 | return l0_train, l1_train, flow_train 33 | -------------------------------------------------------------------------------- /expansion/dataloader/sintellist_clean.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import pdb 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | 20 | left_fold = 'image_2/' 21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel_clean') > -1] 22 | 23 | l0_train = [filepath+left_fold+img for img in train] 24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] 25 | 26 | #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val 27 | 28 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] 29 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train] 30 | 31 | return l0_train, l1_train, flow_train 32 | -------------------------------------------------------------------------------- /expansion/dataloader/sintellist_final.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import pdb 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | 20 | left_fold = 'image_2/' 21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel_final') > -1] 22 | 23 | l0_train = [filepath+left_fold+img for img in train] 24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] 25 | 26 | #l0_train = [i for i in l0_train if not '10.png' in i] # remove 10 as val 27 | 28 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] 29 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train] 30 | 31 | pdb.set_trace() 32 | return l0_train, l1_train, flow_train 33 | -------------------------------------------------------------------------------- /expansion/dataloader/sintellist_train.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import pdb 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | 20 | left_fold = 'image_2/' 21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1] 22 | 23 | l0_train = [filepath+left_fold+img for img in train] 24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] 25 | 26 | l0_train = [i for i in l0_train if not(('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i))] # remove 10 as val 27 | 28 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] 29 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train] 30 | 31 | 32 | return l0_train, l1_train, flow_train 33 | -------------------------------------------------------------------------------- /expansion/dataloader/sintellist_val.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | import pdb 8 | 9 | IMG_EXTENSIONS = [ 10 | '.jpg', '.JPG', '.jpeg', '.JPEG', 11 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 12 | ] 13 | 14 | 15 | def is_image_file(filename): 16 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 17 | 18 | def dataloader(filepath): 19 | 20 | left_fold = 'image_2/' 21 | train = [img for img in os.listdir(filepath+left_fold) if img.find('Sintel') > -1] 22 | 23 | l0_train = [filepath+left_fold+img for img in train] 24 | l0_train = [img for img in l0_train if '%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) in l0_train ] 25 | 26 | l0_train = [i for i in l0_train if ('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i)] # remove 10 as val 27 | #l0_train = [i for i in l0_train if not(('_2_' in i) and ('alley' not in i) and ('bandage' not in i) and ('sleeping' not in i))] # remove 10 as val 28 | 29 | l1_train = ['%s_%s.png'%(img.rsplit('_',1)[0],'%02d'%(1+int(img.split('.')[0].split('_')[-1])) ) for img in l0_train] 30 | flow_train = [img.replace('image_2','flow_occ') for img in l0_train] 31 | 32 | 33 | return sorted(l0_train)[::3], sorted(l1_train)[::3], sorted(flow_train)[::3] 34 | # return sorted(l0_train)[::10], sorted(l1_train)[::10], sorted(flow_train)[::10] 35 | -------------------------------------------------------------------------------- /expansion/dataloader/thingslist.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | 3 | from PIL import Image 4 | import os 5 | import os.path 6 | import numpy as np 7 | 8 | IMG_EXTENSIONS = [ 9 | '.jpg', '.JPG', '.jpeg', '.JPEG', 10 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 11 | ] 12 | 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 16 | 17 | def dataloader(filepath): 18 | exc_list = [ 19 | '0004117.flo', 20 | '0003149.flo', 21 | '0001203.flo', 22 | '0003147.flo', 23 | '0003666.flo', 24 | '0006337.flo', 25 | '0006336.flo', 26 | '0007126.flo', 27 | '0004118.flo', 28 | ] 29 | 30 | left_fold = 'image_clean/left/' 31 | flow_noc = 'flow/left/into_future/' 32 | train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0] 33 | 34 | l0_trainlf = [filepath+left_fold+img.replace('flo','png') for img in train] 35 | l1_trainlf = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainlf] 36 | flow_trainlf = [filepath+flow_noc+img for img in train] 37 | 38 | 39 | exc_list = [ 40 | '0003148.flo', 41 | '0004117.flo', 42 | '0002890.flo', 43 | '0003149.flo', 44 | '0001203.flo', 45 | '0003666.flo', 46 | '0006337.flo', 47 | '0006336.flo', 48 | '0004118.flo', 49 | ] 50 | 51 | left_fold = 'image_clean/right/' 52 | flow_noc = 'flow/right/into_future/' 53 | train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0] 54 | 55 | l0_trainrf = [filepath+left_fold+img.replace('flo','png') for img in train] 56 | l1_trainrf = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainrf] 57 | flow_trainrf = [filepath+flow_noc+img for img in train] 58 | 59 | 60 | exc_list = [ 61 | '0004237.flo', 62 | '0004705.flo', 63 | '0004045.flo', 64 | '0004346.flo', 65 | '0000161.flo', 66 | '0000931.flo', 67 | '0000121.flo', 68 | '0010822.flo', 69 | '0004117.flo', 70 | '0006023.flo', 71 | '0005034.flo', 72 | '0005054.flo', 73 | '0000162.flo', 74 | '0000053.flo', 75 | '0005055.flo', 76 | '0003147.flo', 77 | '0004876.flo', 78 | '0000163.flo', 79 | '0006878.flo', 80 | ] 81 | 82 | left_fold = 'image_clean/left/' 83 | flow_noc = 'flow/left/into_past/' 84 | train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0] 85 | 86 | l0_trainlp = [filepath+left_fold+img.replace('flo','png') for img in train] 87 | l1_trainlp = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(-1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainlp] 88 | flow_trainlp = [filepath+flow_noc+img for img in train] 89 | 90 | exc_list = [ 91 | '0003148.flo', 92 | '0004705.flo', 93 | '0000161.flo', 94 | '0000121.flo', 95 | '0004117.flo', 96 | '0000160.flo', 97 | '0005034.flo', 98 | '0005054.flo', 99 | '0000162.flo', 100 | '0000053.flo', 101 | '0005055.flo', 102 | '0003147.flo', 103 | '0001549.flo', 104 | '0000163.flo', 105 | '0006336.flo', 106 | '0001648.flo', 107 | '0006878.flo', 108 | ] 109 | 110 | left_fold = 'image_clean/right/' 111 | flow_noc = 'flow/right/into_past/' 112 | train = [img for img in os.listdir(filepath+flow_noc) if np.sum([(k in img) for k in exc_list])==0] 113 | 114 | l0_trainrp = [filepath+left_fold+img.replace('flo','png') for img in train] 115 | l1_trainrp = ['%s/%s.png'%(img.rsplit('/',1)[0],'%07d'%(-1+int(img.split('.')[0].split('/')[-1])) ) for img in l0_trainrp] 116 | flow_trainrp = [filepath+flow_noc+img for img in train] 117 | 118 | 119 | l0_train = l0_trainlf + l0_trainrf + l0_trainlp + l0_trainrp 120 | l1_train = l1_trainlf + l1_trainrf + l1_trainlp + l1_trainrp 121 | flow_train = flow_trainlf + flow_trainrf + flow_trainlp + flow_trainrp 122 | return l0_train, l1_train, flow_train 123 | -------------------------------------------------------------------------------- /expansion/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/models/__init__.py -------------------------------------------------------------------------------- /expansion/models/__pycache__/VCN_exp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/models/__pycache__/VCN_exp.cpython-38.pyc -------------------------------------------------------------------------------- /expansion/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /expansion/models/__pycache__/conv4d.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/models/__pycache__/conv4d.cpython-38.pyc -------------------------------------------------------------------------------- /expansion/models/__pycache__/submodule.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/models/__pycache__/submodule.cpython-38.pyc -------------------------------------------------------------------------------- /expansion/models/conv4d.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import torch.nn as nn 3 | import math 4 | import torch 5 | from torch.nn.parameter import Parameter 6 | import torch.nn.functional as F 7 | from torch.nn import Module 8 | from torch.nn.modules.conv import _ConvNd 9 | from torch.nn.modules.utils import _quadruple 10 | from torch.autograd import Variable 11 | from torch.nn import Conv2d 12 | 13 | def conv4d(data,filters,bias=None,permute_filters=True,use_half=False): 14 | """ 15 | This is done by stacking results of multiple 3D convolutions, and is very slow. 16 | Taken from https://github.com/ignacio-rocco/ncnet 17 | """ 18 | b,c,h,w,d,t=data.size() 19 | 20 | data=data.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop 21 | 22 | # Same permutation is done with filters, unless already provided with permutation 23 | if permute_filters: 24 | filters=filters.permute(2,0,1,3,4,5).contiguous() # permute to avoid making contiguous inside loop 25 | 26 | c_out=filters.size(1) 27 | if use_half: 28 | output = Variable(torch.HalfTensor(h,b,c_out,w,d,t),requires_grad=data.requires_grad) 29 | else: 30 | output = Variable(torch.zeros(h,b,c_out,w,d,t),requires_grad=data.requires_grad) 31 | 32 | padding=filters.size(0)//2 33 | if use_half: 34 | Z=Variable(torch.zeros(padding,b,c,w,d,t).half()) 35 | else: 36 | Z=Variable(torch.zeros(padding,b,c,w,d,t)) 37 | 38 | if data.is_cuda: 39 | Z=Z.cuda(data.get_device()) 40 | output=output.cuda(data.get_device()) 41 | 42 | data_padded = torch.cat((Z,data,Z),0) 43 | 44 | 45 | for i in range(output.size(0)): # loop on first feature dimension 46 | # convolve with center channel of filter (at position=padding) 47 | output[i,:,:,:,:,:]=F.conv3d(data_padded[i+padding,:,:,:,:,:], 48 | filters[padding,:,:,:,:,:], bias=bias, stride=1, padding=padding) 49 | # convolve with upper/lower channels of filter (at postions [:padding] [padding+1:]) 50 | for p in range(1,padding+1): 51 | output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding-p,:,:,:,:,:], 52 | filters[padding-p,:,:,:,:,:], bias=None, stride=1, padding=padding) 53 | output[i,:,:,:,:,:]=output[i,:,:,:,:,:]+F.conv3d(data_padded[i+padding+p,:,:,:,:,:], 54 | filters[padding+p,:,:,:,:,:], bias=None, stride=1, padding=padding) 55 | 56 | output=output.permute(1,2,0,3,4,5).contiguous() 57 | return output 58 | 59 | class Conv4d(_ConvNd): 60 | """Applies a 4D convolution over an input signal composed of several input 61 | planes. 62 | """ 63 | 64 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True): 65 | # stride, dilation and groups !=1 functionality not tested 66 | stride=1 67 | dilation=1 68 | groups=1 69 | # zero padding is added automatically in conv4d function to preserve tensor size 70 | padding = 0 71 | kernel_size = _quadruple(kernel_size) 72 | stride = _quadruple(stride) 73 | padding = _quadruple(padding) 74 | dilation = _quadruple(dilation) 75 | super(Conv4d, self).__init__( 76 | in_channels, out_channels, kernel_size, stride, padding, dilation, 77 | False, _quadruple(0), groups, bias) 78 | # weights will be sliced along one dimension during convolution loop 79 | # make the looping dimension to be the first one in the tensor, 80 | # so that we don't need to call contiguous() inside the loop 81 | self.pre_permuted_filters=pre_permuted_filters 82 | if self.pre_permuted_filters: 83 | self.weight.data=self.weight.data.permute(2,0,1,3,4,5).contiguous() 84 | self.use_half=False 85 | # self.isbias = bias 86 | # if not self.isbias: 87 | # self.bn = torch.nn.BatchNorm1d(out_channels) 88 | 89 | 90 | def forward(self, input): 91 | out = conv4d(input, self.weight, bias=self.bias,permute_filters=not self.pre_permuted_filters,use_half=self.use_half) # filters pre-permuted in constructor 92 | # if not self.isbias: 93 | # b,c,u,v,h,w = out.shape 94 | # out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w) 95 | return out 96 | 97 | class fullConv4d(torch.nn.Module): 98 | def __init__(self, in_channels, out_channels, kernel_size, bias=True, pre_permuted_filters=True): 99 | super(fullConv4d, self).__init__() 100 | self.conv = Conv4d(in_channels, out_channels, kernel_size, bias=bias, pre_permuted_filters=pre_permuted_filters) 101 | self.isbias = bias 102 | if not self.isbias: 103 | self.bn = torch.nn.BatchNorm1d(out_channels) 104 | 105 | def forward(self, input): 106 | out = self.conv(input) 107 | if not self.isbias: 108 | b,c,u,v,h,w = out.shape 109 | out = self.bn(out.view(b,c,-1)).view(b,c,u,v,h,w) 110 | return out 111 | 112 | class butterfly4D(torch.nn.Module): 113 | ''' 114 | butterfly 4d 115 | ''' 116 | def __init__(self, fdima, fdimb, withbn=True, full=True,groups=1): 117 | super(butterfly4D, self).__init__() 118 | self.proj = nn.Sequential(projfeat4d(fdima, fdimb, 1, with_bn=withbn,groups=groups), 119 | nn.ReLU(inplace=True),) 120 | self.conva1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups) 121 | self.conva2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(2,1,1),full=full,groups=groups) 122 | self.convb3 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups) 123 | self.convb2 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups) 124 | self.convb1 = sepConv4dBlock(fdimb,fdimb,with_bn=withbn, stride=(1,1,1),full=full,groups=groups) 125 | 126 | #@profile 127 | def forward(self,x): 128 | out = self.proj(x) 129 | b,c,u,v,h,w = out.shape # 9x9 130 | 131 | out1 = self.conva1(out) # 5x5, 3 132 | _,c1,u1,v1,h1,w1 = out1.shape 133 | 134 | out2 = self.conva2(out1) # 3x3, 9 135 | _,c2,u2,v2,h2,w2 = out2.shape 136 | 137 | out2 = self.convb3(out2) # 3x3, 9 138 | 139 | tout1 = F.upsample(out2.view(b,c,u2,v2,-1),(u1,v1,h2*w2),mode='trilinear').view(b,c,u1,v1,h2,w2) # 5x5 140 | tout1 = F.upsample(tout1.view(b,c,-1,h2,w2),(u1*v1,h1,w1),mode='trilinear').view(b,c,u1,v1,h1,w1) # 5x5 141 | out1 = tout1 + out1 142 | out1 = self.convb2(out1) 143 | 144 | tout = F.upsample(out1.view(b,c,u1,v1,-1),(u,v,h1*w1),mode='trilinear').view(b,c,u,v,h1,w1) 145 | tout = F.upsample(tout.view(b,c,-1,h1,w1),(u*v,h,w),mode='trilinear').view(b,c,u,v,h,w) 146 | out = tout + out 147 | out = self.convb1(out) 148 | 149 | return out 150 | 151 | 152 | 153 | class projfeat4d(torch.nn.Module): 154 | ''' 155 | Turn 3d projection into 2d projection 156 | ''' 157 | def __init__(self, in_planes, out_planes, stride, with_bn=True,groups=1): 158 | super(projfeat4d, self).__init__() 159 | self.with_bn = with_bn 160 | self.stride = stride 161 | self.conv1 = nn.Conv3d(in_planes, out_planes, 1, (stride,stride,1), padding=0,bias=not with_bn,groups=groups) 162 | self.bn = nn.BatchNorm3d(out_planes) 163 | 164 | def forward(self,x): 165 | b,c,u,v,h,w = x.size() 166 | x = self.conv1(x.view(b,c,u,v,h*w)) 167 | if self.with_bn: 168 | x = self.bn(x) 169 | _,c,u,v,_ = x.shape 170 | x = x.view(b,c,u,v,h,w) 171 | return x 172 | 173 | class sepConv4d(torch.nn.Module): 174 | ''' 175 | Separable 4d convolution block as 2 3D convolutions 176 | ''' 177 | def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, ksize=3, full=True,groups=1): 178 | super(sepConv4d, self).__init__() 179 | bias = not with_bn 180 | self.isproj = False 181 | self.stride = stride[0] 182 | expand = 1 183 | 184 | if with_bn: 185 | if in_planes != out_planes: 186 | self.isproj = True 187 | self.proj = nn.Sequential(nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups), 188 | nn.BatchNorm2d(out_planes)) 189 | if full: 190 | self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups), 191 | nn.BatchNorm3d(in_planes)) 192 | else: 193 | self.conv1 = nn.Sequential(nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups), 194 | nn.BatchNorm3d(in_planes)) 195 | self.conv2 = nn.Sequential(nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups), 196 | nn.BatchNorm3d(in_planes*expand)) 197 | else: 198 | if in_planes != out_planes: 199 | self.isproj = True 200 | self.proj = nn.Conv2d(in_planes, out_planes, 1, bias=bias, padding=0,groups=groups) 201 | if full: 202 | self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=(1,self.stride,self.stride), bias=bias, padding=(0,ksize//2,ksize//2),groups=groups) 203 | else: 204 | self.conv1 = nn.Conv3d(in_planes*expand, in_planes, (1,ksize,ksize), stride=1, bias=bias, padding=(0,ksize//2,ksize//2),groups=groups) 205 | self.conv2 = nn.Conv3d(in_planes, in_planes*expand, (ksize,ksize,1), stride=(self.stride,self.stride,1), bias=bias, padding=(ksize//2,ksize//2,0),groups=groups) 206 | self.relu = nn.ReLU(inplace=True) 207 | 208 | #@profile 209 | def forward(self,x): 210 | b,c,u,v,h,w = x.shape 211 | x = self.conv2(x.view(b,c,u,v,-1)) 212 | b,c,u,v,_ = x.shape 213 | x = self.relu(x) 214 | x = self.conv1(x.view(b,c,-1,h,w)) 215 | b,c,_,h,w = x.shape 216 | 217 | if self.isproj: 218 | x = self.proj(x.view(b,c,-1,w)) 219 | x = x.view(b,-1,u,v,h,w) 220 | return x 221 | 222 | 223 | class sepConv4dBlock(torch.nn.Module): 224 | ''' 225 | Separable 4d convolution block as 2 2D convolutions and a projection 226 | layer 227 | ''' 228 | def __init__(self, in_planes, out_planes, stride=(1,1,1), with_bn=True, full=True,groups=1): 229 | super(sepConv4dBlock, self).__init__() 230 | if in_planes == out_planes and stride==(1,1,1): 231 | self.downsample = None 232 | else: 233 | if full: 234 | self.downsample = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn,ksize=1, full=full,groups=groups) 235 | else: 236 | self.downsample = projfeat4d(in_planes, out_planes,stride[0], with_bn=with_bn,groups=groups) 237 | self.conv1 = sepConv4d(in_planes, out_planes, stride, with_bn=with_bn, full=full ,groups=groups) 238 | self.conv2 = sepConv4d(out_planes, out_planes,(1,1,1), with_bn=with_bn, full=full,groups=groups) 239 | self.relu1 = nn.ReLU(inplace=True) 240 | self.relu2 = nn.ReLU(inplace=True) 241 | 242 | #@profile 243 | def forward(self,x): 244 | out = self.relu1(self.conv1(x)) 245 | if self.downsample: 246 | x = self.downsample(x) 247 | out = self.relu2(x + self.conv2(out)) 248 | return out 249 | 250 | 251 | ##import torch.backends.cudnn as cudnn 252 | ##cudnn.benchmark = True 253 | #import time 254 | ##im = torch.randn(9,64,9,160,224).cuda() 255 | ##net = torch.nn.Conv3d(64, 64, 3).cuda() 256 | ##net = Conv4d(1,1,3,bias=True,pre_permuted_filters=True).cuda() 257 | ##net = sepConv4dBlock(2,2,stride=(1,1,1)).cuda() 258 | # 259 | ##im = torch.randn(1,16,9,9,96,320).cuda() 260 | ##net = sepConv4d(16,16,with_bn=False).cuda() 261 | # 262 | ##im = torch.randn(1,16,81,96,320).cuda() 263 | ##net = torch.nn.Conv3d(16,16,(1,3,3),padding=(0,1,1)).cuda() 264 | # 265 | ##im = torch.randn(1,16,9,9,96*320).cuda() 266 | ##net = torch.nn.Conv3d(16,16,(3,3,1),padding=(1,1,0)).cuda() 267 | # 268 | ##im = torch.randn(10000,10,9,9).cuda() 269 | ##net = torch.nn.Conv2d(10,10,3,padding=1).cuda() 270 | # 271 | ##im = torch.randn(81,16,96,320).cuda() 272 | ##net = torch.nn.Conv2d(16,16,3,padding=1).cuda() 273 | #c= int(16 *1) 274 | #cp = int(16 *1) 275 | #h=int(96 *4) 276 | #w=int(320 *4) 277 | #k=3 278 | #im = torch.randn(1,c,h,w).cuda() 279 | #net = torch.nn.Conv2d(c,cp,k,padding=k//2).cuda() 280 | # 281 | #im2 = torch.randn(cp,k*k*c).cuda() 282 | #im1 = F.unfold(im, (k,k), padding=k//2)[0] 283 | # 284 | # 285 | #net(im) 286 | #net(im) 287 | #torch.mm(im2,im1) 288 | #torch.mm(im2,im1) 289 | #torch.cuda.synchronize() 290 | #beg = time.time() 291 | #for i in range(100): 292 | # net(im) 293 | # #im1 = F.unfold(im, (k,k), padding=k//2)[0] 294 | # torch.mm(im2,im1) 295 | #torch.cuda.synchronize() 296 | #print('%f'%((time.time()-beg)*10.)) 297 | -------------------------------------------------------------------------------- /expansion/submission.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import sys 3 | import cv2 4 | import argparse 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.backends.cudnn as cudnn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | cudnn.benchmark = False 12 | 13 | class Expansion(): 14 | 15 | def __init__(self, loadmodel = 'pretrained_models/optical_expansion/robust.pth', testres = 1, maxdisp = 256, fac = 1): 16 | 17 | maxw,maxh = [int(testres*1280), int(testres*384)] 18 | 19 | max_h = int(maxh // 64 * 64) 20 | max_w = int(maxw // 64 * 64) 21 | if max_h < maxh: max_h += 64 22 | if max_w < maxw: max_w += 64 23 | maxh = max_h 24 | maxw = max_w 25 | 26 | mean_L = [[0.33,0.33,0.33]] 27 | mean_R = [[0.33,0.33,0.33]] 28 | 29 | # construct model, VCN-expansion 30 | from expansion.models.VCN_exp import VCN 31 | model = VCN([1, maxw, maxh], md=[int(4*(maxdisp/256)),4,4,4,4], fac=fac, 32 | exp_unc=('robust' in loadmodel)) # expansion uncertainty only in the new model 33 | model = nn.DataParallel(model, device_ids=[0]) 34 | model.cuda() 35 | 36 | if loadmodel is not None: 37 | pretrained_dict = torch.load(loadmodel) 38 | mean_L=pretrained_dict['mean_L'] 39 | mean_R=pretrained_dict['mean_R'] 40 | pretrained_dict['state_dict'] = {k:v for k,v in pretrained_dict['state_dict'].items()} 41 | model.load_state_dict(pretrained_dict['state_dict'],strict=False) 42 | else: 43 | print('dry run') 44 | 45 | model.eval() 46 | # resize 47 | maxh = 256 48 | maxw = 256 49 | max_h = int(maxh // 64 * 64) 50 | max_w = int(maxw // 64 * 64) 51 | if max_h < maxh: max_h += 64 52 | if max_w < maxw: max_w += 64 53 | 54 | # modify module according to inputs 55 | from expansion.models.VCN_exp import WarpModule, flow_reg 56 | for i in range(len(model.module.reg_modules)): 57 | model.module.reg_modules[i] = flow_reg([1,max_w//(2**(6-i)), max_h//(2**(6-i))], 58 | ent=getattr(model.module, 'flow_reg%d'%2**(6-i)).ent,\ 59 | maxdisp=getattr(model.module, 'flow_reg%d'%2**(6-i)).md,\ 60 | fac=getattr(model.module, 'flow_reg%d'%2**(6-i)).fac).cuda() 61 | for i in range(len(model.module.warp_modules)): 62 | model.module.warp_modules[i] = WarpModule([1,max_w//(2**(6-i)), max_h//(2**(6-i))]).cuda() 63 | 64 | mean_L = torch.from_numpy(np.asarray(mean_L).astype(np.float32).mean(0)[np.newaxis,:,np.newaxis,np.newaxis]).cuda() 65 | mean_R = torch.from_numpy(np.asarray(mean_R).astype(np.float32).mean(0)[np.newaxis,:,np.newaxis,np.newaxis]).cuda() 66 | 67 | self.max_h = max_h 68 | self.max_w = max_w 69 | self.model = model 70 | self.mean_L = mean_L 71 | self.mean_R = mean_R 72 | 73 | def run(self, imgL_o, imgR_o): 74 | model = self.model 75 | mean_L = self.mean_L 76 | mean_R = self.mean_R 77 | 78 | imgL_o[imgL_o<-1] = -1 79 | imgL_o[imgL_o>1] = 1 80 | imgR_o[imgR_o<-1] = -1 81 | imgR_o[imgR_o>1] = 1 82 | imgL = (imgL_o+1.)*0.5-mean_L 83 | imgR = (imgR_o*1.)*0.5-mean_R 84 | 85 | with torch.no_grad(): 86 | imgLR = torch.cat([imgL,imgR],0) 87 | model.eval() 88 | torch.cuda.synchronize() 89 | rts = model(imgLR) 90 | torch.cuda.synchronize() 91 | flow, occ, logmid, logexp = rts 92 | 93 | torch.cuda.empty_cache() 94 | 95 | return flow, logexp 96 | -------------------------------------------------------------------------------- /expansion/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/utils/__init__.py -------------------------------------------------------------------------------- /expansion/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /expansion/utils/__pycache__/flowlib.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/utils/__pycache__/flowlib.cpython-38.pyc -------------------------------------------------------------------------------- /expansion/utils/__pycache__/io.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/utils/__pycache__/io.cpython-38.pyc -------------------------------------------------------------------------------- /expansion/utils/__pycache__/pfm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/utils/__pycache__/pfm.cpython-38.pyc -------------------------------------------------------------------------------- /expansion/utils/__pycache__/util_flow.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/expansion/utils/__pycache__/util_flow.cpython-38.pyc -------------------------------------------------------------------------------- /expansion/utils/io.py: -------------------------------------------------------------------------------- 1 | import errno 2 | import os 3 | import shutil 4 | import sys 5 | import traceback 6 | import zipfile 7 | 8 | if sys.version_info[0] == 2: 9 | import urllib2 10 | else: 11 | import urllib.request 12 | 13 | 14 | # Converts a string to bytes (for writing the string into a file). Provided for 15 | # compatibility with Python 2 and 3. 16 | def StrToBytes(text): 17 | if sys.version_info[0] == 2: 18 | return text 19 | else: 20 | return bytes(text, 'UTF-8') 21 | 22 | 23 | # Outputs the given text and lets the user input a response (submitted by 24 | # pressing the return key). Provided for compatibility with Python 2 and 3. 25 | def GetUserInput(text): 26 | if sys.version_info[0] == 2: 27 | return raw_input(text) 28 | else: 29 | return input(text) 30 | 31 | 32 | # Creates the given directory (hierarchy), which may already exist. Provided for 33 | # compatibility with Python 2 and 3. 34 | def MakeDirsExistOk(directory_path): 35 | try: 36 | os.makedirs(directory_path) 37 | except OSError as exception: 38 | if exception.errno != errno.EEXIST: 39 | raise 40 | 41 | 42 | # Deletes all files and folders within the given folder. 43 | def DeleteFolderContents(folder_path): 44 | for file_name in os.listdir(folder_path): 45 | file_path = os.path.join(folder_path, file_name) 46 | try: 47 | if os.path.isfile(file_path): 48 | os.unlink(file_path) 49 | else: #if os.path.isdir(file_path): 50 | shutil.rmtree(file_path) 51 | except Exception as e: 52 | print('Exception in DeleteFolderContents():') 53 | print(e) 54 | print('Stack trace:') 55 | print(traceback.format_exc()) 56 | 57 | 58 | # Creates the given directory, respectively deletes all content of the directory 59 | # in case it already exists. 60 | def MakeCleanDirectory(folder_path): 61 | if os.path.isdir(folder_path): 62 | DeleteFolderContents(folder_path) 63 | else: 64 | MakeDirsExistOk(folder_path) 65 | 66 | 67 | # Downloads the given URL to a file in the given directory. Returns the 68 | # path to the downloaded file. 69 | # In part adapted from: https://stackoverflow.com/questions/22676 70 | def DownloadFile(url, dest_dir_path): 71 | file_name = url.split('/')[-1] 72 | dest_file_path = os.path.join(dest_dir_path, file_name) 73 | 74 | if os.path.isfile(dest_file_path): 75 | print('The following file already exists:') 76 | print(dest_file_path) 77 | print('Please choose whether to re-download and overwrite the file [o] or to skip downloading this file [s] by entering o or s.') 78 | while True: 79 | response = GetUserInput("> ") 80 | if response == 's': 81 | return dest_file_path 82 | elif response == 'o': 83 | break 84 | else: 85 | print('Please enter o or s.') 86 | 87 | url_object = None 88 | if sys.version_info[0] == 2: 89 | url_object = urllib2.urlopen(url) 90 | else: 91 | url_object = urllib.request.urlopen(url) 92 | 93 | with open(dest_file_path, 'wb') as outfile: 94 | meta = url_object.info() 95 | file_size = 0 96 | if sys.version_info[0] == 2: 97 | file_size = int(meta.getheaders("Content-Length")[0]) 98 | else: 99 | file_size = int(meta["Content-Length"]) 100 | print("Downloading: %s (size [bytes]: %s)" % (url, file_size)) 101 | 102 | file_size_downloaded = 0 103 | block_size = 8192 104 | while True: 105 | buffer = url_object.read(block_size) 106 | if not buffer: 107 | break 108 | 109 | file_size_downloaded += len(buffer) 110 | outfile.write(buffer) 111 | 112 | sys.stdout.write("%d / %d (%3f%%)\r" % (file_size_downloaded, file_size, file_size_downloaded * 100. / file_size)) 113 | sys.stdout.flush() 114 | 115 | return dest_file_path 116 | 117 | 118 | # Unzips the given zip file into the given directory. 119 | def UnzipFile(file_path, unzip_dir_path, overwrite=True): 120 | zip_ref = zipfile.ZipFile(open(file_path, 'rb')) 121 | 122 | if not overwrite: 123 | for f in zip_ref.namelist(): 124 | if not os.path.isfile(os.path.join(unzip_dir_path, f)): 125 | zip_ref.extract(f, path=unzip_dir_path) 126 | else: 127 | print('Not overwriting {}'.format(f)) 128 | else: 129 | zip_ref.extractall(unzip_dir_path) 130 | zip_ref.close() 131 | 132 | 133 | # Creates a zip file with the contents of the given directory. 134 | # The archive_base_path must not include the extension .zip. The full, final 135 | # path of the archive is returned by the function. 136 | def ZipDirectory(archive_base_path, root_dir_path): 137 | # return shutil.make_archive(archive_base_path, 'zip', root_dir_path) # THIS WILL ALWAYS HAVE ./ FOLDER INCLUDED 138 | with zipfile.ZipFile(archive_base_path+'.zip', "w", compression=zipfile.ZIP_DEFLATED) as zf: 139 | base_path = os.path.normpath(root_dir_path) 140 | for dirpath, dirnames, filenames in os.walk(root_dir_path): 141 | for name in sorted(dirnames): 142 | path = os.path.normpath(os.path.join(dirpath, name)) 143 | zf.write(path, os.path.relpath(path, base_path)) 144 | for name in filenames: 145 | path = os.path.normpath(os.path.join(dirpath, name)) 146 | if os.path.isfile(path): 147 | zf.write(path, os.path.relpath(path, base_path)) 148 | 149 | return archive_base_path+'.zip' 150 | 151 | 152 | # Downloads a zip file and directly unzips it. 153 | def DownloadAndUnzipFile(url, archive_dir_path, unzip_dir_path, overwrite=True): 154 | archive_path = DownloadFile(url, archive_dir_path) 155 | UnzipFile(archive_path, unzip_dir_path, overwrite=overwrite) 156 | 157 | def mkdir_p(path): 158 | try: 159 | os.makedirs(path) 160 | except OSError as exc: # Python >2.5 161 | if exc.errno == errno.EEXIST and os.path.isdir(path): 162 | pass 163 | else: 164 | raise 165 | -------------------------------------------------------------------------------- /expansion/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | File: logger.py 3 | Modified by: Senthil Purushwalkam 4 | Code referenced from https://gist.github.com/gyglim/1f8dfb1b5c82627ae3efcfbbadb9f514 5 | Email: spurushwandrewcmuedu 6 | Github: https://github.com/senthilps8 7 | Description: 8 | """ 9 | import pdb 10 | import tensorflow as tf 11 | from torch.autograd import Variable 12 | import numpy as np 13 | import scipy.misc 14 | import os 15 | try: 16 | from StringIO import StringIO # Python 2.7 17 | except ImportError: 18 | from io import BytesIO # Python 3.x 19 | 20 | 21 | class Logger(object): 22 | 23 | def __init__(self, log_dir, name=None): 24 | """Create a summary writer logging to log_dir.""" 25 | if name is None: 26 | name = 'temp' 27 | self.name = name 28 | if name is not None: 29 | try: 30 | os.makedirs(os.path.join(log_dir, name)) 31 | except: 32 | pass 33 | self.writer = tf.summary.FileWriter(os.path.join(log_dir, name), 34 | filename_suffix=name) 35 | else: 36 | self.writer = tf.summary.FileWriter(log_dir, filename_suffix=name) 37 | 38 | def scalar_summary(self, tag, value, step): 39 | """Log a scalar variable.""" 40 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) 41 | self.writer.add_summary(summary, step) 42 | 43 | def image_summary(self, tag, images, step): 44 | """Log a list of images.""" 45 | 46 | img_summaries = [] 47 | for i, img in enumerate(images): 48 | # Write the image to a string 49 | try: 50 | s = StringIO() 51 | except: 52 | s = BytesIO() 53 | scipy.misc.toimage(img).save(s, format="png") 54 | 55 | # Create an Image object 56 | img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(), 57 | height=img.shape[0], 58 | width=img.shape[1]) 59 | # Create a Summary value 60 | img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum)) 61 | 62 | # Create and write Summary 63 | summary = tf.Summary(value=img_summaries) 64 | self.writer.add_summary(summary, step) 65 | 66 | def histo_summary(self, tag, values, step, bins=1000): 67 | """Log a histogram of the tensor of values.""" 68 | 69 | # Create a histogram using numpy 70 | counts, bin_edges = np.histogram(values, bins=bins) 71 | 72 | # Fill the fields of the histogram proto 73 | hist = tf.HistogramProto() 74 | hist.min = float(np.min(values)) 75 | hist.max = float(np.max(values)) 76 | hist.num = int(np.prod(values.shape)) 77 | hist.sum = float(np.sum(values)) 78 | hist.sum_squares = float(np.sum(values**2)) 79 | 80 | # Drop the start of the first bin 81 | bin_edges = bin_edges[1:] 82 | 83 | # Add bin edges and counts 84 | for edge in bin_edges: 85 | hist.bucket_limit.append(edge) 86 | for c in counts: 87 | hist.bucket.append(c) 88 | 89 | # Create and write Summary 90 | summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)]) 91 | self.writer.add_summary(summary, step) 92 | self.writer.flush() 93 | 94 | def to_np(self, x): 95 | return x.data.cpu().numpy() 96 | 97 | def to_var(self, x): 98 | if torch.cuda.is_available(): 99 | x = x.cuda() 100 | return Variable(x) 101 | 102 | def model_param_histo_summary(self, model, step): 103 | """log histogram summary of model's parameters 104 | and parameter gradients 105 | """ 106 | for tag, value in model.named_parameters(): 107 | if value.grad is None: 108 | continue 109 | tag = tag.replace('.', '/') 110 | tag = self.name+'/'+tag 111 | self.histo_summary(tag, self.to_np(value), step) 112 | self.histo_summary(tag+'/grad', self.to_np(value.grad), step) 113 | 114 | -------------------------------------------------------------------------------- /expansion/utils/multiscaleloss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from https://github.com/ClementPinard/FlowNetPytorch 3 | """ 4 | import pdb 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | 9 | def EPE(input_flow, target_flow, mask, sparse=False, mean=True): 10 | #mask = target_flow[:,2]>0 11 | target_flow = target_flow[:,:2] 12 | EPE_map = torch.norm(target_flow-input_flow,2,1) 13 | batch_size = EPE_map.size(0) 14 | if sparse: 15 | # invalid flow is defined with both flow coordinates to be exactly 0 16 | mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0) 17 | 18 | EPE_map = EPE_map[~mask] 19 | if mean: 20 | return EPE_map[mask].mean() 21 | else: 22 | return EPE_map[mask].sum()/batch_size 23 | 24 | def rob_EPE(input_flow, target_flow, mask, sparse=False, mean=True): 25 | #mask = target_flow[:,2]>0 26 | target_flow = target_flow[:,:2] 27 | #TODO 28 | # EPE_map = torch.norm(target_flow-input_flow,2,1) 29 | EPE_map = (torch.norm(target_flow-input_flow,1,1)+0.01).pow(0.4) 30 | batch_size = EPE_map.size(0) 31 | if sparse: 32 | # invalid flow is defined with both flow coordinates to be exactly 0 33 | mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0) 34 | 35 | EPE_map = EPE_map[~mask] 36 | if mean: 37 | return EPE_map[mask].mean() 38 | else: 39 | return EPE_map[mask].sum()/batch_size 40 | 41 | def sparse_max_pool(input, size): 42 | '''Downsample the input by considering 0 values as invalid. 43 | 44 | Unfortunately, no generic interpolation mode can resize a sparse map correctly, 45 | the strategy here is to use max pooling for positive values and "min pooling" 46 | for negative values, the two results are then summed. 47 | This technique allows sparsity to be minized, contrary to nearest interpolation, 48 | which could potentially lose information for isolated data points.''' 49 | 50 | positive = (input > 0).float() 51 | negative = (input < 0).float() 52 | output = F.adaptive_max_pool2d(input * positive, size) - F.adaptive_max_pool2d(-input * negative, size) 53 | return output 54 | 55 | 56 | def multiscaleEPE(network_output, target_flow, mask, weights=None, sparse=False, rob_loss = False): 57 | def one_scale(output, target, mask, sparse): 58 | 59 | b, _, h, w = output.size() 60 | 61 | if sparse: 62 | target_scaled = sparse_max_pool(target, (h, w)) 63 | else: 64 | target_scaled = F.interpolate(target, (h, w), mode='area') 65 | mask = F.interpolate(mask.float().unsqueeze(1), (h, w), mode='bilinear').squeeze(1)==1 66 | if rob_loss: 67 | return rob_EPE(output, target_scaled, mask, sparse, mean=False) 68 | else: 69 | return EPE(output, target_scaled, mask, sparse, mean=False) 70 | 71 | if type(network_output) not in [tuple, list]: 72 | network_output = [network_output] 73 | if weights is None: 74 | weights = [0.005, 0.01, 0.02, 0.08, 0.32] # as in original article 75 | assert(len(weights) == len(network_output)) 76 | 77 | loss = 0 78 | for output, weight in zip(network_output, weights): 79 | loss += weight * one_scale(output, target_flow, mask, sparse) 80 | return loss 81 | 82 | 83 | def realEPE(output, target, mask, sparse=False): 84 | b, _, h, w = target.size() 85 | upsampled_output = F.interpolate(output, (h,w), mode='bilinear', align_corners=False) 86 | return EPE(upsampled_output, target,mask, sparse, mean=True) 87 | -------------------------------------------------------------------------------- /expansion/utils/pfm.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import sys 4 | 5 | def readPFM(file): 6 | file = open(file, 'rb') 7 | 8 | color = None 9 | width = None 10 | height = None 11 | scale = None 12 | endian = None 13 | 14 | header = file.readline().rstrip() 15 | if (sys.version[0]) == '3': 16 | header = header.decode('utf-8') 17 | if header == 'PF': 18 | color = True 19 | elif header == 'Pf': 20 | color = False 21 | else: 22 | raise Exception('Not a PFM file.') 23 | 24 | if (sys.version[0]) == '3': 25 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 26 | else: 27 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline()) 28 | if dim_match: 29 | width, height = map(int, dim_match.groups()) 30 | else: 31 | raise Exception('Malformed PFM header.') 32 | 33 | if (sys.version[0]) == '3': 34 | scale = float(file.readline().rstrip().decode('utf-8')) 35 | else: 36 | scale = float(file.readline().rstrip()) 37 | 38 | if scale < 0: # little-endian 39 | endian = '<' 40 | scale = -scale 41 | else: 42 | endian = '>' # big-endian 43 | 44 | data = np.fromfile(file, endian + 'f') 45 | shape = (height, width, 3) if color else (height, width) 46 | 47 | data = np.reshape(data, shape) 48 | data = np.flipud(data) 49 | return data, scale 50 | 51 | 52 | def writePFM(file, image, scale=1): 53 | file = open(file, 'wb') 54 | 55 | color = None 56 | 57 | if image.dtype.name != 'float32': 58 | raise Exception('Image dtype must be float32.') 59 | 60 | image = np.flipud(image) 61 | 62 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 63 | color = True 64 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 65 | color = False 66 | else: 67 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 68 | 69 | file.write('PF\n' if color else 'Pf\n') 70 | file.write('%d %d\n' % (image.shape[1], image.shape[0])) 71 | 72 | endian = image.dtype.byteorder 73 | 74 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 75 | scale = -scale 76 | 77 | file.write('%f\n' % scale) 78 | 79 | image.tofile(file) 80 | -------------------------------------------------------------------------------- /expansion/utils/readpfm.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | import sys 4 | 5 | 6 | def readPFM(file): 7 | file = open(file, 'rb') 8 | 9 | color = None 10 | width = None 11 | height = None 12 | scale = None 13 | endian = None 14 | 15 | header = file.readline().rstrip() 16 | if (sys.version[0]) == '3': 17 | header = header.decode('utf-8') 18 | if header == 'PF': 19 | color = True 20 | elif header == 'Pf': 21 | color = False 22 | else: 23 | raise Exception('Not a PFM file.') 24 | 25 | if (sys.version[0]) == '3': 26 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 27 | else: 28 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline()) 29 | if dim_match: 30 | width, height = map(int, dim_match.groups()) 31 | else: 32 | raise Exception('Malformed PFM header.') 33 | 34 | if (sys.version[0]) == '3': 35 | scale = float(file.readline().rstrip().decode('utf-8')) 36 | else: 37 | scale = float(file.readline().rstrip()) 38 | 39 | if scale < 0: # little-endian 40 | endian = '<' 41 | scale = -scale 42 | else: 43 | endian = '>' # big-endian 44 | 45 | data = np.fromfile(file, endian + 'f') 46 | shape = (height, width, 3) if color else (height, width) 47 | 48 | data = np.reshape(data, shape) 49 | data = np.flipud(data) 50 | return data, scale 51 | 52 | -------------------------------------------------------------------------------- /expansion/utils/sintel_io.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python2 2 | 3 | """ 4 | I/O script to save and load the data coming with the MPI-Sintel low-level 5 | computer vision benchmark. 6 | 7 | For more details about the benchmark, please visit www.mpi-sintel.de 8 | 9 | CHANGELOG: 10 | v1.0 (2015/02/03): First release 11 | 12 | Copyright (c) 2015 Jonas Wulff 13 | Max Planck Institute for Intelligent Systems, Tuebingen, Germany 14 | 15 | """ 16 | 17 | # Requirements: Numpy as PIL/Pillow 18 | import numpy as np 19 | from PIL import Image 20 | 21 | # Check for endianness, based on Daniel Scharstein's optical flow code. 22 | # Using little-endian architecture, these two should be equal. 23 | TAG_FLOAT = 202021.25 24 | TAG_CHAR = 'PIEH' 25 | 26 | def flow_read(filename): 27 | """ Read optical flow from file, return (U,V) tuple. 28 | 29 | Original code by Deqing Sun, adapted from Daniel Scharstein. 30 | """ 31 | f = open(filename,'rb') 32 | check = np.fromfile(f,dtype=np.float32,count=1)[0] 33 | assert check == TAG_FLOAT, ' flow_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) 34 | width = np.fromfile(f,dtype=np.int32,count=1)[0] 35 | height = np.fromfile(f,dtype=np.int32,count=1)[0] 36 | size = width*height 37 | assert width > 0 and height > 0 and size > 1 and size < 100000000, ' flow_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height) 38 | tmp = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width*2)) 39 | u = tmp[:,np.arange(width)*2] 40 | v = tmp[:,np.arange(width)*2 + 1] 41 | return u,v 42 | 43 | def flow_write(filename,uv,v=None): 44 | """ Write optical flow to file. 45 | 46 | If v is None, uv is assumed to contain both u and v channels, 47 | stacked in depth. 48 | 49 | Original code by Deqing Sun, adapted from Daniel Scharstein. 50 | """ 51 | nBands = 2 52 | 53 | if v is None: 54 | assert(uv.ndim == 3) 55 | assert(uv.shape[2] == 2) 56 | u = uv[:,:,0] 57 | v = uv[:,:,1] 58 | else: 59 | u = uv 60 | 61 | assert(u.shape == v.shape) 62 | height,width = u.shape 63 | f = open(filename,'wb') 64 | # write the header 65 | f.write(TAG_CHAR) 66 | np.array(width).astype(np.int32).tofile(f) 67 | np.array(height).astype(np.int32).tofile(f) 68 | # arrange into matrix form 69 | tmp = np.zeros((height, width*nBands)) 70 | tmp[:,np.arange(width)*2] = u 71 | tmp[:,np.arange(width)*2 + 1] = v 72 | tmp.astype(np.float32).tofile(f) 73 | f.close() 74 | 75 | 76 | def depth_read(filename): 77 | """ Read depth data from file, return as numpy array. """ 78 | f = open(filename,'rb') 79 | check = np.fromfile(f,dtype=np.float32,count=1)[0] 80 | assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) 81 | width = np.fromfile(f,dtype=np.int32,count=1)[0] 82 | height = np.fromfile(f,dtype=np.int32,count=1)[0] 83 | size = width*height 84 | assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height) 85 | depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width)) 86 | return depth 87 | 88 | def depth_write(filename, depth): 89 | """ Write depth to file. """ 90 | height,width = depth.shape[:2] 91 | f = open(filename,'wb') 92 | # write the header 93 | f.write(TAG_CHAR) 94 | np.array(width).astype(np.int32).tofile(f) 95 | np.array(height).astype(np.int32).tofile(f) 96 | 97 | depth.astype(np.float32).tofile(f) 98 | f.close() 99 | 100 | 101 | def disparity_write(filename,disparity,bitdepth=16): 102 | """ Write disparity to file. 103 | 104 | bitdepth can be either 16 (default) or 32. 105 | 106 | The maximum disparity is 1024, since the image width in Sintel 107 | is 1024. 108 | """ 109 | d = disparity.copy() 110 | 111 | # Clip disparity. 112 | d[d>1024] = 1024 113 | d[d<0] = 0 114 | 115 | d_r = (d / 4.0).astype('uint8') 116 | d_g = ((d * (2.0**6)) % 256).astype('uint8') 117 | 118 | out = np.zeros((d.shape[0],d.shape[1],3),dtype='uint8') 119 | out[:,:,0] = d_r 120 | out[:,:,1] = d_g 121 | 122 | if bitdepth > 16: 123 | d_b = (d * (2**14) % 256).astype('uint8') 124 | out[:,:,2] = d_b 125 | 126 | Image.fromarray(out,'RGB').save(filename,'PNG') 127 | 128 | 129 | def disparity_read(filename): 130 | """ Return disparity read from filename. """ 131 | f_in = np.array(Image.open(filename)) 132 | d_r = f_in[:,:,0].astype('float64') 133 | d_g = f_in[:,:,1].astype('float64') 134 | d_b = f_in[:,:,2].astype('float64') 135 | 136 | depth = d_r * 4 + d_g / (2**6) + d_b / (2**14) 137 | return depth 138 | 139 | 140 | #def cam_read(filename): 141 | # """ Read camera data, return (M,N) tuple. 142 | # 143 | # M is the intrinsic matrix, N is the extrinsic matrix, so that 144 | # 145 | # x = M*N*X, 146 | # with x being a point in homogeneous image pixel coordinates, X being a 147 | # point in homogeneous world coordinates. 148 | # """ 149 | # txtdata = np.loadtxt(filename) 150 | # intrinsic = txtdata[0,:9].reshape((3,3)) 151 | # extrinsic = textdata[1,:12].reshape((3,4)) 152 | # return intrinsic,extrinsic 153 | # 154 | # 155 | #def cam_write(filename,M,N): 156 | # """ Write intrinsic matrix M and extrinsic matrix N to file. """ 157 | # Z = np.zeros((2,12)) 158 | # Z[0,:9] = M.ravel() 159 | # Z[1,:12] = N.ravel() 160 | # np.savetxt(filename,Z) 161 | 162 | def cam_read(filename): 163 | """ Read camera data, return (M,N) tuple. 164 | 165 | M is the intrinsic matrix, N is the extrinsic matrix, so that 166 | 167 | x = M*N*X, 168 | with x being a point in homogeneous image pixel coordinates, X being a 169 | point in homogeneous world coordinates. 170 | """ 171 | f = open(filename,'rb') 172 | check = np.fromfile(f,dtype=np.float32,count=1)[0] 173 | assert check == TAG_FLOAT, ' cam_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check) 174 | M = np.fromfile(f,dtype='float64',count=9).reshape((3,3)) 175 | N = np.fromfile(f,dtype='float64',count=12).reshape((3,4)) 176 | return M,N 177 | 178 | def cam_write(filename, M, N): 179 | """ Write intrinsic matrix M and extrinsic matrix N to file. """ 180 | f = open(filename,'wb') 181 | # write the header 182 | f.write(TAG_CHAR) 183 | M.astype('float64').tofile(f) 184 | N.astype('float64').tofile(f) 185 | f.close() 186 | 187 | 188 | def segmentation_write(filename,segmentation): 189 | """ Write segmentation to file. """ 190 | 191 | segmentation_ = segmentation.astype('int32') 192 | seg_r = np.floor(segmentation_ / (256**2)).astype('uint8') 193 | seg_g = np.floor((segmentation_ % (256**2)) / 256).astype('uint8') 194 | seg_b = np.floor(segmentation_ % 256).astype('uint8') 195 | 196 | out = np.zeros((segmentation.shape[0],segmentation.shape[1],3),dtype='uint8') 197 | out[:,:,0] = seg_r 198 | out[:,:,1] = seg_g 199 | out[:,:,2] = seg_b 200 | 201 | Image.fromarray(out,'RGB').save(filename,'PNG') 202 | 203 | 204 | def segmentation_read(filename): 205 | """ Return disparity read from filename. """ 206 | f_in = np.array(Image.open(filename)) 207 | seg_r = f_in[:,:,0].astype('int32') 208 | seg_g = f_in[:,:,1].astype('int32') 209 | seg_b = f_in[:,:,2].astype('int32') 210 | 211 | segmentation = (seg_r * 256 + seg_g) * 256 + seg_b 212 | return segmentation 213 | 214 | 215 | -------------------------------------------------------------------------------- /expansion/utils/util_flow.py: -------------------------------------------------------------------------------- 1 | import math 2 | import png 3 | import struct 4 | import array 5 | import numpy as np 6 | import cv2 7 | import pdb 8 | 9 | from io import * 10 | 11 | UNKNOWN_FLOW_THRESH = 1e9; 12 | UNKNOWN_FLOW = 1e10; 13 | 14 | # Middlebury checks 15 | TAG_STRING = 'PIEH' # use this when WRITING the file 16 | TAG_FLOAT = 202021.25 # check for this when READING the file 17 | 18 | def readPFM(file): 19 | import re 20 | file = open(file, 'rb') 21 | 22 | color = None 23 | width = None 24 | height = None 25 | scale = None 26 | endian = None 27 | 28 | header = file.readline().rstrip() 29 | if header == b'PF': 30 | color = True 31 | elif header == b'Pf': 32 | color = False 33 | else: 34 | raise Exception('Not a PFM file.') 35 | 36 | dim_match = re.match(b'^(\d+)\s(\d+)\s$', file.readline()) 37 | if dim_match: 38 | width, height = map(int, dim_match.groups()) 39 | else: 40 | raise Exception('Malformed PFM header.') 41 | 42 | scale = float(file.readline().rstrip()) 43 | if scale < 0: # little-endian 44 | endian = '<' 45 | scale = -scale 46 | else: 47 | endian = '>' # big-endian 48 | 49 | data = np.fromfile(file, endian + 'f') 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | return data, scale 55 | 56 | 57 | def save_pfm(file, image, scale = 1): 58 | import sys 59 | color = None 60 | 61 | if image.dtype.name != 'float32': 62 | raise Exception('Image dtype must be float32.') 63 | 64 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 65 | color = True 66 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 67 | color = False 68 | else: 69 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 70 | 71 | file.write('PF\n' if color else 'Pf\n') 72 | file.write('%d %d\n' % (image.shape[1], image.shape[0])) 73 | 74 | endian = image.dtype.byteorder 75 | 76 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 77 | scale = -scale 78 | 79 | file.write('%f\n' % scale) 80 | 81 | image.tofile(file) 82 | 83 | 84 | def ReadMiddleburyFloFile(path): 85 | """ Read .FLO file as specified by Middlebury. 86 | 87 | Returns tuple (width, height, u, v, mask), where u, v, mask are flat 88 | arrays of values. 89 | """ 90 | 91 | with open(path, 'rb') as fil: 92 | tag = struct.unpack('f', fil.read(4))[0] 93 | width = struct.unpack('i', fil.read(4))[0] 94 | height = struct.unpack('i', fil.read(4))[0] 95 | 96 | assert tag == TAG_FLOAT 97 | 98 | #data = np.fromfile(path, dtype=np.float, count=-1) 99 | #data = data[3:] 100 | 101 | fmt = 'f' * width*height*2 102 | data = struct.unpack(fmt, fil.read(4*width*height*2)) 103 | 104 | u = data[::2] 105 | v = data[1::2] 106 | 107 | mask = map(lambda x,y: abs(x) 0: 144 | # print(u[ind], v[ind], mask[ind], row[3*x], row[3*x+1], row[3*x+2]) 145 | 146 | #png_reader.close() 147 | 148 | return (width, height, u, v, mask) 149 | 150 | 151 | def WriteMiddleburyFloFile(path, width, height, u, v, mask=None): 152 | """ Write .FLO file as specified by Middlebury. 153 | """ 154 | 155 | if mask is not None: 156 | u_masked = map(lambda x,y: x if y else UNKNOWN_FLOW, u, mask) 157 | v_masked = map(lambda x,y: x if y else UNKNOWN_FLOW, v, mask) 158 | else: 159 | u_masked = u 160 | v_masked = v 161 | 162 | fmt = 'f' * width*height*2 163 | # Interleave lists 164 | data = [x for t in zip(u_masked,v_masked) for x in t] 165 | 166 | with open(path, 'wb') as fil: 167 | fil.write(str.encode(TAG_STRING)) 168 | fil.write(struct.pack('i', width)) 169 | fil.write(struct.pack('i', height)) 170 | fil.write(struct.pack(fmt, *data)) 171 | 172 | 173 | def write_flow(path,flow): 174 | 175 | invalid_idx = (flow[:, :, 2] == 0) 176 | flow[:, :, 0:2] = flow[:, :, 0:2]*64.+ 2 ** 15 177 | flow[invalid_idx, 0] = 0 178 | flow[invalid_idx, 1] = 0 179 | 180 | flow = flow.astype(np.uint16) 181 | flow = cv2.imwrite(path, flow[:,:,::-1]) 182 | 183 | #WriteKittiPngFile(path, 184 | # flow.shape[1], flow.shape[0], flow[:,:,0].flatten(), 185 | # flow[:,:,1].flatten(), flow[:,:,2].flatten()) 186 | 187 | 188 | 189 | def WriteKittiPngFile(path, width, height, u, v, mask=None): 190 | """ Write 16-bit .PNG file as specified by KITTI-2015 (flow). 191 | 192 | u, v are lists of float values 193 | mask is a list of floats, denoting the *valid* pixels. 194 | """ 195 | 196 | data = array.array('H',[0])*width*height*3 197 | 198 | for i,(u_,v_,mask_) in enumerate(zip(u,v,mask)): 199 | data[3*i] = int(u_*64.0+2**15) 200 | data[3*i+1] = int(v_*64.0+2**15) 201 | data[3*i+2] = int(mask_) 202 | 203 | # if mask_ > 0: 204 | # print(data[3*i], data[3*i+1],data[3*i+2]) 205 | 206 | with open(path, 'wb') as png_file: 207 | png_writer = png.Writer(width=width, height=height, bitdepth=16, compression=3, greyscale=False) 208 | png_writer.write_array(png_file, data) 209 | 210 | 211 | def ConvertMiddleburyFloToKittiPng(src_path, dest_path): 212 | width, height, u, v, mask = ReadMiddleburyFloFile(src_path) 213 | WriteKittiPngFile(dest_path, width, height, u, v, mask=mask) 214 | 215 | def ConvertKittiPngToMiddleburyFlo(src_path, dest_path): 216 | width, height, u, v, mask = ReadKittiPngFile(src_path) 217 | WriteMiddleburyFloFile(dest_path, width, height, u, v, mask=mask) 218 | 219 | 220 | def ParseFilenameKitti(filename): 221 | # Parse kitti filename (seq_frameno.xx), 222 | # return seq, frameno, ext. 223 | # Be aware that seq might contain the dataset name (if contained as prefix) 224 | ext = filename[filename.rfind('.'):] 225 | frameno = filename[filename.rfind('_')+1:filename.rfind('.')] 226 | frameno = int(frameno) 227 | seq = filename[:filename.rfind('_')] 228 | return seq, frameno, ext 229 | 230 | 231 | def read_calib_file(filepath): 232 | """Read in a calibration file and parse into a dictionary.""" 233 | data = {} 234 | 235 | with open(filepath, 'r') as f: 236 | for line in f.readlines(): 237 | key, value = line.split(':', 1) 238 | # The only non-float values in these files are dates, which 239 | # we don't care about anyway 240 | try: 241 | data[key] = np.array([float(x) for x in value.split()]) 242 | except ValueError: 243 | pass 244 | 245 | return data 246 | 247 | def load_calib_cam_to_cam(cam_to_cam_file): 248 | # We'll return the camera calibration as a dictionary 249 | data = {} 250 | 251 | # Load and parse the cam-to-cam calibration data 252 | filedata = read_calib_file(cam_to_cam_file) 253 | 254 | # Create 3x4 projection matrices 255 | P_rect_00 = np.reshape(filedata['P_rect_00'], (3, 4)) 256 | P_rect_10 = np.reshape(filedata['P_rect_01'], (3, 4)) 257 | P_rect_20 = np.reshape(filedata['P_rect_02'], (3, 4)) 258 | P_rect_30 = np.reshape(filedata['P_rect_03'], (3, 4)) 259 | 260 | # Compute the camera intrinsics 261 | data['K_cam0'] = P_rect_00[0:3, 0:3] 262 | data['K_cam1'] = P_rect_10[0:3, 0:3] 263 | data['K_cam2'] = P_rect_20[0:3, 0:3] 264 | data['K_cam3'] = P_rect_30[0:3, 0:3] 265 | 266 | data['b00'] = P_rect_00[0, 3] / P_rect_00[0, 0] 267 | data['b10'] = P_rect_10[0, 3] / P_rect_10[0, 0] 268 | data['b20'] = P_rect_20[0, 3] / P_rect_20[0, 0] 269 | data['b30'] = P_rect_30[0, 3] / P_rect_30[0, 0] 270 | 271 | return data 272 | 273 | -------------------------------------------------------------------------------- /interface/flask_app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, render_template, request, redirect, url_for, abort 2 | import json 3 | 4 | app = Flask(__name__) 5 | 6 | import sys 7 | sys.path.append(".") 8 | sys.path.append("..") 9 | 10 | import argparse 11 | from PIL import Image, ImageOps 12 | import numpy as np 13 | import base64 14 | import cv2 15 | from inference import demo 16 | 17 | def Base64ToNdarry(img_base64): 18 | img_data = base64.b64decode(img_base64) 19 | img_np = np.fromstring(img_data, np.uint8) 20 | src = cv2.imdecode(img_np, cv2.IMREAD_ANYCOLOR) 21 | 22 | return src 23 | 24 | def NdarrayToBase64(dst): 25 | result, dst_data = cv2.imencode('.png', dst) 26 | dst_base64 = base64.b64encode(dst_data) 27 | 28 | return dst_base64 29 | 30 | parser = argparse.ArgumentParser(description='User controllable latent transformer') 31 | parser.add_argument('--checkpoint_path', default='pretrained_models/latent_transformer/cat.pt') 32 | args = parser.parse_args() 33 | 34 | demo = demo(args.checkpoint_path) 35 | 36 | @app.route("/", methods=["GET", "POST"]) 37 | #@auth.login_required 38 | def init(): 39 | if request.method == "GET": 40 | input_img = demo.run() 41 | input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode() 42 | return render_template("index.html", filepath1=input_base64, canvas_img=input_base64, result=True) 43 | if request.method == "POST": 44 | if 'zi' in request.form.keys(): 45 | input_img = demo.move(z=-0.05) 46 | elif 'zo' in request.form.keys(): 47 | input_img = demo.move(z=0.05) 48 | elif 'u' in request.form.keys(): 49 | input_img = demo.move(y=-0.5, z=-0.0) 50 | elif 'd' in request.form.keys(): 51 | input_img = demo.move(y=0.5, z=-0.0) 52 | elif 'l' in request.form.keys(): 53 | input_img = demo.move(x=-0.5, z=-0.0) 54 | elif 'r' in request.form.keys(): 55 | input_img = demo.move(x=0.5, z=-0.0) 56 | else: 57 | input_img = demo.run() 58 | 59 | input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode() 60 | return render_template("index.html", filepath1=input_base64, canvas_img=input_base64, result=True) 61 | 62 | @app.route('/zoom', methods=["POST"]) 63 | def zoom_func(): 64 | 65 | dz = json.loads(request.form['dz']) 66 | sx = json.loads(request.form['sx']) 67 | sy = json.loads(request.form['sy']) 68 | stop_points = json.loads(request.form['stop_points']) 69 | 70 | input_img = demo.zoom(dz,sxsy=[sx,sy],stop_points=stop_points) 71 | input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode() 72 | res = {'img':input_base64} 73 | return json.dumps(res) 74 | 75 | @app.route('/translate', methods=["POST"]) 76 | def translate_func(): 77 | 78 | dx = json.loads(request.form['dx']) 79 | dy = json.loads(request.form['dy']) 80 | dz = json.loads(request.form['dz']) 81 | sx = json.loads(request.form['sx']) 82 | sy = json.loads(request.form['sy']) 83 | stop_points = json.loads(request.form['stop_points']) 84 | zi = json.loads(request.form['zi']) 85 | zo = json.loads(request.form['zo']) 86 | 87 | input_img = demo.translate([dx,dy],sxsy=[sx,sy],stop_points=stop_points,zoom_in=zi,zoom_out=zo) 88 | input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode() 89 | res = {'img':input_base64} 90 | return json.dumps(res) 91 | 92 | @app.route('/changestyle', methods=["POST"]) 93 | def changestyle_func(): 94 | input_img = demo.change_style() 95 | input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode() 96 | res = {'img':input_base64} 97 | return json.dumps(res) 98 | 99 | @app.route('/reset', methods=["POST"]) 100 | def reset_func(): 101 | input_img = demo.reset() 102 | input_base64 = "data:image/png;base64,"+NdarrayToBase64(input_img).decode() 103 | res = {'img':input_base64} 104 | return json.dumps(res) 105 | 106 | if __name__ == "__main__": 107 | app.run(debug=False, host='0.0.0.0', port=8000) -------------------------------------------------------------------------------- /interface/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import Namespace 3 | import numpy as np 4 | import torch 5 | import sys 6 | 7 | sys.path.append(".") 8 | sys.path.append("..") 9 | 10 | from models.StyleGANControler import StyleGANControler 11 | 12 | class demo(): 13 | 14 | def __init__(self, checkpoint_path, truncation = 0.5, use_average_code_as_input = False): 15 | self.truncation = truncation 16 | self.use_average_code_as_input = use_average_code_as_input 17 | ckpt = torch.load(checkpoint_path, map_location='cpu') 18 | opts = ckpt['opts'] 19 | opts['checkpoint_path'] = checkpoint_path 20 | self.opts = Namespace(**ckpt['opts']) 21 | 22 | self.net = StyleGANControler(self.opts) 23 | self.net.eval() 24 | self.net.cuda() 25 | self.target_layers = [0,1,2,3,4,5] 26 | 27 | self.w1 = None 28 | self.w1_after = None 29 | self.f1 = None 30 | 31 | def run(self): 32 | z1 = torch.randn(1,512).to("cuda") 33 | x1, self.w1, self.f1 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_feature_map=True,return_latents=True,truncation=self.truncation, truncation_latent=self.net.latent_avg[0]) 34 | self.w1_after = self.w1.clone() 35 | x1 = self.net.face_pool(x1) 36 | result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] 37 | return result 38 | 39 | def translate(self, dxy, sxsy=[0,0], stop_points=[], zoom_in=False, zoom_out=False): 40 | dz = -5. if zoom_in else 0. 41 | dz = 5. if zoom_out else dz 42 | 43 | dxyz = np.array([dxy[0],dxy[1],dz], dtype=np.float32) 44 | dxy_norm = np.linalg.norm(dxyz[:2], ord=2) 45 | dxyz[:2] = dxyz[:2]/dxy_norm 46 | vec_num = dxy_norm/10 47 | 48 | x = torch.from_numpy(np.array([[dxyz]],dtype=np.float32)).cuda() 49 | f1 = torch.nn.functional.interpolate(self.f1, (256,256)) 50 | y = f1[:,:,sxsy[1],sxsy[0]].unsqueeze(0) 51 | 52 | if len(stop_points)>0: 53 | x = torch.cat([x, torch.zeros(x.shape[0],len(stop_points),x.shape[2]).cuda()], dim=1) 54 | tmp = [] 55 | for sp in stop_points: 56 | tmp.append(f1[:,:,sp[1],sp[0]].unsqueeze(1)) 57 | y = torch.cat([y,torch.cat(tmp, dim=1)],dim=1) 58 | 59 | if not self.use_average_code_as_input: 60 | w_hat = self.net.encoder(self.w1[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num) 61 | w1 = self.w1.clone() 62 | w1[:,self.target_layers] = w_hat 63 | else: 64 | w_hat = self.net.encoder(self.net.latent_avg.unsqueeze(0)[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num) 65 | w1 = self.w1.clone() 66 | w1[:,self.target_layers] = self.w1.clone()[:,self.target_layers] + w_hat - self.net.latent_avg.unsqueeze(0)[:,self.target_layers] 67 | 68 | x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False) 69 | 70 | self.w1_after = w1.clone() 71 | x1 = self.net.face_pool(x1) 72 | result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] 73 | return result 74 | 75 | def zoom(self, dz, sxsy=[0,0], stop_points=[]): 76 | vec_num = abs(dz)/5 77 | dz = 100*np.sign(dz) 78 | x = torch.from_numpy(np.array([[[1.,0,dz]]],dtype=np.float32)).cuda() 79 | f1 = torch.nn.functional.interpolate(self.f1, (256,256)) 80 | y = f1[:,:,sxsy[1],sxsy[0]].unsqueeze(0) 81 | 82 | if len(stop_points)>0: 83 | x = torch.cat([x, torch.zeros(x.shape[0],len(stop_points),x.shape[2]).cuda()], dim=1) 84 | tmp = [] 85 | for sp in stop_points: 86 | tmp.append(f1[:,:,sp[1],sp[0]].unsqueeze(1)) 87 | y = torch.cat([y,torch.cat(tmp, dim=1)],dim=1) 88 | 89 | if not self.use_average_code_as_input: 90 | w_hat = self.net.encoder(self.w1[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num) 91 | w1 = self.w1.clone() 92 | w1[:,self.target_layers] = w_hat 93 | else: 94 | w_hat = self.net.encoder(self.net.latent_avg.unsqueeze(0)[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num) 95 | w1 = self.w1.clone() 96 | w1[:,self.target_layers] = self.w1.clone()[:,self.target_layers] + w_hat - self.net.latent_avg.unsqueeze(0)[:,self.target_layers] 97 | 98 | 99 | x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False) 100 | 101 | x1 = self.net.face_pool(x1) 102 | result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] 103 | return result 104 | 105 | def change_style(self): 106 | z1 = torch.randn(1,512).to("cuda") 107 | x1, w2 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_latents=True, truncation=self.truncation, truncation_latent=self.net.latent_avg[0]) 108 | self.w1_after[:,6:] = w2.detach()[:,0] 109 | x1, _ = self.net.decoder([self.w1_after], input_is_latent=True, randomize_noise=False, return_latents=False) 110 | result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] 111 | return result 112 | 113 | def reset(self): 114 | x1, _ = self.net.decoder([self.w1], input_is_latent=True, randomize_noise=False, return_latents=False) 115 | result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] 116 | return result 117 | -------------------------------------------------------------------------------- /interface/templates/index.html: -------------------------------------------------------------------------------- 1 |  2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 12 | 18 |
10 | 11 | 13 |
14 | 15 |
16 | 17 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
Mouse drag:Translation
Middle mouse button:Set anchor point
Mouse wheel:Zoom in & out
'i' or 'o' key + mouse drag:Translation with zooming in & out
's' key:style mixing
26 |
27 | 28 | 194 | 195 | -------------------------------------------------------------------------------- /licenses/LICENSE_ gengshan-y_expansion: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Carnegie Mellon University 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /licenses/LICENSE_HuangYG123: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 HuangYG123 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /licenses/LICENSE_S-aiueo32: -------------------------------------------------------------------------------- 1 | BSD 2-Clause License 2 | 3 | Copyright (c) 2020, Sou Uchida 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /licenses/LICENSE_TreB1eN: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 TreB1eN 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /licenses/LICENSE_lessw2020: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /licenses/LICENSE_pixel2style2pixel: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Elad Richardson, Yuval Alaluf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /licenses/LICENSE_rosinality: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /models/StyleGANControler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from models.networks import latent_transformer 4 | from models.stylegan2.model import Generator 5 | import numpy as np 6 | 7 | def get_keys(d, name): 8 | if 'state_dict' in d: 9 | d = d['state_dict'] 10 | d_filt = {k[len(name) + 1:]: v for k, v in d.items() if k[:len(name)] == name} 11 | return d_filt 12 | 13 | 14 | class StyleGANControler(nn.Module): 15 | 16 | def __init__(self, opts): 17 | super(StyleGANControler, self).__init__() 18 | self.set_opts(opts) 19 | # Define architecture 20 | 21 | if 'ffhq' in self.opts.stylegan_weights: 22 | self.style_num = 18 23 | elif 'car' in self.opts.stylegan_weights: 24 | self.style_num = 16 25 | elif 'cat' in self.opts.stylegan_weights: 26 | self.style_num = 14 27 | elif 'church' in self.opts.stylegan_weights: 28 | self.style_num = 14 29 | elif 'anime' in self.opts.stylegan_weights: 30 | self.style_num = 16 31 | else: 32 | self.style_num = 18 #Please modify to adjust network architecture to your pre-trained StyleGAN2 33 | 34 | self.encoder = self.set_encoder() 35 | if self.style_num==18: 36 | self.decoder = Generator(1024, 512, 8, channel_multiplier=2) 37 | elif self.style_num==16: 38 | self.decoder = Generator(512, 512, 8, channel_multiplier=2) 39 | elif self.style_num==14: 40 | self.decoder = Generator(256, 512, 8, channel_multiplier=2) 41 | 42 | self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) 43 | 44 | # Load weights if needed 45 | self.load_weights() 46 | 47 | def set_encoder(self): 48 | encoder = latent_transformer.Network(self.opts) 49 | return encoder 50 | 51 | def load_weights(self): 52 | if self.opts.checkpoint_path is not None: 53 | print('Loading from checkpoint: {}'.format(self.opts.checkpoint_path)) 54 | ckpt = torch.load(self.opts.checkpoint_path, map_location='cpu') 55 | self.encoder.load_state_dict(get_keys(ckpt, 'encoder'), strict=True) 56 | self.decoder.load_state_dict(get_keys(ckpt, 'decoder'), strict=True) 57 | self.__load_latent_avg(ckpt) 58 | else: 59 | print('Loading decoder weights from pretrained!') 60 | ckpt = torch.load(self.opts.stylegan_weights) 61 | self.decoder.load_state_dict(ckpt['g_ema'], strict=True) 62 | self.__load_latent_avg(ckpt, repeat=self.opts.style_num) 63 | 64 | def set_opts(self, opts): 65 | self.opts = opts 66 | 67 | def __load_latent_avg(self, ckpt, repeat=None): 68 | if 'latent_avg' in ckpt: 69 | self.latent_avg = ckpt['latent_avg'].to(self.opts.device) 70 | if repeat is not None: 71 | self.latent_avg = self.latent_avg.repeat(repeat, 1) 72 | else: 73 | self.latent_avg = None 74 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/models/__init__.py -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/models/networks/__init__.py -------------------------------------------------------------------------------- /models/networks/latent_transformer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn, einsum 5 | from einops import rearrange 6 | 7 | # classes 8 | class PreNorm(nn.Module): 9 | def __init__(self, dim, fn): 10 | super().__init__() 11 | self.norm = nn.LayerNorm(dim) 12 | self.fn = fn 13 | def forward(self, x, **kwargs): 14 | return self.fn(self.norm(x), **kwargs) 15 | 16 | class FeedForward(nn.Module): 17 | def __init__(self, dim, hidden_dim, dropout = 0.): 18 | super().__init__() 19 | self.net = nn.Sequential( 20 | nn.Linear(dim, hidden_dim), 21 | nn.GELU(), 22 | nn.Dropout(dropout), 23 | nn.Linear(hidden_dim, dim), 24 | nn.Dropout(dropout) 25 | ) 26 | def forward(self, x): 27 | return self.net(x) 28 | 29 | class Attention(nn.Module): 30 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 31 | super().__init__() 32 | inner_dim = dim_head * heads 33 | project_out = not (heads == 1 and dim_head == dim) 34 | 35 | self.heads = heads 36 | self.scale = dim_head ** -0.5 37 | 38 | self.attend = nn.Softmax(dim = -1) 39 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 40 | 41 | self.to_out = nn.Sequential( 42 | nn.Linear(inner_dim, dim), 43 | nn.Dropout(dropout) 44 | ) if project_out else nn.Identity() 45 | 46 | def forward(self, x): 47 | qkv = self.to_qkv(x).chunk(3, dim = -1) 48 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 49 | 50 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 51 | 52 | attn = self.attend(dots) 53 | 54 | out = torch.matmul(attn, v) 55 | out = rearrange(out, 'b h n d -> b n (h d)') 56 | return self.to_out(out) 57 | 58 | class CrossAttention(nn.Module): 59 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 60 | super().__init__() 61 | inner_dim = dim_head * heads 62 | project_out = not (heads == 1 and dim_head == dim) 63 | 64 | self.heads = heads 65 | self.scale = dim_head ** -0.5 66 | 67 | self.to_k = nn.Linear(dim, inner_dim , bias=False) 68 | self.to_v = nn.Linear(dim, inner_dim , bias = False) 69 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 70 | 71 | self.to_out = nn.Sequential( 72 | nn.Linear(inner_dim, dim), 73 | nn.Dropout(dropout) 74 | ) if project_out else nn.Identity() 75 | 76 | def forward(self, x_qkv, query_length=1): 77 | h = self.heads 78 | 79 | k = self.to_k(x_qkv)[:, query_length:] 80 | k = rearrange(k, 'b n (h d) -> b h n d', h = h) 81 | 82 | v = self.to_v(x_qkv)[:, query_length:] 83 | v = rearrange(v, 'b n (h d) -> b h n d', h = h) 84 | 85 | q = self.to_q(x_qkv)[:, :query_length] 86 | q = rearrange(q, 'b n (h d) -> b h n d', h = h) 87 | 88 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 89 | 90 | attn = dots.softmax(dim=-1) 91 | 92 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 93 | out = rearrange(out, 'b h n d -> b n (h d)') 94 | out = self.to_out(out) 95 | 96 | return out 97 | 98 | class TransformerEncoder(nn.Module): 99 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 100 | super().__init__() 101 | self.layers = nn.ModuleList([]) 102 | for _ in range(depth): 103 | self.layers.append(nn.ModuleList([ 104 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 105 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 106 | ])) 107 | def forward(self, x): 108 | for attn, ff in self.layers: 109 | x = attn(x) + x 110 | x = ff(x) + x 111 | return x 112 | 113 | class TransformerDecoder(nn.Module): 114 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 115 | super().__init__() 116 | self.pos_embedding = nn.Parameter(torch.randn(1, 6, dim)) 117 | self.layers = nn.ModuleList([]) 118 | for _ in range(depth): 119 | self.layers.append(nn.ModuleList([ 120 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 121 | PreNorm(dim, CrossAttention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 122 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 123 | ])) 124 | def forward(self, x, y): 125 | x = x + self.pos_embedding[:, :x.shape[1]] 126 | for sattn, cattn, ff in self.layers: 127 | x = sattn(x) + x 128 | xy = torch.cat((x,y), dim=1) 129 | x = cattn(xy, query_length=x.shape[1]) + x 130 | x = ff(x) + x 131 | return x 132 | 133 | class Network(nn.Module): 134 | def __init__(self, opts): 135 | super(Network, self).__init__() 136 | 137 | self.transformer_encoder = TransformerEncoder(dim=512, depth=6, heads=8, dim_head=64, mlp_dim=512, dropout=0) 138 | self.transformer_decoder = TransformerDecoder(dim=512, depth=6, heads=8, dim_head=64, mlp_dim=512, dropout=0) 139 | self.layer1 = nn.Linear(3, 256) 140 | self.layer2 = nn.Linear(512, 256) 141 | self.layer3 = nn.Linear(512, 512) 142 | self.layer4 = nn.Linear(512, 512) 143 | self.mlp_head = nn.Sequential( 144 | nn.Linear(512, 512) 145 | ) 146 | 147 | def forward(self, w, x, y, alpha=1.): 148 | #w: latent vectors 149 | #x: flow vectors 150 | #y: StyleGAN features 151 | xh = F.relu(self.layer1(x)) 152 | yh = F.relu(self.layer2(y)) 153 | xyh = torch.cat([xh,yh], dim=2) 154 | xyh = F.relu(self.layer3(xyh)) 155 | xyh = self.transformer_encoder(xyh) 156 | 157 | wh = F.relu(self.layer4(w)) 158 | 159 | h = self.transformer_decoder(wh, xyh) 160 | h = self.mlp_head(h) 161 | w_hat = w+alpha*h 162 | return w_hat 163 | -------------------------------------------------------------------------------- /models/stylegan2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/models/stylegan2/__init__.py -------------------------------------------------------------------------------- /models/stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /models/stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | module_path = os.path.dirname(__file__) 9 | fused = load( 10 | 'fused', 11 | sources=[ 12 | os.path.join(module_path, 'fused_bias_act.cpp'), 13 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 14 | ], 15 | ) 16 | 17 | 18 | class FusedLeakyReLUFunctionBackward(Function): 19 | @staticmethod 20 | def forward(ctx, grad_output, out, negative_slope, scale): 21 | ctx.save_for_backward(out) 22 | ctx.negative_slope = negative_slope 23 | ctx.scale = scale 24 | 25 | empty = grad_output.new_empty(0) 26 | 27 | grad_input = fused.fused_bias_act( 28 | grad_output, empty, out, 3, 1, negative_slope, scale 29 | ) 30 | 31 | dim = [0] 32 | 33 | if grad_input.ndim > 2: 34 | dim += list(range(2, grad_input.ndim)) 35 | 36 | grad_bias = grad_input.sum(dim).detach() 37 | 38 | return grad_input, grad_bias 39 | 40 | @staticmethod 41 | def backward(ctx, gradgrad_input, gradgrad_bias): 42 | out, = ctx.saved_tensors 43 | gradgrad_out = fused.fused_bias_act( 44 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 45 | ) 46 | 47 | return gradgrad_out, None, None, None 48 | 49 | 50 | class FusedLeakyReLUFunction(Function): 51 | @staticmethod 52 | def forward(ctx, input, bias, negative_slope, scale): 53 | empty = input.new_empty(0) 54 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 55 | ctx.save_for_backward(out) 56 | ctx.negative_slope = negative_slope 57 | ctx.scale = scale 58 | 59 | return out 60 | 61 | @staticmethod 62 | def backward(ctx, grad_output): 63 | out, = ctx.saved_tensors 64 | 65 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 66 | grad_output, out, ctx.negative_slope, ctx.scale 67 | ) 68 | 69 | return grad_input, grad_bias, None, None 70 | 71 | 72 | class FusedLeakyReLU(nn.Module): 73 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 74 | super().__init__() 75 | 76 | self.bias = nn.Parameter(torch.zeros(channel)) 77 | self.negative_slope = negative_slope 78 | self.scale = scale 79 | 80 | def forward(self, input): 81 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 82 | 83 | 84 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 85 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 86 | -------------------------------------------------------------------------------- /models/stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /models/stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /models/stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /models/stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.utils.cpp_extension import load 6 | 7 | module_path = os.path.dirname(__file__) 8 | upfirdn2d_op = load( 9 | 'upfirdn2d', 10 | sources=[ 11 | os.path.join(module_path, 'upfirdn2d.cpp'), 12 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 13 | ], 14 | ) 15 | 16 | 17 | class UpFirDn2dBackward(Function): 18 | @staticmethod 19 | def forward( 20 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 21 | ): 22 | up_x, up_y = up 23 | down_x, down_y = down 24 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 25 | 26 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 27 | 28 | grad_input = upfirdn2d_op.upfirdn2d( 29 | grad_output, 30 | grad_kernel, 31 | down_x, 32 | down_y, 33 | up_x, 34 | up_y, 35 | g_pad_x0, 36 | g_pad_x1, 37 | g_pad_y0, 38 | g_pad_y1, 39 | ) 40 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 41 | 42 | ctx.save_for_backward(kernel) 43 | 44 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 45 | 46 | ctx.up_x = up_x 47 | ctx.up_y = up_y 48 | ctx.down_x = down_x 49 | ctx.down_y = down_y 50 | ctx.pad_x0 = pad_x0 51 | ctx.pad_x1 = pad_x1 52 | ctx.pad_y0 = pad_y0 53 | ctx.pad_y1 = pad_y1 54 | ctx.in_size = in_size 55 | ctx.out_size = out_size 56 | 57 | return grad_input 58 | 59 | @staticmethod 60 | def backward(ctx, gradgrad_input): 61 | kernel, = ctx.saved_tensors 62 | 63 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 64 | 65 | gradgrad_out = upfirdn2d_op.upfirdn2d( 66 | gradgrad_input, 67 | kernel, 68 | ctx.up_x, 69 | ctx.up_y, 70 | ctx.down_x, 71 | ctx.down_y, 72 | ctx.pad_x0, 73 | ctx.pad_x1, 74 | ctx.pad_y0, 75 | ctx.pad_y1, 76 | ) 77 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 78 | gradgrad_out = gradgrad_out.view( 79 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 80 | ) 81 | 82 | return gradgrad_out, None, None, None, None, None, None, None, None 83 | 84 | 85 | class UpFirDn2d(Function): 86 | @staticmethod 87 | def forward(ctx, input, kernel, up, down, pad): 88 | up_x, up_y = up 89 | down_x, down_y = down 90 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 91 | 92 | kernel_h, kernel_w = kernel.shape 93 | batch, channel, in_h, in_w = input.shape 94 | ctx.in_size = input.shape 95 | 96 | input = input.reshape(-1, in_h, in_w, 1) 97 | 98 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 99 | 100 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 101 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 102 | ctx.out_size = (out_h, out_w) 103 | 104 | ctx.up = (up_x, up_y) 105 | ctx.down = (down_x, down_y) 106 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 107 | 108 | g_pad_x0 = kernel_w - pad_x0 - 1 109 | g_pad_y0 = kernel_h - pad_y0 - 1 110 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 111 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 112 | 113 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 114 | 115 | out = upfirdn2d_op.upfirdn2d( 116 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 117 | ) 118 | # out = out.view(major, out_h, out_w, minor) 119 | out = out.view(-1, channel, out_h, out_w) 120 | 121 | return out 122 | 123 | @staticmethod 124 | def backward(ctx, grad_output): 125 | kernel, grad_kernel = ctx.saved_tensors 126 | 127 | grad_input = UpFirDn2dBackward.apply( 128 | grad_output, 129 | kernel, 130 | grad_kernel, 131 | ctx.up, 132 | ctx.down, 133 | ctx.pad, 134 | ctx.g_pad, 135 | ctx.in_size, 136 | ctx.out_size, 137 | ) 138 | 139 | return grad_input, None, None, None, None 140 | 141 | 142 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 143 | out = UpFirDn2d.apply( 144 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 145 | ) 146 | 147 | return out 148 | 149 | 150 | def upfirdn2d_native( 151 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 152 | ): 153 | _, in_h, in_w, minor = input.shape 154 | kernel_h, kernel_w = kernel.shape 155 | 156 | out = input.view(-1, in_h, 1, in_w, 1, minor) 157 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 158 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 159 | 160 | out = F.pad( 161 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 162 | ) 163 | out = out[ 164 | :, 165 | max(-pad_y0, 0): out.shape[1] - max(-pad_y1, 0), 166 | max(-pad_x0, 0): out.shape[2] - max(-pad_x1, 0), 167 | :, 168 | ] 169 | 170 | out = out.permute(0, 3, 1, 2) 171 | out = out.reshape( 172 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 173 | ) 174 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 175 | out = F.conv2d(out, w) 176 | out = out.reshape( 177 | -1, 178 | minor, 179 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 180 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 181 | ) 182 | out = out.permute(0, 2, 3, 1) 183 | 184 | return out[:, ::down_y, ::down_x, :] 185 | -------------------------------------------------------------------------------- /models/stylegan2/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 19 | int c = a / b; 20 | 21 | if (c * b > a) { 22 | c--; 23 | } 24 | 25 | return c; 26 | } 27 | 28 | 29 | struct UpFirDn2DKernelParams { 30 | int up_x; 31 | int up_y; 32 | int down_x; 33 | int down_y; 34 | int pad_x0; 35 | int pad_x1; 36 | int pad_y0; 37 | int pad_y1; 38 | 39 | int major_dim; 40 | int in_h; 41 | int in_w; 42 | int minor_dim; 43 | int kernel_h; 44 | int kernel_w; 45 | int out_h; 46 | int out_w; 47 | int loop_major; 48 | int loop_x; 49 | }; 50 | 51 | 52 | template 53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { 54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 56 | 57 | __shared__ volatile float sk[kernel_h][kernel_w]; 58 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 59 | 60 | int minor_idx = blockIdx.x; 61 | int tile_out_y = minor_idx / p.minor_dim; 62 | minor_idx -= tile_out_y * p.minor_dim; 63 | tile_out_y *= tile_out_h; 64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 65 | int major_idx_base = blockIdx.z * p.loop_major; 66 | 67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { 68 | return; 69 | } 70 | 71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { 72 | int ky = tap_idx / kernel_w; 73 | int kx = tap_idx - ky * kernel_w; 74 | scalar_t v = 0.0; 75 | 76 | if (kx < p.kernel_w & ky < p.kernel_h) { 77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 78 | } 79 | 80 | sk[ky][kx] = v; 81 | } 82 | 83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { 84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { 85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 87 | int tile_in_x = floor_div(tile_mid_x, up_x); 88 | int tile_in_y = floor_div(tile_mid_y, up_y); 89 | 90 | __syncthreads(); 91 | 92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { 93 | int rel_in_y = in_idx / tile_in_w; 94 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 95 | int in_x = rel_in_x + tile_in_x; 96 | int in_y = rel_in_y + tile_in_y; 97 | 98 | scalar_t v = 0.0; 99 | 100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; 102 | } 103 | 104 | sx[rel_in_y][rel_in_x] = v; 105 | } 106 | 107 | __syncthreads(); 108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { 109 | int rel_out_y = out_idx / tile_out_w; 110 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 111 | int out_x = rel_out_x + tile_out_x; 112 | int out_y = rel_out_y + tile_out_y; 113 | 114 | int mid_x = tile_mid_x + rel_out_x * down_x; 115 | int mid_y = tile_mid_y + rel_out_y * down_y; 116 | int in_x = floor_div(mid_x, up_x); 117 | int in_y = floor_div(mid_y, up_y); 118 | int rel_in_x = in_x - tile_in_x; 119 | int rel_in_y = in_y - tile_in_y; 120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 122 | 123 | scalar_t v = 0.0; 124 | 125 | #pragma unroll 126 | for (int y = 0; y < kernel_h / up_y; y++) 127 | #pragma unroll 128 | for (int x = 0; x < kernel_w / up_x; x++) 129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; 130 | 131 | if (out_x < p.out_w & out_y < p.out_h) { 132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; 133 | } 134 | } 135 | } 136 | } 137 | } 138 | 139 | 140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 141 | int up_x, int up_y, int down_x, int down_y, 142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 143 | int curDevice = -1; 144 | cudaGetDevice(&curDevice); 145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 146 | 147 | UpFirDn2DKernelParams p; 148 | 149 | auto x = input.contiguous(); 150 | auto k = kernel.contiguous(); 151 | 152 | p.major_dim = x.size(0); 153 | p.in_h = x.size(1); 154 | p.in_w = x.size(2); 155 | p.minor_dim = x.size(3); 156 | p.kernel_h = k.size(0); 157 | p.kernel_w = k.size(1); 158 | p.up_x = up_x; 159 | p.up_y = up_y; 160 | p.down_x = down_x; 161 | p.down_y = down_y; 162 | p.pad_x0 = pad_x0; 163 | p.pad_x1 = pad_x1; 164 | p.pad_y0 = pad_y0; 165 | p.pad_y1 = pad_y1; 166 | 167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; 168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; 169 | 170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 171 | 172 | int mode = -1; 173 | 174 | int tile_out_h; 175 | int tile_out_w; 176 | 177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 178 | mode = 1; 179 | tile_out_h = 16; 180 | tile_out_w = 64; 181 | } 182 | 183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { 184 | mode = 2; 185 | tile_out_h = 16; 186 | tile_out_w = 64; 187 | } 188 | 189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 190 | mode = 3; 191 | tile_out_h = 16; 192 | tile_out_w = 64; 193 | } 194 | 195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { 196 | mode = 4; 197 | tile_out_h = 16; 198 | tile_out_w = 64; 199 | } 200 | 201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { 202 | mode = 5; 203 | tile_out_h = 8; 204 | tile_out_w = 32; 205 | } 206 | 207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { 208 | mode = 6; 209 | tile_out_h = 8; 210 | tile_out_w = 32; 211 | } 212 | 213 | dim3 block_size; 214 | dim3 grid_size; 215 | 216 | if (tile_out_h > 0 && tile_out_w) { 217 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 218 | p.loop_x = 1; 219 | block_size = dim3(32 * 8, 1, 1); 220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 222 | (p.major_dim - 1) / p.loop_major + 1); 223 | } 224 | 225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 226 | switch (mode) { 227 | case 1: 228 | upfirdn2d_kernel<<>>( 229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 230 | ); 231 | 232 | break; 233 | 234 | case 2: 235 | upfirdn2d_kernel<<>>( 236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 237 | ); 238 | 239 | break; 240 | 241 | case 3: 242 | upfirdn2d_kernel<<>>( 243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 244 | ); 245 | 246 | break; 247 | 248 | case 4: 249 | upfirdn2d_kernel<<>>( 250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 251 | ); 252 | 253 | break; 254 | 255 | case 5: 256 | upfirdn2d_kernel<<>>( 257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 258 | ); 259 | 260 | break; 261 | 262 | case 6: 263 | upfirdn2d_kernel<<>>( 264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 265 | ); 266 | 267 | break; 268 | } 269 | }); 270 | 271 | return out; 272 | } -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/options/__init__.py -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | class TrainOptions: 4 | 5 | def __init__(self): 6 | self.parser = ArgumentParser() 7 | self.initialize() 8 | 9 | def initialize(self): 10 | self.parser.add_argument('--exp_dir', type=str, help='Path to experiment output directory') 11 | 12 | self.parser.add_argument('--batch_size', default=1, type=int, help='Batch size for training') 13 | self.parser.add_argument('--learning_rate', default=0.001, type=float, help='Optimizer learning rate') 14 | self.parser.add_argument('--optim_name', default='ranger', type=str, help='Which optimizer to use') 15 | self.parser.add_argument('--train_decoder', default=False, type=bool, help='Whether to train the decoder model') 16 | 17 | self.parser.add_argument('--lpips_lambda', default=0., type=float, help='LPIPS loss multiplier factor') 18 | self.parser.add_argument('--l2_lambda', default=0, type=float, help='L2 loss multiplier factor') 19 | self.parser.add_argument('--l2latent_lambda', default=1.0, type=float, help='L2 loss multiplier factor') 20 | 21 | self.parser.add_argument('--stylegan_weights', default='pretrained_models/stylegan2-cat-config-f.pt', type=str, help='Path to StyleGAN model weights') 22 | self.parser.add_argument('--checkpoint_path', default=None, type=str, help='Path to pSp model checkpoint') 23 | 24 | self.parser.add_argument('--max_steps', default=60100, type=int, help='Maximum number of training steps') 25 | self.parser.add_argument('--image_interval', default=100, type=int, help='Interval for logging train images during training') 26 | self.parser.add_argument('--save_interval', default=10000, type=int, help='Model checkpoint interval') 27 | 28 | self.parser.add_argument('--style_num', default=14, type=int, help='The number of StyleGAN layers get latent codes ') 29 | self.parser.add_argument('--channel_multiplier', default=2, type=int, help='StyleGAN parameter') 30 | 31 | def parse(self): 32 | opts = self.parser.parse_args() 33 | return opts -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file runs the main training/val loop 3 | """ 4 | import os 5 | import json 6 | import sys 7 | import pprint 8 | 9 | sys.path.append(".") 10 | sys.path.append("..") 11 | 12 | from options.train_options import TrainOptions 13 | from training.coach import Coach 14 | 15 | 16 | def main(): 17 | opts = TrainOptions().parse() 18 | os.makedirs(opts.exp_dir, exist_ok=True) 19 | 20 | opts_dict = vars(opts) 21 | pprint.pprint(opts_dict) 22 | with open(os.path.join(opts.exp_dir, 'opt.json'), 'w') as f: 23 | json.dump(opts_dict, f, indent=4, sort_keys=True) 24 | 25 | coach = Coach(opts) 26 | coach.train() 27 | 28 | if __name__ == '__main__': 29 | main() 30 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/training/__init__.py -------------------------------------------------------------------------------- /training/coach.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math, random 3 | import numpy as np 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | matplotlib.use('Agg') 7 | 8 | import torch 9 | from torch import nn 10 | from torch.utils.tensorboard import SummaryWriter 11 | import torch.nn.functional as F 12 | 13 | from utils import common 14 | from criteria.lpips.lpips import LPIPS 15 | from models.StyleGANControler import StyleGANControler 16 | from training.ranger import Ranger 17 | 18 | from expansion.submission import Expansion 19 | from expansion.utils.flowlib import point_vec 20 | 21 | class Coach: 22 | def __init__(self, opts): 23 | self.opts = opts 24 | if self.opts.checkpoint_path is None: 25 | self.global_step = 0 26 | else: 27 | self.global_step = int(os.path.splitext(os.path.basename(self.opts.checkpoint_path))[0].split('_')[-1]) 28 | 29 | self.device = 'cuda:0' # TODO: Allow multiple GPU? currently using CUDA_VISIBLE_DEVICES 30 | self.opts.device = self.device 31 | 32 | # Initialize network 33 | self.net = StyleGANControler(self.opts).to(self.device) 34 | 35 | # Initialize loss 36 | if self.opts.lpips_lambda > 0: 37 | self.lpips_loss = LPIPS(net_type='alex').to(self.device).eval() 38 | self.mse_loss = nn.MSELoss().to(self.device).eval() 39 | 40 | # Initialize optimizer 41 | self.optimizer = self.configure_optimizers() 42 | 43 | # Initialize logger 44 | log_dir = os.path.join(opts.exp_dir, 'logs') 45 | os.makedirs(log_dir, exist_ok=True) 46 | self.logger = SummaryWriter(log_dir=log_dir) 47 | 48 | # Initialize checkpoint dir 49 | self.checkpoint_dir = os.path.join(opts.exp_dir, 'checkpoints') 50 | os.makedirs(self.checkpoint_dir, exist_ok=True) 51 | self.best_val_loss = None 52 | if self.opts.save_interval is None: 53 | self.opts.save_interval = self.opts.max_steps 54 | 55 | # Initialize optical flow estimator 56 | self.ex = Expansion() 57 | 58 | # Set flow normalization values 59 | if 'ffhq' in self.opts.stylegan_weights: 60 | self.sigma_f = 4 61 | self.sigma_e = 0.02 62 | elif 'car' in self.opts.stylegan_weights: 63 | self.sigma_f = 5 64 | self.sigma_e = 0.03 65 | elif 'cat' in self.opts.stylegan_weights: 66 | self.sigma_f = 12 67 | self.sigma_e = 0.04 68 | elif 'church' in self.opts.stylegan_weights: 69 | self.sigma_f = 8 70 | self.sigma_e = 0.02 71 | elif 'anime' in self.opts.stylegan_weights: 72 | self.sigma_f = 7 73 | self.sigma_e = 0.025 74 | 75 | def train(self, truncation = 0.3, sigma = 0.1, target_layers = [0,1,2,3,4,5]): 76 | 77 | x = np.array(range(0,256,16)).astype(np.float32)/127.5-1. 78 | y = np.array(range(0,256,16)).astype(np.float32)/127.5-1. 79 | xx, yy = np.meshgrid(x,y) 80 | grid = np.concatenate([xx[:,:,None],yy[:,:,None]], axis=2) 81 | grid = torch.from_numpy(grid[None,:]).cuda() 82 | grid = grid.repeat(self.opts.batch_size,1,1,1) 83 | 84 | while self.global_step < self.opts.max_steps: 85 | with torch.no_grad(): 86 | z1 = torch.randn(self.opts.batch_size,512).to("cuda") 87 | z2 = torch.randn(self.opts.batch_size,self.net.style_num, 512).to("cuda") 88 | 89 | x1, w1, f1 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_feature_map=True,return_latents=True,truncation=truncation, truncation_latent=self.net.latent_avg[0]) 90 | x1 = self.net.face_pool(x1) 91 | x2, w2 = self.net.decoder([z2],input_is_latent=False,randomize_noise=False,return_latents=True, truncation_latent=self.net.latent_avg[0]) 92 | x2 = self.net.face_pool(x2) 93 | w_mid = w1.clone() 94 | w_mid[:,target_layers] = w_mid[:,target_layers]+sigma*(w2[:,target_layers]-w_mid[:,target_layers]) 95 | x_mid, _ = self.net.decoder([w_mid], input_is_latent=True, randomize_noise=False, return_latents=False) 96 | x_mid = self.net.face_pool(x_mid) 97 | 98 | flow, logexp = self.ex.run(x1.detach(),x_mid.detach()) 99 | flow_feature = torch.cat([flow/self.sigma_f, logexp/self.sigma_e], dim=1) 100 | f1 = F.interpolate(f1, (flow_feature.shape[2:])) 101 | f1 = F.grid_sample(f1, grid, mode='nearest', align_corners=True) 102 | flow_feature = F.grid_sample(flow_feature, grid, mode='nearest', align_corners=True) 103 | flow_feature = flow_feature.view(flow_feature.shape[0], flow_feature.shape[1], -1).permute(0,2,1) 104 | f1 = f1.view(f1.shape[0], f1.shape[1], -1).permute(0,2,1) 105 | 106 | self.net.train() 107 | self.optimizer.zero_grad() 108 | w_hat = self.net.encoder(w1[:,target_layers].detach(), flow_feature.detach(), f1.detach()) 109 | loss, loss_dict, id_logs = self.calc_loss(w_hat, w_mid[:,target_layers].detach()) 110 | loss.backward() 111 | self.optimizer.step() 112 | 113 | w_mid[:,target_layers] = w_hat.detach() 114 | x_hat, _ = self.net.decoder([w_mid], input_is_latent=True, randomize_noise=False) 115 | x_hat = self.net.face_pool(x_hat) 116 | if self.global_step % self.opts.image_interval == 0 or ( 117 | self.global_step < 1000 and self.global_step % 100 == 0): 118 | imgL_o = ((x1.detach()+1.)*127.5)[0].permute(1,2,0).cpu().numpy() 119 | flow = torch.cat((flow,torch.ones_like(flow)[:,:1]), dim=1)[0].permute(1,2,0).cpu().numpy() 120 | flowvis = point_vec(imgL_o, flow) 121 | flowvis = torch.from_numpy(flowvis[:,:,::-1].copy()).permute(2,0,1).unsqueeze(0)/127.5-1. 122 | self.parse_and_log_images(None, flowvis, x_mid, x_hat, title='trained_images') 123 | print(loss_dict) 124 | 125 | if self.global_step % self.opts.save_interval == 0 or self.global_step == self.opts.max_steps: 126 | self.checkpoint_me(loss_dict, is_best=False) 127 | 128 | if self.global_step == self.opts.max_steps: 129 | print('OMG, finished training!') 130 | break 131 | 132 | self.global_step += 1 133 | 134 | def checkpoint_me(self, loss_dict, is_best): 135 | save_name = 'best_model.pt' if is_best else 'iteration_{}.pt'.format(self.global_step) 136 | save_dict = self.__get_save_dict() 137 | checkpoint_path = os.path.join(self.checkpoint_dir, save_name) 138 | torch.save(save_dict, checkpoint_path) 139 | with open(os.path.join(self.checkpoint_dir, 'timestamp.txt'), 'a') as f: 140 | if is_best: 141 | f.write('**Best**: Step - {}, Loss - {:.3f} \n{}\n'.format(self.global_step, self.best_val_loss, loss_dict)) 142 | else: 143 | f.write('Step - {}, \n{}\n'.format(self.global_step, loss_dict)) 144 | 145 | def configure_optimizers(self): 146 | params = list(self.net.encoder.parameters()) 147 | if self.opts.train_decoder: 148 | params += list(self.net.decoder.parameters()) 149 | if self.opts.optim_name == 'adam': 150 | optimizer = torch.optim.Adam(params, lr=self.opts.learning_rate) 151 | else: 152 | optimizer = Ranger(params, lr=self.opts.learning_rate) 153 | return optimizer 154 | 155 | def calc_loss(self, latent, w, y_hat=None, y=None): 156 | loss_dict = {} 157 | loss = 0.0 158 | id_logs = None 159 | 160 | if self.opts.l2_lambda > 0 and (y_hat is not None) and (y is not None): 161 | loss_l2 = F.mse_loss(y_hat, y) 162 | loss_dict['loss_l2'] = float(loss_l2) 163 | loss += loss_l2 * self.opts.l2_lambda 164 | if self.opts.lpips_lambda > 0 and (y_hat is not None) and (y is not None): 165 | loss_lpips = self.lpips_loss(y_hat, y) 166 | loss_dict['loss_lpips'] = float(loss_lpips) 167 | loss += loss_lpips * self.opts.lpips_lambda 168 | if self.opts.l2latent_lambda > 0: 169 | loss_l2 = F.mse_loss(latent, w) 170 | loss_dict['loss_l2latent'] = float(loss_l2) 171 | loss += loss_l2 * self.opts.l2latent_lambda 172 | 173 | loss_dict['loss'] = float(loss) 174 | return loss, loss_dict, id_logs 175 | 176 | def parse_and_log_images(self, id_logs, x, y, y_hat, title, subscript=None, display_count=1): 177 | im_data = [] 178 | for i in range(display_count): 179 | cur_im_data = { 180 | 'input_face': common.tensor2im(x[i]), 181 | 'target_face': common.tensor2im(y[i]), 182 | 'output_face': common.tensor2im(y_hat[i]), 183 | } 184 | if id_logs is not None: 185 | for key in id_logs[i]: 186 | cur_im_data[key] = id_logs[i][key] 187 | im_data.append(cur_im_data) 188 | self.log_images(title, im_data=im_data, subscript=subscript) 189 | 190 | 191 | def log_images(self, name, im_data, subscript=None, log_latest=False): 192 | fig = common.vis_faces(im_data) 193 | step = self.global_step 194 | if log_latest: 195 | step = 0 196 | if subscript: 197 | path = os.path.join(self.logger.log_dir, name, '{}_{:04d}.jpg'.format(subscript, step)) 198 | else: 199 | path = os.path.join(self.logger.log_dir, name, '{:04d}.jpg'.format(step)) 200 | os.makedirs(os.path.dirname(path), exist_ok=True) 201 | fig.savefig(path) 202 | plt.close(fig) 203 | 204 | def __get_save_dict(self): 205 | save_dict = { 206 | 'state_dict': self.net.state_dict(), 207 | 'opts': vars(self.opts) 208 | } 209 | 210 | save_dict['latent_avg'] = self.net.latent_avg 211 | return save_dict -------------------------------------------------------------------------------- /training/ranger.py: -------------------------------------------------------------------------------- 1 | # Ranger deep learning optimizer - RAdam + Lookahead + Gradient Centralization, combined into one optimizer. 2 | 3 | # https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer 4 | # and/or 5 | # https://github.com/lessw2020/Best-Deep-Learning-Optimizers 6 | 7 | # Ranger has now been used to capture 12 records on the FastAI leaderboard. 8 | 9 | # This version = 20.4.11 10 | 11 | # Credits: 12 | # Gradient Centralization --> https://arxiv.org/abs/2004.01461v2 (a new optimization technique for DNNs), github: https://github.com/Yonghongwei/Gradient-Centralization 13 | # RAdam --> https://github.com/LiyuanLucasLiu/RAdam 14 | # Lookahead --> rewritten by lessw2020, but big thanks to Github @LonePatient and @RWightman for ideas from their code. 15 | # Lookahead paper --> MZhang,G Hinton https://arxiv.org/abs/1907.08610 16 | 17 | # summary of changes: 18 | # 4/11/20 - add gradient centralization option. Set new testing benchmark for accuracy with it, toggle with use_gc flag at init. 19 | # full code integration with all updates at param level instead of group, moves slow weights into state dict (from generic weights), 20 | # supports group learning rates (thanks @SHolderbach), fixes sporadic load from saved model issues. 21 | # changes 8/31/19 - fix references to *self*.N_sma_threshold; 22 | # changed eps to 1e-5 as better default than 1e-8. 23 | 24 | import math 25 | import torch 26 | from torch.optim.optimizer import Optimizer 27 | 28 | 29 | class Ranger(Optimizer): 30 | 31 | def __init__(self, params, lr=1e-3, # lr 32 | alpha=0.5, k=6, N_sma_threshhold=5, # Ranger options 33 | betas=(.95, 0.999), eps=1e-5, weight_decay=0, # Adam options 34 | use_gc=True, gc_conv_only=False 35 | # Gradient centralization on or off, applied to conv layers only or conv + fc layers 36 | ): 37 | 38 | # parameter checks 39 | if not 0.0 <= alpha <= 1.0: 40 | raise ValueError(f'Invalid slow update rate: {alpha}') 41 | if not 1 <= k: 42 | raise ValueError(f'Invalid lookahead steps: {k}') 43 | if not lr > 0: 44 | raise ValueError(f'Invalid Learning Rate: {lr}') 45 | if not eps > 0: 46 | raise ValueError(f'Invalid eps: {eps}') 47 | 48 | # parameter comments: 49 | # beta1 (momentum) of .95 seems to work better than .90... 50 | # N_sma_threshold of 5 seems better in testing than 4. 51 | # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you. 52 | 53 | # prep defaults and init torch.optim base 54 | defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold, 55 | eps=eps, weight_decay=weight_decay) 56 | super().__init__(params, defaults) 57 | 58 | # adjustable threshold 59 | self.N_sma_threshhold = N_sma_threshhold 60 | 61 | # look ahead params 62 | 63 | self.alpha = alpha 64 | self.k = k 65 | 66 | # radam buffer for state 67 | self.radam_buffer = [[None, None, None] for ind in range(10)] 68 | 69 | # gc on or off 70 | self.use_gc = use_gc 71 | 72 | # level of gradient centralization 73 | self.gc_gradient_threshold = 3 if gc_conv_only else 1 74 | 75 | def __setstate__(self, state): 76 | super(Ranger, self).__setstate__(state) 77 | 78 | def step(self, closure=None): 79 | loss = None 80 | 81 | # Evaluate averages and grad, update param tensors 82 | for group in self.param_groups: 83 | 84 | for p in group['params']: 85 | if p.grad is None: 86 | continue 87 | grad = p.grad.data.float() 88 | 89 | if grad.is_sparse: 90 | raise RuntimeError('Ranger optimizer does not support sparse gradients') 91 | 92 | p_data_fp32 = p.data.float() 93 | 94 | state = self.state[p] # get state dict for this param 95 | 96 | if len(state) == 0: # if first time to run...init dictionary with our desired entries 97 | # if self.first_run_check==0: 98 | # self.first_run_check=1 99 | # print("Initializing slow buffer...should not see this at load from saved model!") 100 | state['step'] = 0 101 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 102 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 103 | 104 | # look ahead weight storage now in state dict 105 | state['slow_buffer'] = torch.empty_like(p.data) 106 | state['slow_buffer'].copy_(p.data) 107 | 108 | else: 109 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 110 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 111 | 112 | # begin computations 113 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 114 | beta1, beta2 = group['betas'] 115 | 116 | # GC operation for Conv layers and FC layers 117 | if grad.dim() > self.gc_gradient_threshold: 118 | grad.add_(-grad.mean(dim=tuple(range(1, grad.dim())), keepdim=True)) 119 | 120 | state['step'] += 1 121 | 122 | # compute variance mov avg 123 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 124 | # compute mean moving avg 125 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 126 | 127 | buffered = self.radam_buffer[int(state['step'] % 10)] 128 | 129 | if state['step'] == buffered[0]: 130 | N_sma, step_size = buffered[1], buffered[2] 131 | else: 132 | buffered[0] = state['step'] 133 | beta2_t = beta2 ** state['step'] 134 | N_sma_max = 2 / (1 - beta2) - 1 135 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 136 | buffered[1] = N_sma 137 | if N_sma > self.N_sma_threshhold: 138 | step_size = math.sqrt( 139 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 140 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 141 | else: 142 | step_size = 1.0 / (1 - beta1 ** state['step']) 143 | buffered[2] = step_size 144 | 145 | if group['weight_decay'] != 0: 146 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 147 | 148 | # apply lr 149 | if N_sma > self.N_sma_threshhold: 150 | denom = exp_avg_sq.sqrt().add_(group['eps']) 151 | p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size * group['lr']) 152 | else: 153 | p_data_fp32.add_(exp_avg, alpha=-step_size * group['lr']) 154 | 155 | p.data.copy_(p_data_fp32) 156 | 157 | # integrated look ahead... 158 | # we do it at the param level instead of group level 159 | if state['step'] % group['k'] == 0: 160 | slow_p = state['slow_buffer'] # get access to slow param tensor 161 | slow_p.add_(p.data - slow_p, alpha=self.alpha) # (fast weights - slow weights) * alpha 162 | p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor 163 | 164 | return loss -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/endo-yuki-t/UserControllableLT/a3734753fa2a5421f0e3ff2be98122fb073d6406/utils/__init__.py -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from PIL import Image 4 | import matplotlib.pyplot as plt 5 | import random 6 | 7 | 8 | # Log images 9 | def log_input_image(x, opts): 10 | if opts.label_nc == 0: 11 | return tensor2im(x) 12 | elif opts.label_nc == 1: 13 | return tensor2sketch(x) 14 | else: 15 | return tensor2map(x) 16 | 17 | 18 | def tensor2im(var): 19 | var = var.cpu().detach().transpose(0, 2).transpose(0, 1).numpy() 20 | var = ((var + 1) / 2) 21 | var[var < 0] = 0 22 | var[var > 1] = 1 23 | var = var * 255 24 | return Image.fromarray(var.astype('uint8')) 25 | 26 | 27 | def tensor2map(var): 28 | mask = np.argmax(var.data.cpu().numpy(), axis=0) 29 | colors = get_colors() 30 | mask_image = np.ones(shape=(mask.shape[0], mask.shape[1], 3)) 31 | for class_idx in np.unique(mask): 32 | mask_image[mask == class_idx] = colors[class_idx] 33 | mask_image = mask_image.astype('uint8') 34 | return Image.fromarray(mask_image) 35 | 36 | 37 | def tensor2sketch(var): 38 | im = var[0].cpu().detach().numpy() 39 | im = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) 40 | im = (im * 255).astype(np.uint8) 41 | return Image.fromarray(im) 42 | 43 | 44 | # Visualization utils 45 | def get_colors(): 46 | # currently support up to 19 classes (for the celebs-hq-mask dataset) 47 | colors = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], 48 | [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], 49 | [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] 50 | 51 | # asign random colors to more 200 classes 52 | random.seed(0) 53 | for i in range(200): 54 | colors.append([random.randint(0,255),random.randint(0,255),random.randint(0,255)]) 55 | return colors 56 | 57 | 58 | def vis_faces(log_hooks): 59 | display_count = len(log_hooks) 60 | fig = plt.figure(figsize=(8, 4 * display_count)) 61 | gs = fig.add_gridspec(display_count, 3) 62 | for i in range(display_count): 63 | hooks_dict = log_hooks[i] 64 | fig.add_subplot(gs[i, 0]) 65 | if 'diff_input' in hooks_dict: 66 | vis_faces_with_id(hooks_dict, fig, gs, i) 67 | else: 68 | vis_faces_no_id(hooks_dict, fig, gs, i) 69 | plt.tight_layout() 70 | return fig 71 | 72 | 73 | def vis_faces_with_id(hooks_dict, fig, gs, i): 74 | plt.imshow(hooks_dict['input_face']) 75 | plt.title('Input\nOut Sim={:.2f}'.format(float(hooks_dict['diff_input']))) 76 | fig.add_subplot(gs[i, 1]) 77 | plt.imshow(hooks_dict['target_face']) 78 | plt.title('Target\nIn={:.2f}, Out={:.2f}'.format(float(hooks_dict['diff_views']), 79 | float(hooks_dict['diff_target']))) 80 | fig.add_subplot(gs[i, 2]) 81 | plt.imshow(hooks_dict['output_face']) 82 | plt.title('Output\n Target Sim={:.2f}'.format(float(hooks_dict['diff_target']))) 83 | 84 | 85 | def vis_faces_no_id(hooks_dict, fig, gs, i): 86 | plt.imshow(hooks_dict['input_face'], cmap="gray") 87 | plt.title('Input') 88 | fig.add_subplot(gs[i, 1]) 89 | plt.imshow(hooks_dict['target_face']) 90 | plt.title('Target') 91 | fig.add_subplot(gs[i, 2]) 92 | plt.imshow(hooks_dict['output_face']) 93 | plt.title('Output') 94 | --------------------------------------------------------------------------------