├── .gitignore ├── LICENSE ├── README.md ├── files ├── DECV_requirements.txt ├── DVPEnv.yml ├── general.txt ├── instaColorEnv.yml └── instaColorInstallDep.sh ├── models ├── Colorful │ ├── __init__.py │ ├── base_color.py │ ├── eccv16.py │ ├── options │ │ └── test_options.py │ ├── siggraph17.py │ └── util.py ├── DEVC │ ├── __init__.py │ ├── models │ │ ├── ColorVidNet.py │ │ ├── FrameColor.py │ │ ├── GAN_Models.py │ │ ├── NonlocalNet.py │ │ ├── spectral_normalization.py │ │ └── vgg19_gray.py │ ├── options │ │ └── test_options.py │ └── utils │ │ ├── lib │ │ ├── functional.py │ │ └── test_transforms.py │ │ ├── loss.py │ │ ├── util.py │ │ └── util_distortion.py ├── DVP │ ├── README.md │ ├── __init__.py │ ├── environment.yml │ ├── main_IRT.py │ ├── models │ │ ├── __pycache__ │ │ │ ├── network.cpython-36.pyc │ │ │ └── network.cpython-39.pyc │ │ └── network.py │ ├── options │ │ └── test_options.py │ ├── test.sh │ └── vgg.py └── InstaColor │ ├── __init__.py │ ├── data │ ├── __init__.py │ ├── aligned_dataset.py │ ├── base_data_loader.py │ ├── base_dataset.py │ ├── color_dataset.py │ ├── image_folder.py │ └── single_dataset.py │ ├── models │ ├── __init__.py │ ├── base_model.py │ ├── fusion_model.py │ └── networks.py │ ├── options │ ├── base_options.py │ └── test_options.py │ └── utils │ ├── datasets.py │ ├── download.py │ ├── image_utils.py │ └── util.py ├── notebooks ├── DEVC.ipynb ├── DVP.ipynb ├── InstaColor.ipynb ├── colorizer.ipynb ├── video_prepro.ipynb └── webdemo.ipynb ├── sample.mp4 ├── utils ├── color_format.py ├── loss.py ├── metrics.py ├── util.py └── v2i.py └── webdemo.py /.gitignore: -------------------------------------------------------------------------------- 1 | test 2 | 3 | notebooks/.ipynb_checkpoints 4 | notebooks/results 5 | notebooks/checkpoints* 6 | 7 | **/__pycache__ 8 | *.zip 9 | checkpoints -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Vince Ai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Colorizer 2 | A collection of colorization models 3 | 4 | Model curently included: 5 | 6 | - DEVC 7 | 8 | - DVP 9 | 10 | - InstaColor 11 | 12 | - Colorful 13 | 14 | DVP Colab Demo: 15 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/drive/folders/1L6_hY35kvL5EkFCIncQPi7uRoOsNdgvn) 16 | 17 | Interactive Web Demo for other models: 18 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Vince-Ai/Colorization/blob/main/notebooks/webdemo.ipynb) 19 | -------------------------------------------------------------------------------- /files/DECV_requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm>=4.23.0 3 | scipy==1.2 4 | torchsummary>=1.3 5 | matplotlib>=2.2.2 6 | opencv_contrib_python==4.2.0.32 7 | torchvision>=0.3.0 8 | scikit_image>=0.15.0 9 | torchviz>=0.0.1 10 | Pillow>=6.1.0 11 | torch 12 | easydict 13 | prefetch_generator 14 | PyYAML 15 | tensorboard 16 | future 17 | numba 18 | pypng -------------------------------------------------------------------------------- /files/DVPEnv.yml: -------------------------------------------------------------------------------- 1 | name: dvp 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=4.5=1_gnu 9 | - asttokens=2.0.5=pyhd8ed1ab_0 10 | - backcall=0.2.0=pyh9f0ad1d_0 11 | - backports=1.0=py_2 12 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0 13 | - blas=1.0=mkl 14 | - brotli=1.0.9=he6710b0_2 15 | - bzip2=1.0.8=h7b6447c_0 16 | - ca-certificates=2022.2.1=h06a4308_0 17 | - certifi=2021.10.8=py39h06a4308_2 18 | - cudatoolkit=11.3.1=h2bc3f7f_2 19 | - cycler=0.11.0=pyhd3eb1b0_0 20 | - dbus=1.13.18=hb2f20db_0 21 | - debugpy=1.5.1=py39h295c915_0 22 | - decorator=5.1.1=pyhd8ed1ab_0 23 | - entrypoints=0.4=pyhd8ed1ab_0 24 | - executing=0.8.3=pyhd8ed1ab_0 25 | - expat=2.4.4=h295c915_0 26 | - ffmpeg=4.3=hf484d3e_0 27 | - fontconfig=2.13.1=h6c09931_0 28 | - fonttools=4.25.0=pyhd3eb1b0_0 29 | - freetype=2.11.0=h70c0345_0 30 | - giflib=5.2.1=h7b6447c_0 31 | - glib=2.69.1=h4ff587b_1 32 | - gmp=6.2.1=h2531618_2 33 | - gnutls=3.6.15=he1e5248_0 34 | - gst-plugins-base=1.14.0=h8213a91_2 35 | - gstreamer=1.14.0=h28cd5cc_2 36 | - icu=58.2=he6710b0_3 37 | - imageio=2.16.1=pyhcf75d05_0 38 | - intel-openmp=2021.4.0=h06a4308_3561 39 | - ipykernel=6.9.1=py39hef51801_0 40 | - ipython=8.1.1=py39hf3d152e_0 41 | - jedi=0.18.1=py39hf3d152e_0 42 | - jpeg=9d=h7f8727e_0 43 | - jupyter_client=7.1.2=pyhd8ed1ab_0 44 | - jupyter_core=4.9.2=py39hf3d152e_0 45 | - kiwisolver=1.3.2=py39h295c915_0 46 | - lame=3.100=h7b6447c_0 47 | - lcms2=2.12=h3be6417_0 48 | - ld_impl_linux-64=2.35.1=h7274673_9 49 | - libffi=3.3=he6710b0_2 50 | - libgcc-ng=9.3.0=h5101ec6_17 51 | - libgfortran-ng=7.5.0=ha8ba4b0_17 52 | - libgfortran4=7.5.0=ha8ba4b0_17 53 | - libgomp=9.3.0=h5101ec6_17 54 | - libiconv=1.15=h63c8f33_5 55 | - libidn2=2.3.2=h7f8727e_0 56 | - libpng=1.6.37=hbc83047_0 57 | - libsodium=1.0.18=h36c2ea0_1 58 | - libstdcxx-ng=9.3.0=hd4cf53a_17 59 | - libtasn1=4.16.0=h27cfd23_0 60 | - libtiff=4.2.0=h85742a9_0 61 | - libunistring=0.9.10=h27cfd23_0 62 | - libuuid=1.0.3=h7f8727e_2 63 | - libuv=1.40.0=h7b6447c_0 64 | - libwebp=1.2.2=h55f646e_0 65 | - libwebp-base=1.2.2=h7f8727e_0 66 | - libxcb=1.14=h7b6447c_0 67 | - libxml2=2.9.12=h03d6c58_0 68 | - lz4-c=1.9.3=h295c915_1 69 | - matplotlib=3.5.1=py39h06a4308_0 70 | - matplotlib-base=3.5.1=py39ha18d171_0 71 | - matplotlib-inline=0.1.3=pyhd8ed1ab_0 72 | - mkl=2021.4.0=h06a4308_640 73 | - mkl-service=2.4.0=py39h7f8727e_0 74 | - mkl_fft=1.3.1=py39hd3c417c_0 75 | - mkl_random=1.2.2=py39h51133e4_0 76 | - munkres=1.1.4=py_0 77 | - ncurses=6.3=h7f8727e_2 78 | - nest-asyncio=1.5.4=pyhd8ed1ab_0 79 | - nettle=3.7.3=hbbd107a_1 80 | - numpy=1.21.2=py39h20f2e39_0 81 | - numpy-base=1.21.2=py39h79a1101_0 82 | - openh264=2.1.1=h4ff587b_0 83 | - openssl=1.1.1m=h7f8727e_0 84 | - packaging=21.3=pyhd3eb1b0_0 85 | - parso=0.8.3=pyhd8ed1ab_0 86 | - pcre=8.45=h295c915_0 87 | - pexpect=4.8.0=pyh9f0ad1d_2 88 | - pickleshare=0.7.5=py_1003 89 | - pillow=9.0.1=py39h22f2fdc_0 90 | - pip=21.2.4=py39h06a4308_0 91 | - prompt-toolkit=3.0.27=pyha770c72_0 92 | - ptyprocess=0.7.0=pyhd3deb0d_0 93 | - pure_eval=0.2.2=pyhd8ed1ab_0 94 | - pygments=2.11.2=pyhd8ed1ab_0 95 | - pyparsing=3.0.4=pyhd3eb1b0_0 96 | - pyqt=5.9.2=py39h2531618_6 97 | - python=3.9.7=h12debd9_1 98 | - python-dateutil=2.8.2=pyhd8ed1ab_0 99 | - python_abi=3.9=2_cp39 100 | - pytorch=1.10.2=py3.9_cuda11.3_cudnn8.2.0_0 101 | - pytorch-mutex=1.0=cuda 102 | - pyzmq=19.0.2=py39hb69f2a1_2 103 | - qt=5.9.7=h5867ecd_1 104 | - readline=8.1.2=h7f8727e_1 105 | - scipy=1.7.3=py39hc147768_0 106 | - setuptools=58.0.4=py39h06a4308_0 107 | - sip=4.19.13=py39h295c915_0 108 | - six=1.16.0=pyhd3eb1b0_1 109 | - sqlite=3.37.2=hc218d9a_0 110 | - stack_data=0.2.0=pyhd8ed1ab_0 111 | - tk=8.6.11=h1ccaba5_0 112 | - torchaudio=0.10.2=py39_cu113 113 | - torchvision=0.11.3=py39_cu113 114 | - tornado=6.1=py39h3811e60_1 115 | - traitlets=5.1.1=pyhd8ed1ab_0 116 | - typing_extensions=3.10.0.2=pyh06a4308_0 117 | - tzdata=2021e=hda174b7_0 118 | - wcwidth=0.2.5=pyh9f0ad1d_2 119 | - wheel=0.37.1=pyhd3eb1b0_0 120 | - xz=5.2.5=h7b6447c_0 121 | - zeromq=4.3.4=h9c3ff4c_0 122 | - zlib=1.2.11=h7f8727e_4 123 | - zstd=1.4.9=haebb681_0 124 | - pip: 125 | - opencv-python==4.5.5.62 126 | prefix: /home/tonyx/Utils/anaconda3/envs/dvp 127 | -------------------------------------------------------------------------------- /files/general.txt: -------------------------------------------------------------------------------- 1 | av 2 | torchvideo 3 | pillow-simd 4 | lpips -------------------------------------------------------------------------------- /files/instaColorEnv.yml: -------------------------------------------------------------------------------- 1 | name: instacolorization 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - backcall=0.1.0=py37_0 8 | - blas=1.0=mkl 9 | - ca-certificates=2019.11.28=hecc5488_0 10 | - certifi=2019.11.28=py37_0 11 | - cloudpickle=1.3.0=py_0 12 | - cudatoolkit=10.1.243=h6bb024c_0 13 | - cycler=0.10.0=py37_0 14 | - cython=0.29.15=py37he6710b0_0 15 | - cytoolz=0.10.1=py37h7b6447c_0 16 | - dask-core=2.10.1=py_0 17 | - dbus=1.13.12=h746ee38_0 18 | - decorator=4.4.1=py_0 19 | - expat=2.2.6=he6710b0_0 20 | - fontconfig=2.13.0=h9420a91_0 21 | - freetype=2.9.1=h8a8886c_1 22 | - glib=2.63.1=h5a9c865_0 23 | - gst-plugins-base=1.14.0=hbbd80ab_1 24 | - gstreamer=1.14.0=hb453b48_1 25 | - icu=58.2=h9c2bf20_1 26 | - imageio=2.6.1=py37_0 27 | - intel-openmp=2020.0=166 28 | - ipython=7.12.0=py37h5ca1d4c_0 29 | - ipython_genutils=0.2.0=py37_0 30 | - jedi=0.16.0=py37_0 31 | - joblib=0.14.1=py_0 32 | - jpeg=9b=h024ee3a_2 33 | - kiwisolver=1.1.0=py37he6710b0_0 34 | - ld_impl_linux-64=2.33.1=h53a641e_7 35 | - libedit=3.1.20181209=hc058e9b_0 36 | - libffi=3.2.1=hd88cf55_4 37 | - libgcc-ng=9.1.0=hdf63c60_0 38 | - libgfortran-ng=7.3.0=hdf63c60_0 39 | - libpng=1.6.37=hbc83047_0 40 | - libstdcxx-ng=9.1.0=hdf63c60_0 41 | - libtiff=4.1.0=h2733197_0 42 | - libuuid=1.0.3=h1bed415_2 43 | - libxcb=1.13=h1bed415_1 44 | - libxml2=2.9.9=hea5a465_1 45 | - matplotlib=3.1.3=py37_0 46 | - matplotlib-base=3.1.3=py37hef1b27d_0 47 | - mkl=2020.0=166 48 | - mkl-service=2.3.0=py37he904b0f_0 49 | - mkl_fft=1.0.15=py37ha843d7b_0 50 | - mkl_random=1.1.0=py37hd6b4f25_0 51 | - ncurses=6.1=he6710b0_1 52 | - networkx=2.4=py_0 53 | - ninja=1.9.0=py37hfd86e86_0 54 | - numpy=1.18.1=py37h4f9e942_0 55 | - numpy-base=1.18.1=py37hde5b4d6_1 56 | - olefile=0.46=py37_0 57 | - openssl=1.1.1d=h516909a_0 58 | - parso=0.6.1=py_0 59 | - pcre=8.43=he6710b0_0 60 | - pexpect=4.8.0=py37_0 61 | - pickleshare=0.7.5=py37_0 62 | - pillow=7.0.0=py37hb39fc2d_0 63 | - pip=20.0.2=py37_1 64 | - prompt_toolkit=3.0.3=py_0 65 | - ptyprocess=0.6.0=py37_0 66 | - pygments=2.5.2=py_0 67 | - pyparsing=2.4.6=py_0 68 | - pyqt=5.9.2=py37h05f1152_2 69 | - python=3.7.6=h0371630_2 70 | - python-dateutil=2.8.1=py_0 71 | - pywavelets=1.1.1=py37h7b6447c_0 72 | - qt=5.9.7=h5867ecd_1 73 | - readline=7.0=h7b6447c_5 74 | - scikit-image=0.16.2=py37h0573a6f_0 75 | - scikit-learn=0.22.1=py37hd81dba3_0 76 | - scipy=1.4.1=py37h0b6359f_0 77 | - setuptools=45.2.0=py37_0 78 | - sip=4.19.8=py37hf484d3e_0 79 | - six=1.14.0=py37_0 80 | - sqlite=3.31.1=h7b6447c_0 81 | - tk=8.6.8=hbc83047_0 82 | - toolz=0.10.0=py_0 83 | - tornado=6.0.3=py37h7b6447c_3 84 | - tqdm=4.43.0=py_0 85 | - traitlets=4.3.3=py37_0 86 | - wcwidth=0.1.8=py_0 87 | - wheel=0.34.2=py37_0 88 | - xz=5.2.4=h14c3975_4 89 | - zlib=1.2.11=h7b6447c_3 90 | - zstd=1.3.7=h0b5b093_0 91 | - pip: 92 | - absl-py==0.9.0 93 | - cachetools==4.1.0 94 | - chardet==3.0.4 95 | - future==0.18.2 96 | - fvcore==0.1.dev200506 97 | - google-auth==1.14.2 98 | - google-auth-oauthlib==0.4.1 99 | - grpcio==1.28.1 100 | - idna==2.9 101 | - importlib-metadata==1.6.0 102 | - jsonpatch==1.25 103 | - jsonpointer==2.0 104 | - markdown==3.2.2 105 | - mock==4.0.2 106 | - oauthlib==3.1.0 107 | - opencv-python==4.2.0.32 108 | - portalocker==1.7.0 109 | - protobuf==3.11.3 110 | - pyasn1==0.4.8 111 | - pyasn1-modules==0.2.8 112 | - pydot==1.4.1 113 | - pyzmq==18.1.1 114 | - requests==2.23.0 115 | - requests-oauthlib==1.3.0 116 | - rsa==4.0 117 | - tabulate==0.8.7 118 | - tensorboard==2.2.1 119 | - tensorboard-plugin-wit==1.6.0.post3 120 | - termcolor==1.1.0 121 | - urllib3==1.25.8 122 | - visdom==0.1.8.9 123 | - websocket-client==0.57.0 124 | - werkzeug==1.0.1 125 | - yacs==0.1.7 126 | - zipp==3.1.0 127 | prefix: /home/user/miniconda3/envs/py37_pt14 128 | 129 | -------------------------------------------------------------------------------- /files/instaColorInstallDep.sh: -------------------------------------------------------------------------------- 1 | pip install -U torch==1.5 torchvision==0.6 -f https://download.pytorch.org/whl/cu101/torch_stable.html 2 | pip install cython pyyaml==5.1 3 | pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' 4 | pip install dominate==2.4.0 5 | pip install detectron2==0.1.2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/index.html 6 | -------------------------------------------------------------------------------- /models/Colorful/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .base_color import * 3 | from .eccv16 import * 4 | from .siggraph17 import * 5 | from .util import * 6 | 7 | import os 8 | import shutil 9 | import cv2 10 | import torch 11 | from PIL import Image 12 | from tqdm import tqdm 13 | import numpy as np 14 | import torchvision.transforms as transform_lib 15 | import matplotlib.pyplot as plt 16 | 17 | from utils.util import download_zipfile, mkdir 18 | from utils.v2i import convert_frames_to_video 19 | 20 | class OPT(): 21 | pass 22 | 23 | class Colorful(): 24 | def __init__(self, pretrained=True): 25 | self.model = siggraph17(pretrained=True).cuda().eval() 26 | self.opt = OPT() 27 | self.opt.output_frame_path = "./test/results" 28 | 29 | def test(self, input_path, output_path, opt=None): 30 | 31 | if not os.path.isdir(self.opt.output_frame_path): 32 | os.makedirs(self.opt.output_frame_path) 33 | 34 | frames = os.listdir(input_path) 35 | frames.sort() 36 | for frame in frames: 37 | colorized = self.colorize(os.path.join(input_path, frame)) 38 | plt.imsave(os.path.join(self.opt.output_frame_path, frame), colorized) 39 | 40 | convert_frames_to_video(self.opt.output_frame_path, output_path) 41 | 42 | 43 | def colorize(self, path): 44 | img = load_img(path) 45 | (tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256,256)) 46 | tens_l_rs = tens_l_rs.cuda() 47 | out_img_siggraph17 = postprocess_tens(tens_l_orig, self.model(tens_l_rs).cpu()) 48 | return out_img_siggraph17 -------------------------------------------------------------------------------- /models/Colorful/base_color.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | 5 | class BaseColor(nn.Module): 6 | def __init__(self): 7 | super(BaseColor, self).__init__() 8 | 9 | self.l_cent = 50. 10 | self.l_norm = 100. 11 | self.ab_norm = 110. 12 | 13 | def normalize_l(self, in_l): 14 | return (in_l-self.l_cent)/self.l_norm 15 | 16 | def unnormalize_l(self, in_l): 17 | return in_l*self.l_norm + self.l_cent 18 | 19 | def normalize_ab(self, in_ab): 20 | return in_ab/self.ab_norm 21 | 22 | def unnormalize_ab(self, in_ab): 23 | return in_ab*self.ab_norm 24 | 25 | -------------------------------------------------------------------------------- /models/Colorful/eccv16.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from IPython import embed 6 | 7 | from .base_color import * 8 | 9 | class ECCVGenerator(BaseColor): 10 | def __init__(self, norm_layer=nn.BatchNorm2d): 11 | super(ECCVGenerator, self).__init__() 12 | 13 | model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),] 14 | model1+=[nn.ReLU(True),] 15 | model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),] 16 | model1+=[nn.ReLU(True),] 17 | model1+=[norm_layer(64),] 18 | 19 | model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),] 20 | model2+=[nn.ReLU(True),] 21 | model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),] 22 | model2+=[nn.ReLU(True),] 23 | model2+=[norm_layer(128),] 24 | 25 | model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),] 26 | model3+=[nn.ReLU(True),] 27 | model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 28 | model3+=[nn.ReLU(True),] 29 | model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),] 30 | model3+=[nn.ReLU(True),] 31 | model3+=[norm_layer(256),] 32 | 33 | model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),] 34 | model4+=[nn.ReLU(True),] 35 | model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 36 | model4+=[nn.ReLU(True),] 37 | model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 38 | model4+=[nn.ReLU(True),] 39 | model4+=[norm_layer(512),] 40 | 41 | model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 42 | model5+=[nn.ReLU(True),] 43 | model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 44 | model5+=[nn.ReLU(True),] 45 | model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 46 | model5+=[nn.ReLU(True),] 47 | model5+=[norm_layer(512),] 48 | 49 | model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 50 | model6+=[nn.ReLU(True),] 51 | model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 52 | model6+=[nn.ReLU(True),] 53 | model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 54 | model6+=[nn.ReLU(True),] 55 | model6+=[norm_layer(512),] 56 | 57 | model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 58 | model7+=[nn.ReLU(True),] 59 | model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 60 | model7+=[nn.ReLU(True),] 61 | model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 62 | model7+=[nn.ReLU(True),] 63 | model7+=[norm_layer(512),] 64 | 65 | model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),] 66 | model8+=[nn.ReLU(True),] 67 | model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 68 | model8+=[nn.ReLU(True),] 69 | model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 70 | model8+=[nn.ReLU(True),] 71 | 72 | model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),] 73 | 74 | self.model1 = nn.Sequential(*model1) 75 | self.model2 = nn.Sequential(*model2) 76 | self.model3 = nn.Sequential(*model3) 77 | self.model4 = nn.Sequential(*model4) 78 | self.model5 = nn.Sequential(*model5) 79 | self.model6 = nn.Sequential(*model6) 80 | self.model7 = nn.Sequential(*model7) 81 | self.model8 = nn.Sequential(*model8) 82 | 83 | self.softmax = nn.Softmax(dim=1) 84 | self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False) 85 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear') 86 | 87 | def forward(self, input_l): 88 | conv1_2 = self.model1(self.normalize_l(input_l)) 89 | conv2_2 = self.model2(conv1_2) 90 | conv3_3 = self.model3(conv2_2) 91 | conv4_3 = self.model4(conv3_3) 92 | conv5_3 = self.model5(conv4_3) 93 | conv6_3 = self.model6(conv5_3) 94 | conv7_3 = self.model7(conv6_3) 95 | conv8_3 = self.model8(conv7_3) 96 | out_reg = self.model_out(self.softmax(conv8_3)) 97 | 98 | return self.unnormalize_ab(self.upsample4(out_reg)) 99 | 100 | def eccv16(pretrained=True): 101 | model = ECCVGenerator() 102 | if(pretrained): 103 | import torch.utils.model_zoo as model_zoo 104 | model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth',map_location='cpu',check_hash=True)) 105 | return model 106 | -------------------------------------------------------------------------------- /models/Colorful/options/test_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.backends.cudnn as cudnn 3 | class TestOptions(): 4 | def __init__(self): 5 | pass 6 | 7 | def parse(self): 8 | # initialize parser with basic options 9 | opt = {} 10 | return opt -------------------------------------------------------------------------------- /models/Colorful/siggraph17.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .base_color import * 5 | 6 | class SIGGRAPHGenerator(BaseColor): 7 | def __init__(self, norm_layer=nn.BatchNorm2d, classes=529): 8 | super(SIGGRAPHGenerator, self).__init__() 9 | 10 | # Conv1 11 | model1=[nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True),] 12 | model1+=[nn.ReLU(True),] 13 | model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),] 14 | model1+=[nn.ReLU(True),] 15 | model1+=[norm_layer(64),] 16 | # add a subsampling operation 17 | 18 | # Conv2 19 | model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),] 20 | model2+=[nn.ReLU(True),] 21 | model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),] 22 | model2+=[nn.ReLU(True),] 23 | model2+=[norm_layer(128),] 24 | # add a subsampling layer operation 25 | 26 | # Conv3 27 | model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),] 28 | model3+=[nn.ReLU(True),] 29 | model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 30 | model3+=[nn.ReLU(True),] 31 | model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 32 | model3+=[nn.ReLU(True),] 33 | model3+=[norm_layer(256),] 34 | # add a subsampling layer operation 35 | 36 | # Conv4 37 | model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),] 38 | model4+=[nn.ReLU(True),] 39 | model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 40 | model4+=[nn.ReLU(True),] 41 | model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 42 | model4+=[nn.ReLU(True),] 43 | model4+=[norm_layer(512),] 44 | 45 | # Conv5 46 | model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 47 | model5+=[nn.ReLU(True),] 48 | model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 49 | model5+=[nn.ReLU(True),] 50 | model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 51 | model5+=[nn.ReLU(True),] 52 | model5+=[norm_layer(512),] 53 | 54 | # Conv6 55 | model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 56 | model6+=[nn.ReLU(True),] 57 | model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 58 | model6+=[nn.ReLU(True),] 59 | model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),] 60 | model6+=[nn.ReLU(True),] 61 | model6+=[norm_layer(512),] 62 | 63 | # Conv7 64 | model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 65 | model7+=[nn.ReLU(True),] 66 | model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 67 | model7+=[nn.ReLU(True),] 68 | model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),] 69 | model7+=[nn.ReLU(True),] 70 | model7+=[norm_layer(512),] 71 | 72 | # Conv7 73 | model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)] 74 | model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 75 | 76 | model8=[nn.ReLU(True),] 77 | model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 78 | model8+=[nn.ReLU(True),] 79 | model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),] 80 | model8+=[nn.ReLU(True),] 81 | model8+=[norm_layer(256),] 82 | 83 | # Conv9 84 | model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),] 85 | model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),] 86 | # add the two feature maps above 87 | 88 | model9=[nn.ReLU(True),] 89 | model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),] 90 | model9+=[nn.ReLU(True),] 91 | model9+=[norm_layer(128),] 92 | 93 | # Conv10 94 | model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),] 95 | model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),] 96 | # add the two feature maps above 97 | 98 | model10=[nn.ReLU(True),] 99 | model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True),] 100 | model10+=[nn.LeakyReLU(negative_slope=.2),] 101 | 102 | # classification output 103 | model_class=[nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),] 104 | 105 | # regression output 106 | model_out=[nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),] 107 | model_out+=[nn.Tanh()] 108 | 109 | self.model1 = nn.Sequential(*model1) 110 | self.model2 = nn.Sequential(*model2) 111 | self.model3 = nn.Sequential(*model3) 112 | self.model4 = nn.Sequential(*model4) 113 | self.model5 = nn.Sequential(*model5) 114 | self.model6 = nn.Sequential(*model6) 115 | self.model7 = nn.Sequential(*model7) 116 | self.model8up = nn.Sequential(*model8up) 117 | self.model8 = nn.Sequential(*model8) 118 | self.model9up = nn.Sequential(*model9up) 119 | self.model9 = nn.Sequential(*model9) 120 | self.model10up = nn.Sequential(*model10up) 121 | self.model10 = nn.Sequential(*model10) 122 | self.model3short8 = nn.Sequential(*model3short8) 123 | self.model2short9 = nn.Sequential(*model2short9) 124 | self.model1short10 = nn.Sequential(*model1short10) 125 | 126 | self.model_class = nn.Sequential(*model_class) 127 | self.model_out = nn.Sequential(*model_out) 128 | 129 | self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='bilinear'),]) 130 | self.softmax = nn.Sequential(*[nn.Softmax(dim=1),]) 131 | 132 | def forward(self, input_A, input_B=None, mask_B=None): 133 | if(input_B is None): 134 | input_B = torch.cat((input_A*0, input_A*0), dim=1) 135 | if(mask_B is None): 136 | mask_B = input_A*0 137 | 138 | conv1_2 = self.model1(torch.cat((self.normalize_l(input_A),self.normalize_ab(input_B),mask_B),dim=1)) 139 | conv2_2 = self.model2(conv1_2[:,:,::2,::2]) 140 | conv3_3 = self.model3(conv2_2[:,:,::2,::2]) 141 | conv4_3 = self.model4(conv3_3[:,:,::2,::2]) 142 | conv5_3 = self.model5(conv4_3) 143 | conv6_3 = self.model6(conv5_3) 144 | conv7_3 = self.model7(conv6_3) 145 | 146 | conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3) 147 | conv8_3 = self.model8(conv8_up) 148 | conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) 149 | conv9_3 = self.model9(conv9_up) 150 | conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) 151 | conv10_2 = self.model10(conv10_up) 152 | out_reg = self.model_out(conv10_2) 153 | 154 | conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) 155 | conv9_3 = self.model9(conv9_up) 156 | conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) 157 | conv10_2 = self.model10(conv10_up) 158 | out_reg = self.model_out(conv10_2) 159 | 160 | return self.unnormalize_ab(out_reg) 161 | 162 | def siggraph17(pretrained=True): 163 | model = SIGGRAPHGenerator() 164 | if(pretrained): 165 | import torch.utils.model_zoo as model_zoo 166 | model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True)) 167 | return model 168 | 169 | -------------------------------------------------------------------------------- /models/Colorful/util.py: -------------------------------------------------------------------------------- 1 | 2 | from PIL import Image 3 | import numpy as np 4 | from skimage import color 5 | import torch 6 | import torch.nn.functional as F 7 | from IPython import embed 8 | 9 | def load_img(img_path): 10 | out_np = np.asarray(Image.open(img_path)) 11 | if(out_np.ndim==2): 12 | out_np = np.tile(out_np[:,:,None],3) 13 | return out_np 14 | 15 | def resize_img(img, HW=(256,256), resample=3): 16 | return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample)) 17 | 18 | def preprocess_img(img_rgb_orig, HW=(256,256), resample=3): 19 | # return original size L and resized L as torch Tensors 20 | img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample) 21 | 22 | img_lab_orig = color.rgb2lab(img_rgb_orig) 23 | img_lab_rs = color.rgb2lab(img_rgb_rs) 24 | 25 | img_l_orig = img_lab_orig[:,:,0] 26 | img_l_rs = img_lab_rs[:,:,0] 27 | 28 | tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:] 29 | tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:] 30 | 31 | return (tens_orig_l, tens_rs_l) 32 | 33 | def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'): 34 | # tens_orig_l 1 x 1 x H_orig x W_orig 35 | # out_ab 1 x 2 x H x W 36 | 37 | HW_orig = tens_orig_l.shape[2:] 38 | HW = out_ab.shape[2:] 39 | 40 | # call resize function if needed 41 | if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]): 42 | out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear') 43 | else: 44 | out_ab_orig = out_ab 45 | 46 | out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1) 47 | return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0))) 48 | -------------------------------------------------------------------------------- /models/DEVC/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import cv2 4 | import torch 5 | from PIL import Image 6 | from tqdm import tqdm 7 | import numpy as np 8 | import torchvision.transforms as transform_lib 9 | 10 | from utils.util import download_zipfile, mkdir 11 | from utils.v2i import convert_frames_to_video 12 | import models.DEVC.utils.lib.test_transforms as transforms 13 | from models.DEVC.utils.util_distortion import CenterPad, Normalize, RGB2Lab, ToTensor 14 | from models.DEVC.utils.util import batch_lab2rgb_transpose_mc, save_frames, tensor_lab2rgb, uncenter_l 15 | from models.DEVC.models.ColorVidNet import ColorVidNet 16 | from models.DEVC.models.FrameColor import frame_colorization 17 | from models.DEVC.models.NonlocalNet import VGG19_pytorch, WarpNet 18 | 19 | class DEVC(): 20 | def __init__(self, pretrained=True): 21 | self.nonlocal_net = WarpNet(1) 22 | self.colornet = ColorVidNet(7) 23 | self.vggnet = VGG19_pytorch() 24 | 25 | if pretrained is True: 26 | download_zipfile("https://facevc.blob.core.windows.net/zhanbo/old_photo/colorization_checkpoint.zip", "DEVC_checkpoints.zip") 27 | self.vggnet.load_state_dict(torch.load("data/vgg19_conv.pth")) 28 | self.nonlocal_net.load_state_dict(torch.load("checkpoints/video_moredata_l1/nonlocal_net_iter_76000.pth")) 29 | self.colornet.load_state_dict(torch.load("checkpoints/video_moredata_l1/colornet_iter_76000.pth")) 30 | 31 | def test(self, input_path, output_path, opt): 32 | mkdir(opt.output_frame_path) 33 | # parameters for wls filter 34 | wls_filter_on = True 35 | lambda_value = 500 36 | sigma_color = 4 37 | 38 | # net 39 | self.nonlocal_net.eval() 40 | self.colornet.eval() 41 | self.vggnet.eval() 42 | self.nonlocal_net.cuda() 43 | self.colornet.cuda() 44 | self.vggnet.cuda() 45 | for param in self.vggnet.parameters(): 46 | param.requires_grad = False 47 | 48 | # processing folders 49 | print("processing the folder:", input_path) 50 | _, _, filenames = os.walk(input_path).__next__() 51 | filenames.sort(key=lambda f: int("".join(filter(str.isdigit, f) or -1))) 52 | 53 | # NOTE: resize frames to 216*384 54 | transform = transforms.Compose( 55 | [CenterPad(opt.image_size), transform_lib.CenterCrop(opt.image_size), RGB2Lab(), ToTensor(), Normalize()] 56 | ) 57 | 58 | # if frame propagation: use the first frame as reference 59 | # otherwise, use the specified reference image 60 | ref_name = os.path.join(input_path , filenames[0]) if opt.frame_propagate else opt.ref_path 61 | print("reference name:", ref_name) 62 | frame_ref = Image.open(ref_name) 63 | 64 | I_last_lab_predict = None 65 | 66 | IB_lab_large = transform(frame_ref).unsqueeze(0).cuda() 67 | IB_lab = torch.nn.functional.interpolate(IB_lab_large, scale_factor=0.5, mode="bilinear") 68 | IB_l = IB_lab[:, 0:1, :, :] 69 | IB_ab = IB_lab[:, 1:3, :, :] 70 | with torch.no_grad(): 71 | I_reference_lab = IB_lab 72 | I_reference_l = I_reference_lab[:, 0:1, :, :] 73 | I_reference_ab = I_reference_lab[:, 1:3, :, :] 74 | I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1)) 75 | features_B = self.vggnet(I_reference_rgb, ["r12", "r22", "r32", "r42", "r52"], preprocess=True) 76 | 77 | for index, frame_name in enumerate(tqdm(filenames)): 78 | frame1 = Image.open(os.path.join(input_path, frame_name)) 79 | IA_lab_large = transform(frame1).unsqueeze(0).cuda() 80 | IA_lab = torch.nn.functional.interpolate(IA_lab_large, scale_factor=0.5, mode="bilinear") 81 | 82 | IA_l = IA_lab[:, 0:1, :, :] 83 | IA_ab = IA_lab[:, 1:3, :, :] 84 | 85 | if I_last_lab_predict is None: 86 | if opt.frame_propagate: 87 | I_last_lab_predict = IB_lab 88 | else: 89 | I_last_lab_predict = torch.zeros_like(IA_lab).cuda() 90 | 91 | # start the frame colorization 92 | with torch.no_grad(): 93 | I_current_lab = IA_lab 94 | I_current_ab_predict, I_current_nonlocal_lab_predict, features_current_gray = frame_colorization( 95 | I_current_lab, 96 | I_reference_lab, 97 | I_last_lab_predict, 98 | features_B, 99 | self.vggnet, 100 | self.nonlocal_net, 101 | self.colornet, 102 | feature_noise=0, 103 | temperature=1e-10, 104 | ) 105 | I_last_lab_predict = torch.cat((IA_l, I_current_ab_predict), dim=1) 106 | 107 | # upsampling 108 | curr_bs_l = IA_lab_large[:, 0:1, :, :] 109 | curr_predict = ( 110 | torch.nn.functional.interpolate(I_current_ab_predict.data.cpu(), scale_factor=2, mode="bilinear") * 1.25 111 | ) 112 | 113 | # filtering 114 | if wls_filter_on: 115 | guide_image = uncenter_l(curr_bs_l) * 255 / 100 116 | wls_filter = cv2.ximgproc.createFastGlobalSmootherFilter( 117 | guide_image[0, 0, :, :].cpu().numpy().astype(np.uint8), lambda_value, sigma_color 118 | ) 119 | curr_predict_a = wls_filter.filter(curr_predict[0, 0, :, :].cpu().numpy()) 120 | curr_predict_b = wls_filter.filter(curr_predict[0, 1, :, :].cpu().numpy()) 121 | curr_predict_a = torch.from_numpy(curr_predict_a).unsqueeze(0).unsqueeze(0) 122 | curr_predict_b = torch.from_numpy(curr_predict_b).unsqueeze(0).unsqueeze(0) 123 | curr_predict_filter = torch.cat((curr_predict_a, curr_predict_b), dim=1) 124 | IA_predict_rgb = batch_lab2rgb_transpose_mc(curr_bs_l[:32], curr_predict_filter[:32, ...]) 125 | else: 126 | IA_predict_rgb = batch_lab2rgb_transpose_mc(curr_bs_l[:32], curr_predict[:32, ...]) 127 | 128 | # save the frames 129 | save_frames(IA_predict_rgb, opt.output_frame_path, index) 130 | 131 | # output video 132 | convert_frames_to_video(opt.output_frame_path, output_path) 133 | 134 | shutil.rmtree("data") 135 | shutil.rmtree("checkpoints") 136 | shutil.rmtree(opt.output_frame_path) 137 | 138 | print("Task Complete!") -------------------------------------------------------------------------------- /models/DEVC/models/ColorVidNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | 5 | 6 | class ColorVidNet(nn.Module): 7 | def __init__(self, ic): 8 | super(ColorVidNet, self).__init__() 9 | self.conv1_1 = nn.Sequential(nn.Conv2d(ic, 32, 3, 1, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 1, 1)) 10 | self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 1) 11 | self.conv1_2norm = nn.BatchNorm2d(64, affine=False) 12 | self.conv1_2norm_ss = nn.Conv2d(64, 64, 1, 2, bias=False, groups=64) 13 | self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 1) 14 | self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 1) 15 | self.conv2_2norm = nn.BatchNorm2d(128, affine=False) 16 | self.conv2_2norm_ss = nn.Conv2d(128, 128, 1, 2, bias=False, groups=128) 17 | self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1) 18 | self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1) 19 | self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 1) 20 | self.conv3_3norm = nn.BatchNorm2d(256, affine=False) 21 | self.conv3_3norm_ss = nn.Conv2d(256, 256, 1, 2, bias=False, groups=256) 22 | self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1) 23 | self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1) 24 | self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 1) 25 | self.conv4_3norm = nn.BatchNorm2d(512, affine=False) 26 | self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 2, 2) 27 | self.conv5_2 = nn.Conv2d(512, 512, 3, 1, 2, 2) 28 | self.conv5_3 = nn.Conv2d(512, 512, 3, 1, 2, 2) 29 | self.conv5_3norm = nn.BatchNorm2d(512, affine=False) 30 | self.conv6_1 = nn.Conv2d(512, 512, 3, 1, 2, 2) 31 | self.conv6_2 = nn.Conv2d(512, 512, 3, 1, 2, 2) 32 | self.conv6_3 = nn.Conv2d(512, 512, 3, 1, 2, 2) 33 | self.conv6_3norm = nn.BatchNorm2d(512, affine=False) 34 | self.conv7_1 = nn.Conv2d(512, 512, 3, 1, 1) 35 | self.conv7_2 = nn.Conv2d(512, 512, 3, 1, 1) 36 | self.conv7_3 = nn.Conv2d(512, 512, 3, 1, 1) 37 | self.conv7_3norm = nn.BatchNorm2d(512, affine=False) 38 | self.conv8_1 = nn.ConvTranspose2d(512, 256, 4, 2, 1) 39 | self.conv3_3_short = nn.Conv2d(256, 256, 3, 1, 1) 40 | self.conv8_2 = nn.Conv2d(256, 256, 3, 1, 1) 41 | self.conv8_3 = nn.Conv2d(256, 256, 3, 1, 1) 42 | self.conv8_3norm = nn.BatchNorm2d(256, affine=False) 43 | self.conv9_1 = nn.ConvTranspose2d(256, 128, 4, 2, 1) 44 | self.conv2_2_short = nn.Conv2d(128, 128, 3, 1, 1) 45 | self.conv9_2 = nn.Conv2d(128, 128, 3, 1, 1) 46 | self.conv9_2norm = nn.BatchNorm2d(128, affine=False) 47 | self.conv10_1 = nn.ConvTranspose2d(128, 128, 4, 2, 1) 48 | self.conv1_2_short = nn.Conv2d(64, 128, 3, 1, 1) 49 | self.conv10_2 = nn.Conv2d(128, 128, 3, 1, 1) 50 | self.conv10_ab = nn.Conv2d(128, 2, 1, 1) 51 | 52 | # add self.relux_x 53 | self.relu1_1 = nn.ReLU() 54 | self.relu1_2 = nn.ReLU() 55 | self.relu2_1 = nn.ReLU() 56 | self.relu2_2 = nn.ReLU() 57 | self.relu3_1 = nn.ReLU() 58 | self.relu3_2 = nn.ReLU() 59 | self.relu3_3 = nn.ReLU() 60 | self.relu4_1 = nn.ReLU() 61 | self.relu4_2 = nn.ReLU() 62 | self.relu4_3 = nn.ReLU() 63 | self.relu5_1 = nn.ReLU() 64 | self.relu5_2 = nn.ReLU() 65 | self.relu5_3 = nn.ReLU() 66 | self.relu6_1 = nn.ReLU() 67 | self.relu6_2 = nn.ReLU() 68 | self.relu6_3 = nn.ReLU() 69 | self.relu7_1 = nn.ReLU() 70 | self.relu7_2 = nn.ReLU() 71 | self.relu7_3 = nn.ReLU() 72 | self.relu8_1_comb = nn.ReLU() 73 | self.relu8_2 = nn.ReLU() 74 | self.relu8_3 = nn.ReLU() 75 | self.relu9_1_comb = nn.ReLU() 76 | self.relu9_2 = nn.ReLU() 77 | self.relu10_1_comb = nn.ReLU() 78 | self.relu10_2 = nn.LeakyReLU(0.2, True) 79 | 80 | print("replace all deconv with [nearest + conv]") 81 | self.conv8_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(512, 256, 3, 1, 1)) 82 | self.conv9_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(256, 128, 3, 1, 1)) 83 | self.conv10_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(128, 128, 3, 1, 1)) 84 | 85 | print("replace all batchnorm with instancenorm") 86 | self.conv1_2norm = nn.InstanceNorm2d(64) 87 | self.conv2_2norm = nn.InstanceNorm2d(128) 88 | self.conv3_3norm = nn.InstanceNorm2d(256) 89 | self.conv4_3norm = nn.InstanceNorm2d(512) 90 | self.conv5_3norm = nn.InstanceNorm2d(512) 91 | self.conv6_3norm = nn.InstanceNorm2d(512) 92 | self.conv7_3norm = nn.InstanceNorm2d(512) 93 | self.conv8_3norm = nn.InstanceNorm2d(256) 94 | self.conv9_2norm = nn.InstanceNorm2d(128) 95 | 96 | def forward(self, x): 97 | """ x: gray image (1 channel), ab(2 channel), ab_err, ba_err""" 98 | conv1_1 = self.relu1_1(self.conv1_1(x)) 99 | conv1_2 = self.relu1_2(self.conv1_2(conv1_1)) 100 | conv1_2norm = self.conv1_2norm(conv1_2) 101 | conv1_2norm_ss = self.conv1_2norm_ss(conv1_2norm) 102 | conv2_1 = self.relu2_1(self.conv2_1(conv1_2norm_ss)) 103 | conv2_2 = self.relu2_2(self.conv2_2(conv2_1)) 104 | conv2_2norm = self.conv2_2norm(conv2_2) 105 | conv2_2norm_ss = self.conv2_2norm_ss(conv2_2norm) 106 | conv3_1 = self.relu3_1(self.conv3_1(conv2_2norm_ss)) 107 | conv3_2 = self.relu3_2(self.conv3_2(conv3_1)) 108 | conv3_3 = self.relu3_3(self.conv3_3(conv3_2)) 109 | conv3_3norm = self.conv3_3norm(conv3_3) 110 | conv3_3norm_ss = self.conv3_3norm_ss(conv3_3norm) 111 | conv4_1 = self.relu4_1(self.conv4_1(conv3_3norm_ss)) 112 | conv4_2 = self.relu4_2(self.conv4_2(conv4_1)) 113 | conv4_3 = self.relu4_3(self.conv4_3(conv4_2)) 114 | conv4_3norm = self.conv4_3norm(conv4_3) 115 | conv5_1 = self.relu5_1(self.conv5_1(conv4_3norm)) 116 | conv5_2 = self.relu5_2(self.conv5_2(conv5_1)) 117 | conv5_3 = self.relu5_3(self.conv5_3(conv5_2)) 118 | conv5_3norm = self.conv5_3norm(conv5_3) 119 | conv6_1 = self.relu6_1(self.conv6_1(conv5_3norm)) 120 | conv6_2 = self.relu6_2(self.conv6_2(conv6_1)) 121 | conv6_3 = self.relu6_3(self.conv6_3(conv6_2)) 122 | conv6_3norm = self.conv6_3norm(conv6_3) 123 | conv7_1 = self.relu7_1(self.conv7_1(conv6_3norm)) 124 | conv7_2 = self.relu7_2(self.conv7_2(conv7_1)) 125 | conv7_3 = self.relu7_3(self.conv7_3(conv7_2)) 126 | conv7_3norm = self.conv7_3norm(conv7_3) 127 | conv8_1 = self.conv8_1(conv7_3norm) 128 | conv3_3_short = self.conv3_3_short(conv3_3norm) 129 | conv8_1_comb = self.relu8_1_comb(conv8_1 + conv3_3_short) 130 | conv8_2 = self.relu8_2(self.conv8_2(conv8_1_comb)) 131 | conv8_3 = self.relu8_3(self.conv8_3(conv8_2)) 132 | conv8_3norm = self.conv8_3norm(conv8_3) 133 | conv9_1 = self.conv9_1(conv8_3norm) 134 | conv2_2_short = self.conv2_2_short(conv2_2norm) 135 | conv9_1_comb = self.relu9_1_comb(conv9_1 + conv2_2_short) 136 | conv9_2 = self.relu9_2(self.conv9_2(conv9_1_comb)) 137 | conv9_2norm = self.conv9_2norm(conv9_2) 138 | conv10_1 = self.conv10_1(conv9_2norm) 139 | conv1_2_short = self.conv1_2_short(conv1_2norm) 140 | conv10_1_comb = self.relu10_1_comb(conv10_1 + conv1_2_short) 141 | conv10_2 = self.relu10_2(self.conv10_2(conv10_1_comb)) 142 | conv10_ab = self.conv10_ab(conv10_2) 143 | 144 | return torch.tanh(conv10_ab) * 128 145 | -------------------------------------------------------------------------------- /models/DEVC/models/FrameColor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from models.DEVC.utils.util import * 3 | 4 | 5 | def warp_color(IA_l, IB_lab, features_B, vggnet, nonlocal_net, colornet, feature_noise=0, temperature=0.01): 6 | IA_rgb_from_gray = gray2rgb_batch(IA_l) 7 | with torch.no_grad(): 8 | A_relu1_1, A_relu2_1, A_relu3_1, A_relu4_1, A_relu5_1 = vggnet( 9 | IA_rgb_from_gray, ["r12", "r22", "r32", "r42", "r52"], preprocess=True 10 | ) 11 | B_relu1_1, B_relu2_1, B_relu3_1, B_relu4_1, B_relu5_1 = features_B 12 | 13 | # NOTE: output the feature before normalization 14 | features_A = [A_relu1_1, A_relu2_1, A_relu3_1, A_relu4_1, A_relu5_1] 15 | 16 | A_relu2_1 = feature_normalize(A_relu2_1) 17 | A_relu3_1 = feature_normalize(A_relu3_1) 18 | A_relu4_1 = feature_normalize(A_relu4_1) 19 | A_relu5_1 = feature_normalize(A_relu5_1) 20 | B_relu2_1 = feature_normalize(B_relu2_1) 21 | B_relu3_1 = feature_normalize(B_relu3_1) 22 | B_relu4_1 = feature_normalize(B_relu4_1) 23 | B_relu5_1 = feature_normalize(B_relu5_1) 24 | 25 | nonlocal_BA_lab, similarity_map = nonlocal_net( 26 | IB_lab, 27 | A_relu2_1, 28 | A_relu3_1, 29 | A_relu4_1, 30 | A_relu5_1, 31 | B_relu2_1, 32 | B_relu3_1, 33 | B_relu4_1, 34 | B_relu5_1, 35 | temperature=temperature, 36 | ) 37 | 38 | return nonlocal_BA_lab, similarity_map, features_A 39 | 40 | 41 | def frame_colorization( 42 | IA_lab, 43 | IB_lab, 44 | IA_last_lab, 45 | features_B, 46 | vggnet, 47 | nonlocal_net, 48 | colornet, 49 | joint_training=True, 50 | feature_noise=0, 51 | luminance_noise=0, 52 | temperature=0.01, 53 | ): 54 | 55 | IA_l = IA_lab[:, 0:1, :, :] 56 | if luminance_noise: 57 | IA_l = IA_l + torch.randn_like(IA_l, requires_grad=False) * luminance_noise 58 | 59 | with torch.autograd.set_grad_enabled(joint_training): 60 | nonlocal_BA_lab, similarity_map, features_A_gray = warp_color( 61 | IA_l, IB_lab, features_B, vggnet, nonlocal_net, colornet, feature_noise, temperature=temperature 62 | ) 63 | nonlocal_BA_ab = nonlocal_BA_lab[:, 1:3, :, :] 64 | color_input = torch.cat((IA_l, nonlocal_BA_ab, similarity_map, IA_last_lab), dim=1) 65 | IA_ab_predict = colornet(color_input) 66 | 67 | return IA_ab_predict, nonlocal_BA_lab, features_A_gray 68 | -------------------------------------------------------------------------------- /models/DEVC/models/GAN_Models.py: -------------------------------------------------------------------------------- 1 | # DCGAN-like generator and discriminator 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn 5 | 6 | from models.DEVC.models.spectral_normalization import SpectralNorm 7 | 8 | 9 | class Generator(nn.Module): 10 | def __init__(self, z_dim): 11 | super(Generator, self).__init__() 12 | self.z_dim = z_dim 13 | 14 | self.model = nn.Sequential( 15 | nn.ConvTranspose2d(z_dim, 512, 4, stride=1), 16 | nn.InstanceNorm2d(512), 17 | nn.ReLU(), 18 | nn.ConvTranspose2d(512, 256, 4, stride=2, padding=(1, 1)), 19 | nn.InstanceNorm2d(256), 20 | nn.ReLU(), 21 | nn.ConvTranspose2d(256, 128, 4, stride=2, padding=(1, 1)), 22 | nn.InstanceNorm2d(128), 23 | nn.ReLU(), 24 | nn.ConvTranspose2d(128, 64, 4, stride=2, padding=(1, 1)), 25 | nn.InstanceNorm2d(64), 26 | nn.ReLU(), 27 | nn.ConvTranspose2d(64, channels, 3, stride=1, padding=(1, 1)), 28 | nn.Tanh(), 29 | ) 30 | 31 | def forward(self, z): 32 | return self.model(z.view(-1, self.z_dim, 1, 1)) 33 | 34 | 35 | channels = 3 36 | leak = 0.1 37 | w_g = 4 38 | 39 | 40 | class Discriminator(nn.Module): 41 | def __init__(self): 42 | super(Discriminator, self).__init__() 43 | 44 | self.conv1 = SpectralNorm(nn.Conv2d(channels, 64, 3, stride=1, padding=(1, 1))) 45 | self.conv2 = SpectralNorm(nn.Conv2d(64, 64, 4, stride=2, padding=(1, 1))) 46 | self.conv3 = SpectralNorm(nn.Conv2d(64, 128, 3, stride=1, padding=(1, 1))) 47 | self.conv4 = SpectralNorm(nn.Conv2d(128, 128, 4, stride=2, padding=(1, 1))) 48 | self.conv5 = SpectralNorm(nn.Conv2d(128, 256, 3, stride=1, padding=(1, 1))) 49 | self.conv6 = SpectralNorm(nn.Conv2d(256, 256, 4, stride=2, padding=(1, 1))) 50 | self.conv7 = SpectralNorm(nn.Conv2d(256, 256, 3, stride=1, padding=(1, 1))) 51 | self.conv8 = SpectralNorm(nn.Conv2d(256, 512, 4, stride=2, padding=(1, 1))) 52 | self.fc = SpectralNorm(nn.Linear(w_g * w_g * 512, 1)) 53 | 54 | def forward(self, x): 55 | m = x 56 | m = nn.LeakyReLU(leak)(self.conv1(m)) 57 | m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(64)(self.conv2(m))) 58 | m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(128)(self.conv3(m))) 59 | m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(128)(self.conv4(m))) 60 | m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv5(m))) 61 | m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv6(m))) 62 | m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv7(m))) 63 | m = nn.LeakyReLU(leak)(self.conv8(m)) 64 | 65 | return self.fc(m.view(-1, w_g * w_g * 512)) 66 | 67 | 68 | class Self_Attention(nn.Module): 69 | """ Self attention Layer""" 70 | 71 | def __init__(self, in_dim): 72 | super(Self_Attention, self).__init__() 73 | self.chanel_in = in_dim 74 | 75 | self.query_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 1, kernel_size=1)) 76 | self.key_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 1, kernel_size=1)) 77 | self.value_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)) 78 | self.gamma = nn.Parameter(torch.zeros(1)) 79 | 80 | self.softmax = nn.Softmax(dim=-1) # 81 | 82 | def forward(self, x): 83 | """ 84 | inputs : 85 | x : input feature maps( B X C X W X H) 86 | returns : 87 | out : self attention value + input feature 88 | attention: B X N X N (N is Width*Height) 89 | """ 90 | m_batchsize, C, width, height = x.size() 91 | proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N) 92 | proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H) 93 | energy = torch.bmm(proj_query, proj_key) # transpose check 94 | attention = self.softmax(energy) # BX (N) X (N) 95 | proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N 96 | 97 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 98 | out = out.view(m_batchsize, C, width, height) 99 | 100 | out = self.gamma * out + x 101 | return out 102 | 103 | 104 | class Discriminator_x64(nn.Module): 105 | """ 106 | Discriminative Network 107 | """ 108 | 109 | def __init__(self, in_size=6, ndf=64): 110 | super(Discriminator_x64, self).__init__() 111 | self.in_size = in_size 112 | self.ndf = ndf 113 | 114 | self.layer1 = nn.Sequential( 115 | SpectralNorm(nn.Conv2d(self.in_size, self.ndf, 4, 2, 1)), nn.LeakyReLU(0.2, inplace=True) 116 | ) 117 | self.layer2 = nn.Sequential( 118 | SpectralNorm(nn.Conv2d(self.ndf, self.ndf, 4, 2, 1)), 119 | nn.InstanceNorm2d(self.ndf), 120 | nn.LeakyReLU(0.2, inplace=True), 121 | ) 122 | self.attention = Self_Attention(self.ndf) 123 | self.layer3 = nn.Sequential( 124 | SpectralNorm(nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1)), 125 | nn.InstanceNorm2d(self.ndf * 2), 126 | nn.LeakyReLU(0.2, inplace=True), 127 | ) 128 | self.layer4 = nn.Sequential( 129 | SpectralNorm(nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1)), 130 | nn.InstanceNorm2d(self.ndf * 4), 131 | nn.LeakyReLU(0.2, inplace=True), 132 | ) 133 | self.layer5 = nn.Sequential( 134 | SpectralNorm(nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1)), 135 | nn.InstanceNorm2d(self.ndf * 8), 136 | nn.LeakyReLU(0.2, inplace=True), 137 | ) 138 | self.layer6 = nn.Sequential( 139 | SpectralNorm(nn.Conv2d(self.ndf * 8, self.ndf * 16, 4, 2, 1)), 140 | nn.InstanceNorm2d(self.ndf * 16), 141 | nn.LeakyReLU(0.2, inplace=True), 142 | ) 143 | 144 | self.last = SpectralNorm(nn.Conv2d(self.ndf * 16, 1, [3, 6], 1, 0)) 145 | 146 | def forward(self, input): 147 | feature1 = self.layer1(input) 148 | feature2 = self.layer2(feature1) 149 | feature_attention = self.attention(feature2) 150 | feature3 = self.layer3(feature_attention) 151 | feature4 = self.layer4(feature3) 152 | feature5 = self.layer5(feature4) 153 | feature6 = self.layer6(feature5) 154 | output = self.last(feature6) 155 | output = F.avg_pool2d(output, output.size()[2:]).view(output.size()[0], -1) 156 | 157 | return output, feature4 158 | -------------------------------------------------------------------------------- /models/DEVC/models/spectral_normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Parameter 4 | 5 | 6 | def l2normalize(v, eps=1e-12): 7 | return v / (v.norm() + eps) 8 | 9 | 10 | class SpectralNorm(nn.Module): 11 | def __init__(self, module, name="weight", power_iterations=1): 12 | super(SpectralNorm, self).__init__() 13 | self.module = module 14 | self.name = name 15 | self.power_iterations = power_iterations 16 | if not self._made_params(): 17 | self._make_params() 18 | 19 | def _update_u_v(self): 20 | u = getattr(self.module, self.name + "_u") 21 | v = getattr(self.module, self.name + "_v") 22 | w = getattr(self.module, self.name + "_bar") 23 | 24 | height = w.data.shape[0] 25 | for _ in range(self.power_iterations): 26 | v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data)) 27 | u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data)) 28 | 29 | sigma = u.dot(w.view(height, -1).mv(v)) 30 | setattr(self.module, self.name, w / sigma.expand_as(w)) 31 | 32 | def _made_params(self): 33 | try: 34 | u = getattr(self.module, self.name + "_u") 35 | v = getattr(self.module, self.name + "_v") 36 | w = getattr(self.module, self.name + "_bar") 37 | return True 38 | except AttributeError: 39 | return False 40 | 41 | def _make_params(self): 42 | w = getattr(self.module, self.name) 43 | 44 | height = w.data.shape[0] 45 | width = w.view(height, -1).data.shape[1] 46 | 47 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 48 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 49 | u.data = l2normalize(u.data) 50 | v.data = l2normalize(v.data) 51 | w_bar = Parameter(w.data) 52 | 53 | del self.module._parameters[self.name] 54 | 55 | self.module.register_parameter(self.name + "_u", u) 56 | self.module.register_parameter(self.name + "_v", v) 57 | self.module.register_parameter(self.name + "_bar", w_bar) 58 | 59 | def forward(self, *args): 60 | self._update_u_v() 61 | return self.module.forward(*args) 62 | -------------------------------------------------------------------------------- /models/DEVC/models/vgg19_gray.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class LambdaBase(nn.Sequential): 8 | def __init__(self, fn, *args): 9 | super(LambdaBase, self).__init__(*args) 10 | self.lambda_func = fn 11 | 12 | def forward_prepare(self, input): 13 | output = [] 14 | for module in self._modules.values(): 15 | output.append(module(input)) 16 | 17 | return output if output else input 18 | 19 | 20 | class Lambda(LambdaBase): 21 | def forward(self, input): 22 | return self.lambda_func(self.forward_prepare(input)) 23 | 24 | 25 | class LambdaMap(LambdaBase): 26 | def forward(self, input): 27 | return list(map(self.lambda_func, self.forward_prepare(input))) 28 | 29 | 30 | class LambdaReduce(LambdaBase): 31 | def forward(self, input): 32 | return reduce(self.lambda_func, self.forward_prepare(input)) 33 | 34 | layer_names = [ 35 | "conv1_1", 36 | "relu1_1", 37 | "conv1_2", 38 | "relu1_2", 39 | "pool1", 40 | "conv2_1", 41 | "relu2_1", 42 | "conv2_2", 43 | "relu2_2", 44 | "pool2", 45 | "conv3_1", 46 | "relu3_1", 47 | "conv3_2", 48 | "relu3_2", 49 | "conv3_3", 50 | "relu3_3", 51 | "conv3_4", 52 | "relu3_4", 53 | "pool3", 54 | "conv4_1", 55 | "relu4_1", 56 | "conv4_2", 57 | "relu4_2", 58 | "conv4_3", 59 | "relu4_3", 60 | "conv4_4", 61 | "relu4_4", 62 | "pool4", 63 | "conv5_1", 64 | "relu5_1", 65 | "conv5_2", 66 | "relu5_2", 67 | "conv5_3", 68 | "relu5_3", 69 | "conv5_4", 70 | "relu5_4", 71 | "pool5", 72 | "view1", 73 | "fc6", 74 | "fc6_relu", 75 | "fc7", 76 | "fc7_relu", 77 | "fc8", 78 | ] 79 | 80 | def get_pretrained_vgg(): 81 | model = nn.Sequential( # Sequential, 82 | nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1)), 83 | nn.ReLU(), 84 | nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)), 85 | nn.ReLU(), 86 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 87 | nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1)), 88 | nn.ReLU(), 89 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)), 90 | nn.ReLU(), 91 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 92 | nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1)), 93 | nn.ReLU(), 94 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), 95 | nn.ReLU(), 96 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), 97 | nn.ReLU(), 98 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)), 99 | nn.ReLU(), 100 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 101 | nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1)), 102 | nn.ReLU(), 103 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)), 104 | nn.ReLU(), 105 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)), 106 | nn.ReLU(), 107 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)), 108 | nn.ReLU(), 109 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 110 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)), 111 | nn.ReLU(), 112 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)), 113 | nn.ReLU(), 114 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)), 115 | nn.ReLU(), 116 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)), 117 | nn.ReLU(), 118 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True), 119 | Lambda(lambda x: x.view(x.size(0), -1)), # View, 120 | nn.Sequential(Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x), nn.Linear(25088, 4096)), # Linear, 121 | nn.ReLU(), 122 | nn.Sequential(Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x), nn.Linear(4096, 4096)), # Linear, 123 | nn.ReLU(), 124 | nn.Sequential(Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x), nn.Linear(4096, 1000)), # Linear, 125 | ) 126 | 127 | 128 | model.load_state_dict(torch.load("data/vgg19_gray.pth")) 129 | vgg19_gray_net = torch.nn.Sequential() 130 | for (name, layer) in model._modules.items(): 131 | vgg19_gray_net.add_module(layer_names[int(name)], model[int(name)]) 132 | 133 | for param in vgg19_gray_net.parameters(): 134 | param.requires_grad = False 135 | vgg19_gray_net.eval() 136 | return vgg19_gray_net 137 | 138 | 139 | class vgg19_gray(torch.nn.Module): 140 | def __init__(self, requires_grad=False): 141 | super(vgg19_gray, self).__init__() 142 | vgg_pretrained_features = get_pretrained_vgg() 143 | self.slice1 = torch.nn.Sequential() 144 | self.slice2 = torch.nn.Sequential() 145 | self.slice3 = torch.nn.Sequential() 146 | for x in range(12): 147 | self.slice1.add_module(layer_names[x], vgg_pretrained_features[x]) 148 | for x in range(12, 21): 149 | self.slice2.add_module(layer_names[x], vgg_pretrained_features[x]) 150 | for x in range(21, 30): 151 | self.slice3.add_module(layer_names[x], vgg_pretrained_features[x]) 152 | if not requires_grad: 153 | for param in self.parameters(): 154 | param.requires_grad = False 155 | 156 | def forward(self, X): 157 | h = self.slice1(X) 158 | h_relu3_1 = h 159 | h = self.slice2(h) 160 | h_relu4_1 = h 161 | h = self.slice3(h) 162 | h_relu5_1 = h 163 | return h_relu3_1, h_relu4_1, h_relu5_1 164 | 165 | 166 | class vgg19_gray_new(torch.nn.Module): 167 | def __init__(self, requires_grad=False): 168 | super(vgg19_gray_new, self).__init__() 169 | vgg_pretrained_features = get_pretrained_vgg() 170 | self.slice0 = torch.nn.Sequential() 171 | self.slice1 = torch.nn.Sequential() 172 | self.slice2 = torch.nn.Sequential() 173 | self.slice3 = torch.nn.Sequential() 174 | for x in range(7): 175 | self.slice0.add_module(layer_names[x], vgg_pretrained_features[x]) 176 | for x in range(7, 12): 177 | self.slice1.add_module(layer_names[x], vgg_pretrained_features[x]) 178 | for x in range(12, 21): 179 | self.slice2.add_module(layer_names[x], vgg_pretrained_features[x]) 180 | for x in range(21, 30): 181 | self.slice3.add_module(layer_names[x], vgg_pretrained_features[x]) 182 | if not requires_grad: 183 | for param in self.parameters(): 184 | param.requires_grad = False 185 | 186 | def forward(self, X): 187 | h = self.slice0(X) 188 | h_relu2_1 = h 189 | h = self.slice1(h) 190 | h_relu3_1 = h 191 | h = self.slice2(h) 192 | h_relu4_1 = h 193 | h = self.slice3(h) 194 | h_relu5_1 = h 195 | return h_relu2_1, h_relu3_1, h_relu4_1, h_relu5_1 196 | -------------------------------------------------------------------------------- /models/DEVC/options/test_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.backends.cudnn as cudnn 3 | class TestOptions(): 4 | def __init__(self): 5 | self.initialized = False 6 | 7 | def initialize(self, parser): 8 | parser.add_argument( 9 | "--frame_propagate", default=True, type=bool, help="propagation mode, , please check the paper" 10 | ) 11 | parser.add_argument("--image_size", type=int, default=[216 * 2, 384 * 2], help="the image size, eg. [216,384]") 12 | parser.add_argument("--cuda", action="store_false") 13 | parser.add_argument("--gpu_ids", type=str, default="0", help="separate by comma") 14 | parser.add_argument("--clip_path", type=str, default="../test/input.mp4", help="path of input clips") 15 | parser.add_argument("--ref_path", type=str, default="../test/frame00000.jpg", help="path of refernce images") 16 | parser.add_argument("--output_frame_path", type=str, default="../test/results", help="path of output colorized frames") 17 | 18 | self.initialized = True 19 | return parser 20 | 21 | def parse(self): 22 | # initialize parser with basic options 23 | if not self.initialized: 24 | parser = argparse.ArgumentParser( 25 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 26 | parser = self.initialize(parser) 27 | opt = parser.parse_args() 28 | 29 | opt.gpu_ids = [int(x) for x in opt.gpu_ids.split(",")] 30 | cudnn.benchmark = True 31 | print("running on GPU", opt.gpu_ids) 32 | return opt -------------------------------------------------------------------------------- /models/DEVC/utils/lib/test_transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import collections 4 | import numbers 5 | import random 6 | 7 | import torch 8 | from PIL import Image 9 | from skimage import color 10 | 11 | import models.DEVC.utils.lib.functional as F 12 | 13 | __all__ = [ 14 | "Compose", 15 | "Concatenate", 16 | "ToTensor", 17 | "Normalize", 18 | "Resize", 19 | "Scale", 20 | "CenterCrop", 21 | "Pad", 22 | "RandomCrop", 23 | "RandomHorizontalFlip", 24 | "RandomVerticalFlip", 25 | "RandomResizedCrop", 26 | "RandomSizedCrop", 27 | "FiveCrop", 28 | "TenCrop", 29 | "RGB2Lab", 30 | ] 31 | 32 | 33 | def CustomFunc(inputs, func, *args, **kwargs): 34 | im_l = func(inputs[0], *args, **kwargs) 35 | im_ab = func(inputs[1], *args, **kwargs) 36 | warp_ba = func(inputs[2], *args, **kwargs) 37 | warp_aba = func(inputs[3], *args, **kwargs) 38 | # im_gbl_ab = func(inputs[4], *args, **kwargs) 39 | # bgr_mc_im = func(inputs[5], *args, **kwargs) 40 | layer_data = [im_l, im_ab, warp_ba, warp_aba] 41 | # layer_data = [im_l, im_ab, warp_ba, warp_aba, im_gbl_ab, bgr_mc_im] 42 | for l in range(5): 43 | layer = inputs[4 + l] 44 | err_ba = func(layer[0], *args, **kwargs) 45 | err_ab = func(layer[1], *args, **kwargs) 46 | 47 | layer_data.append([err_ba, err_ab]) 48 | 49 | return layer_data 50 | 51 | 52 | class Compose(object): 53 | """Composes several transforms together. 54 | 55 | Args: 56 | transforms (list of ``Transform`` objects): list of transforms to compose. 57 | 58 | Example: 59 | >>> transforms.Compose([ 60 | >>> transforms.CenterCrop(10), 61 | >>> transforms.ToTensor(), 62 | >>> ]) 63 | """ 64 | 65 | def __init__(self, transforms): 66 | self.transforms = transforms 67 | 68 | def __call__(self, inputs): 69 | for t in self.transforms: 70 | inputs = t(inputs) 71 | return inputs 72 | 73 | 74 | class Concatenate(object): 75 | """ 76 | Input: [im_l, im_ab, inputs] 77 | inputs = [warp_ba_l, warp_ba_ab, warp_aba, err_pm, err_aba] 78 | 79 | Output:[im_l, err_pm, warp_ba, warp_aba, im_ab, err_aba] 80 | """ 81 | 82 | def __call__(self, inputs): 83 | im_l = inputs[0] 84 | im_ab = inputs[1] 85 | warp_ba = inputs[2] 86 | warp_aba = inputs[3] 87 | # im_glb_ab = inputs[4] 88 | # bgr_mc_im = inputs[5] 89 | # bgr_mc_im = bgr_mc_im[[2, 1, 0], ...] 90 | 91 | err_ba = [] 92 | err_ab = [] 93 | 94 | for l in range(5): 95 | layer = inputs[4 + l] 96 | err_ba.append(layer[0]) 97 | err_ab.append(layer[1]) 98 | 99 | cerr_ba = torch.cat(err_ba, 0) 100 | cerr_ab = torch.cat(err_ab, 0) 101 | 102 | return (im_l, cerr_ba, warp_ba, warp_aba, im_ab, cerr_ab) 103 | # return (im_l, cerr_ba, warp_ba, warp_aba, im_glb_ab, bgr_mc_im, im_ab, cerr_ab) 104 | 105 | 106 | class ToTensor(object): 107 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 108 | 109 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 110 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 111 | """ 112 | 113 | def __call__(self, inputs): 114 | """ 115 | Args: 116 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 117 | 118 | Returns: 119 | Tensor: Converted image. 120 | """ 121 | inputs = CustomFunc(inputs, F.to_mytensor) 122 | return inputs 123 | 124 | 125 | class Normalize(object): 126 | """Normalize an tensor image with mean and standard deviation. 127 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform 128 | will normalize each channel of the input ``torch.*Tensor`` i.e. 129 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 130 | 131 | Args: 132 | mean (sequence): Sequence of means for each channel. 133 | std (sequence): Sequence of standard deviations for each channel. 134 | """ 135 | 136 | def __call__(self, inputs): 137 | """ 138 | Args: 139 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 140 | 141 | Returns: 142 | Tensor: Normalized Tensor image. 143 | """ 144 | 145 | im_l = F.normalize(inputs[0], 50, 1) # [0, 100] 146 | im_ab = F.normalize(inputs[1], (0, 0), (1, 1)) # [-100, 100] 147 | 148 | inputs[2][0:1, :, :] = F.normalize(inputs[2][0:1, :, :], 50, 1) 149 | inputs[2][1:3, :, :] = F.normalize(inputs[2][1:3, :, :], (0, 0), (1, 1)) 150 | warp_ba = inputs[2] 151 | 152 | inputs[3][0:1, :, :] = F.normalize(inputs[3][0:1, :, :], 50, 1) 153 | inputs[3][1:3, :, :] = F.normalize(inputs[3][1:3, :, :], (0, 0), (1, 1)) 154 | warp_aba = inputs[3] 155 | 156 | # im_gbl_ab = F.normalize(inputs[4], (0, 0), (1, 1)) # [-100, 100] 157 | # 158 | # bgr_mc_im = F.normalize(inputs[5], (123.68, 116.78, 103.938), (1, 1, 1)) 159 | 160 | # layer_data = [im_l, im_ab, warp_ba, warp_aba, im_gbl_ab, bgr_mc_im] 161 | layer_data = [im_l, im_ab, warp_ba, warp_aba] 162 | 163 | for l in range(5): 164 | layer = inputs[4 + l] 165 | err_ba = F.normalize(layer[0], 127, 2) # [0, 255] 166 | err_ab = F.normalize(layer[1], 127, 2) # [0, 255] 167 | layer_data.append([err_ba, err_ab]) 168 | 169 | return layer_data 170 | 171 | 172 | class Resize(object): 173 | """Resize the input PIL Image to the given size. 174 | 175 | Args: 176 | size (sequence or int): Desired output size. If size is a sequence like 177 | (h, w), output size will be matched to this. If size is an int, 178 | smaller edge of the image will be matched to this number. 179 | i.e, if height > width, then image will be rescaled to 180 | (size * height / width, size) 181 | interpolation (int, optional): Desired interpolation. Default is 182 | ``PIL.Image.BILINEAR`` 183 | """ 184 | 185 | def __init__(self, size, interpolation=Image.BILINEAR): 186 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 187 | self.size = size 188 | self.interpolation = interpolation 189 | 190 | def __call__(self, inputs): 191 | """ 192 | Args: 193 | img (PIL Image): Image to be scaled. 194 | 195 | Returns: 196 | PIL Image: Rescaled image. 197 | """ 198 | return CustomFunc(inputs, F.resize, self.size, self.interpolation) 199 | 200 | 201 | class RandomCrop(object): 202 | """Crop the given PIL Image at a random location. 203 | 204 | Args: 205 | size (sequence or int): Desired output size of the crop. If size is an 206 | int instead of sequence like (h, w), a square crop (size, size) is 207 | made. 208 | padding (int or sequence, optional): Optional padding on each border 209 | of the image. Default is 0, i.e no padding. If a sequence of length 210 | 4 is provided, it is used to pad left, top, right, bottom borders 211 | respectively. 212 | """ 213 | 214 | def __init__(self, size, padding=0): 215 | if isinstance(size, numbers.Number): 216 | self.size = (int(size), int(size)) 217 | else: 218 | self.size = size 219 | self.padding = padding 220 | 221 | @staticmethod 222 | def get_params(img, output_size): 223 | """Get parameters for ``crop`` for a random crop. 224 | 225 | Args: 226 | img (PIL Image): Image to be cropped. 227 | output_size (tuple): Expected output size of the crop. 228 | 229 | Returns: 230 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 231 | """ 232 | w, h = img.size 233 | th, tw = output_size 234 | if w == tw and h == th: 235 | return 0, 0, h, w 236 | 237 | i = random.randint(0, h - th) 238 | j = random.randint(0, w - tw) 239 | return i, j, th, tw 240 | 241 | def __call__(self, inputs): 242 | """ 243 | Args: 244 | img (PIL Image): Image to be cropped. 245 | 246 | Returns: 247 | PIL Image: Cropped image. 248 | """ 249 | if self.padding > 0: 250 | inputs = CustomFunc(inputs, F.pad, self.padding) 251 | 252 | i, j, h, w = self.get_params(inputs[0], self.size) 253 | return CustomFunc(inputs, F.crop, i, j, h, w) 254 | 255 | 256 | class CenterCrop(object): 257 | """Crop the given PIL Image at a random location. 258 | 259 | Args: 260 | size (sequence or int): Desired output size of the crop. If size is an 261 | int instead of sequence like (h, w), a square crop (size, size) is 262 | made. 263 | padding (int or sequence, optional): Optional padding on each border 264 | of the image. Default is 0, i.e no padding. If a sequence of length 265 | 4 is provided, it is used to pad left, top, right, bottom borders 266 | respectively. 267 | """ 268 | 269 | def __init__(self, size, padding=0): 270 | if isinstance(size, numbers.Number): 271 | self.size = (int(size), int(size)) 272 | else: 273 | self.size = size 274 | self.padding = padding 275 | 276 | @staticmethod 277 | def get_params(img, output_size): 278 | """Get parameters for ``crop`` for a random crop. 279 | 280 | Args: 281 | img (PIL Image): Image to be cropped. 282 | output_size (tuple): Expected output size of the crop. 283 | 284 | Returns: 285 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. 286 | """ 287 | w, h = img.size 288 | th, tw = output_size 289 | if w == tw and h == th: 290 | return 0, 0, h, w 291 | 292 | i = (h - th) // 2 293 | j = (w - tw) // 2 294 | return i, j, th, tw 295 | 296 | def __call__(self, inputs): 297 | """ 298 | Args: 299 | img (PIL Image): Image to be cropped. 300 | 301 | Returns: 302 | PIL Image: Cropped image. 303 | """ 304 | if self.padding > 0: 305 | inputs = CustomFunc(inputs, F.pad, self.padding) 306 | 307 | if type(inputs) is list: 308 | i, j, h, w = self.get_params(inputs[0], self.size) 309 | else: 310 | i, j, h, w = self.get_params(inputs, self.size) 311 | return CustomFunc(inputs, F.crop, i, j, h, w) 312 | 313 | 314 | class RandomHorizontalFlip(object): 315 | """Horizontally flip the given PIL Image randomly with a probability of 0.5.""" 316 | 317 | def __call__(self, inputs): 318 | """ 319 | Args: 320 | img (PIL Image): Image to be flipped. 321 | 322 | Returns: 323 | PIL Image: Randomly flipped image. 324 | """ 325 | 326 | if random.random() < 0.5: 327 | return CustomFunc(inputs, F.hflip) 328 | return inputs 329 | 330 | 331 | class RGB2Lab(object): 332 | def __call__(self, inputs): 333 | """ 334 | Args: 335 | img (PIL Image): Image to be flipped. 336 | 337 | Returns: 338 | PIL Image: Randomly flipped image. 339 | """ 340 | 341 | def __call__(self, inputs): 342 | image_lab = color.rgb2lab(inputs[0]) 343 | warp_ba_lab = color.rgb2lab(inputs[2]) 344 | warp_aba_lab = color.rgb2lab(inputs[3]) 345 | # im_gbl_lab = color.rgb2lab(inputs[4]) 346 | 347 | inputs[0] = image_lab[:, :, :1] # l channel 348 | inputs[1] = image_lab[:, :, 1:] # ab channel 349 | inputs[2] = warp_ba_lab # lab channel 350 | inputs[3] = warp_aba_lab # lab channel 351 | # inputs[4] = im_gbl_lab[:, :, 1:] # ab channel 352 | 353 | return inputs 354 | -------------------------------------------------------------------------------- /models/DEVC/utils/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision 6 | from utils.util import feature_normalize 7 | 8 | postpa = torchvision.transforms.Compose( 9 | [ 10 | torchvision.transforms.Lambda(lambda x: x.mul_(1.0 / 255)), 11 | torchvision.transforms.Normalize( 12 | mean=[-0.40760392, -0.45795686, -0.48501961], std=[1, 1, 1] # add imagenet mean 13 | ), 14 | torchvision.transforms.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]), # turn to RGB 15 | ] 16 | ) 17 | postpb = torchvision.transforms.Compose([torchvision.transforms.ToPILImage()]) 18 | 19 | 20 | def post_processing(tensor): 21 | t = postpa(tensor) # denormalize the image since the optimized tensor is the normalized one 22 | t[t > 1] = 1 23 | t[t < 0] = 0 24 | img = postpb(t) 25 | img = np.array(img) 26 | return img 27 | 28 | 29 | class ContextualLoss(nn.Module): 30 | """ 31 | input is Al, Bl, channel = 1, range ~ [0, 255] 32 | """ 33 | 34 | def __init__(self): 35 | super(ContextualLoss, self).__init__() 36 | return None 37 | 38 | def forward(self, X_features, Y_features, h=0.1, feature_centering=True): 39 | """ 40 | X_features&Y_features are are feature vectors or feature 2d array 41 | h: bandwidth 42 | return the per-sample loss 43 | """ 44 | batch_size = X_features.shape[0] 45 | feature_depth = X_features.shape[1] 46 | feature_size = X_features.shape[2] 47 | 48 | # to normalized feature vectors 49 | if feature_centering: 50 | X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze( 51 | dim=-1 52 | ).unsqueeze(dim=-1) 53 | Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze( 54 | dim=-1 55 | ).unsqueeze(dim=-1) 56 | X_features = feature_normalize(X_features).view( 57 | batch_size, feature_depth, -1 58 | ) # batch_size * feature_depth * feature_size^2 59 | Y_features = feature_normalize(Y_features).view( 60 | batch_size, feature_depth, -1 61 | ) # batch_size * feature_depth * feature_size^2 62 | 63 | # conine distance = 1 - similarity 64 | X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth 65 | d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2 66 | 67 | # normalized distance: dij_bar 68 | d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2 69 | 70 | # pairwise affinity 71 | w = torch.exp((1 - d_norm) / h) 72 | A_ij = w / torch.sum(w, dim=-1, keepdim=True) 73 | 74 | # contextual loss per sample 75 | CX = torch.mean(torch.max(A_ij, dim=1)[0], dim=-1) 76 | return -torch.log(CX) 77 | 78 | 79 | class ContextualLoss_forward(nn.Module): 80 | """ 81 | input is Al, Bl, channel = 1, range ~ [0, 255] 82 | """ 83 | 84 | def __init__(self): 85 | super(ContextualLoss_forward, self).__init__() 86 | return None 87 | 88 | def forward(self, X_features, Y_features, h=0.1, feature_centering=True): 89 | """ 90 | X_features&Y_features are are feature vectors or feature 2d array 91 | h: bandwidth 92 | return the per-sample loss 93 | """ 94 | batch_size = X_features.shape[0] 95 | feature_depth = X_features.shape[1] 96 | feature_size = X_features.shape[2] 97 | 98 | # to normalized feature vectors 99 | if feature_centering: 100 | X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze( 101 | dim=-1 102 | ).unsqueeze(dim=-1) 103 | Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze( 104 | dim=-1 105 | ).unsqueeze(dim=-1) 106 | X_features = feature_normalize(X_features).view( 107 | batch_size, feature_depth, -1 108 | ) # batch_size * feature_depth * feature_size^2 109 | Y_features = feature_normalize(Y_features).view( 110 | batch_size, feature_depth, -1 111 | ) # batch_size * feature_depth * feature_size^2 112 | 113 | # conine distance = 1 - similarity 114 | X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth 115 | d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2 116 | 117 | # normalized distance: dij_bar 118 | d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2 119 | 120 | # pairwise affinity 121 | w = torch.exp((1 - d_norm) / h) 122 | A_ij = w / torch.sum(w, dim=-1, keepdim=True) 123 | 124 | # contextual loss per sample 125 | CX = torch.mean(torch.max(A_ij, dim=-1)[0], dim=1) 126 | return -torch.log(CX) 127 | 128 | 129 | class ContextualLoss_complex(nn.Module): 130 | """ 131 | input is Al, Bl, channel = 1, range ~ [0, 255] 132 | """ 133 | 134 | def __init__(self): 135 | super(ContextualLoss_complex, self).__init__() 136 | return None 137 | 138 | def forward(self, X_features, Y_features, h=0.1, patch_size=1, direction="forward"): 139 | """ 140 | X_features&Y_features are are feature vectors or feature 2d array 141 | h: bandwidth 142 | return the per-sample loss 143 | """ 144 | batch_size = X_features.shape[0] 145 | feature_depth = X_features.shape[1] 146 | feature_size = X_features.shape[2] 147 | 148 | # to normalized feature vectors 149 | X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze( 150 | dim=-1 151 | ).unsqueeze(dim=-1) 152 | Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze( 153 | dim=-1 154 | ).unsqueeze(dim=-1) 155 | X_features = feature_normalize(X_features) # batch_size * feature_depth * feature_size^2 156 | Y_features = feature_normalize(Y_features) # batch_size * feature_depth * feature_size^2 157 | 158 | # to normalized feature vectors 159 | X_features = F.unfold( 160 | X_features, kernel_size=(patch_size, patch_size), stride=(1, 1), padding=(patch_size // 2, patch_size // 2) 161 | ) # batch_size * feature_depth_new * feature_size^2 162 | Y_features = F.unfold( 163 | Y_features, kernel_size=(patch_size, patch_size), stride=(1, 1), padding=(patch_size // 2, patch_size // 2) 164 | ) # batch_size * feature_depth_new * feature_size^2 165 | 166 | # conine distance = 1 - similarity 167 | X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth 168 | d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2 169 | 170 | # normalized distance: dij_bar 171 | d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2 172 | 173 | # pairwise affinity 174 | w = torch.exp((1 - d_norm) / h) 175 | A_ij = w / torch.sum(w, dim=-1, keepdim=True) 176 | 177 | # contextual loss per sample 178 | if direction == "forward": 179 | CX = torch.mean(torch.max(A_ij, dim=-1)[0], dim=1) 180 | else: 181 | CX = torch.mean(torch.max(A_ij, dim=1)[0], dim=-1) 182 | 183 | return -torch.log(CX) 184 | 185 | 186 | class ChamferDistance_patch_loss(nn.Module): 187 | """ 188 | input is Al, Bl, channel = 1, range ~ [0, 255] 189 | """ 190 | 191 | def __init__(self): 192 | super(ChamferDistance_patch_loss, self).__init__() 193 | return None 194 | 195 | def forward(self, X_features, Y_features, patch_size=3, image_x=None, image_y=None, h=0.1, Y_features_in=None): 196 | """ 197 | X_features&Y_features are are feature vectors or feature 2d array 198 | h: bandwidth 199 | return the per-sample loss 200 | """ 201 | batch_size = X_features.shape[0] 202 | feature_depth = X_features.shape[1] 203 | feature_size = X_features.shape[2] 204 | 205 | # to normalized feature vectors 206 | X_features = F.unfold( 207 | X_features, kernel_size=(patch_size, patch_size), stride=(1, 1), padding=(patch_size // 2, patch_size // 2) 208 | ) # batch_size, feature_depth_new * feature_size^2 209 | Y_features = F.unfold( 210 | Y_features, kernel_size=(patch_size, patch_size), stride=(1, 1), padding=(patch_size // 2, patch_size // 2) 211 | ) # batch_size, feature_depth_new * feature_size^2 212 | 213 | if image_x is not None and image_y is not None: 214 | image_x = torch.nn.functional.interpolate(image_x, size=(feature_size, feature_size), mode="bilinear").view( 215 | batch_size, 3, -1 216 | ) 217 | image_y = torch.nn.functional.interpolate(image_y, size=(feature_size, feature_size), mode="bilinear").view( 218 | batch_size, 3, -1 219 | ) 220 | 221 | X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth 222 | similarity_matrix = torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2 223 | NN_index = similarity_matrix.max(dim=-1, keepdim=True)[1].squeeze() 224 | 225 | if Y_features_in is not None: 226 | loss = torch.mean((X_features - Y_features_in.detach()) ** 2) 227 | Y_features_in = Y_features_in.detach() 228 | else: 229 | loss = torch.mean((X_features - Y_features[:, :, NN_index].detach()) ** 2) 230 | Y_features_in = Y_features[:, :, NN_index].detach() 231 | 232 | # re-arrange image 233 | if image_x is not None and image_y is not None: 234 | image_y_rearrange = image_y[:, :, NN_index] 235 | image_y_rearrange = image_y_rearrange.view(batch_size, 3, feature_size, feature_size) 236 | image_x = image_x.view(batch_size, 3, feature_size, feature_size) 237 | image_y = image_y.view(batch_size, 3, feature_size, feature_size) 238 | 239 | return loss 240 | 241 | 242 | class ChamferDistance_loss(nn.Module): 243 | """ 244 | input is Al, Bl, channel = 1, range ~ [0, 255] 245 | """ 246 | 247 | def __init__(self): 248 | super(ChamferDistance_loss, self).__init__() 249 | return None 250 | 251 | def forward(self, X_features, Y_features, image_x, image_y, h=0.1, Y_features_in=None): 252 | """ 253 | X_features&Y_features are are feature vectors or feature 2d array 254 | h: bandwidth 255 | return the per-sample loss 256 | """ 257 | batch_size = X_features.shape[0] 258 | feature_depth = X_features.shape[1] 259 | feature_size = X_features.shape[2] 260 | 261 | # to normalized feature vectors 262 | X_features = feature_normalize(X_features).view( 263 | batch_size, feature_depth, -1 264 | ) # batch_size * feature_depth * feature_size^2 265 | Y_features = feature_normalize(Y_features).view( 266 | batch_size, feature_depth, -1 267 | ) # batch_size * feature_depth * feature_size^2 268 | image_x = torch.nn.functional.interpolate(image_x, size=(feature_size, feature_size), mode="bilinear").view( 269 | batch_size, 3, -1 270 | ) 271 | image_y = torch.nn.functional.interpolate(image_y, size=(feature_size, feature_size), mode="bilinear").view( 272 | batch_size, 3, -1 273 | ) 274 | 275 | X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth 276 | similarity_matrix = torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2 277 | NN_index = similarity_matrix.max(dim=-1, keepdim=True)[1].squeeze() 278 | if Y_features_in is not None: 279 | loss = torch.mean((X_features - Y_features_in.detach()) ** 2) 280 | Y_features_in = Y_features_in.detach() 281 | else: 282 | loss = torch.mean((X_features - Y_features[:, :, NN_index].detach()) ** 2) 283 | Y_features_in = Y_features[:, :, NN_index].detach() 284 | 285 | # re-arrange image 286 | image_y_rearrange = image_y[:, :, NN_index] 287 | image_y_rearrange = image_y_rearrange.view(batch_size, 3, feature_size, feature_size) 288 | image_x = image_x.view(batch_size, 3, feature_size, feature_size) 289 | image_y = image_y.view(batch_size, 3, feature_size, feature_size) 290 | 291 | return loss, Y_features_in, X_features 292 | 293 | 294 | if __name__ == "__main__": 295 | contextual_loss = ContextualLoss() 296 | batch_size = 32 297 | feature_depth = 8 298 | feature_size = 16 299 | X_features = torch.zeros(batch_size, feature_depth, feature_size, feature_size) 300 | Y_features = torch.zeros(batch_size, feature_depth, feature_size, feature_size) 301 | 302 | cx_loss = contextual_loss(X_features, Y_features, 1) 303 | print(cx_loss) 304 | -------------------------------------------------------------------------------- /models/DEVC/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import sys 4 | 5 | import cv2 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | import torch 9 | import torchvision.utils as vutils 10 | from skimage import color, io 11 | from torch.autograd import Variable 12 | 13 | 14 | from utils.loss import mse_loss 15 | 16 | cv2.setNumThreads(0) 17 | 18 | # l: [-50,50] 19 | # ab: [-128, 128] 20 | l_norm, ab_norm = 1.0, 1.0 21 | l_mean, ab_mean = 50.0, 0 22 | 23 | 24 | ###### utility ###### 25 | def to_np(x): 26 | return x.data.cpu().numpy() 27 | 28 | 29 | def utf8_str(in_str): 30 | try: 31 | in_str = in_str.decode("UTF-8") 32 | except Exception: 33 | in_str = in_str.encode("UTF-8").decode("UTF-8") 34 | return in_str 35 | 36 | 37 | class MovingAvg(object): 38 | def __init__(self, pool_size=100): 39 | from queue import Queue 40 | 41 | self.pool = Queue(maxsize=pool_size) 42 | self.sum = 0 43 | self.curr_pool_size = 0 44 | 45 | def set_curr_val(self, val): 46 | if not self.pool.full(): 47 | self.curr_pool_size += 1 48 | self.pool.put_nowait(val) 49 | else: 50 | last_first_val = self.pool.get_nowait() 51 | self.pool.put_nowait(val) 52 | self.sum -= last_first_val 53 | 54 | self.sum += val 55 | return self.sum / self.curr_pool_size 56 | 57 | 58 | ###### image normalization ###### 59 | def center_l(l): 60 | # normalization for l 61 | l_mc = (l - l_mean) / l_norm 62 | return l_mc 63 | 64 | 65 | # denormalization for l 66 | def uncenter_l(l): 67 | return l * l_norm + l_mean 68 | 69 | 70 | # normalization for ab 71 | def center_ab(ab): 72 | return (ab - ab_mean) / ab_norm 73 | 74 | 75 | # normalization for lab image 76 | def center_lab_img(img_lab): 77 | return ( 78 | img_lab / np.array((l_norm, ab_norm, ab_norm))[:, np.newaxis, np.newaxis] 79 | - np.array((l_mean / l_norm, ab_mean / ab_norm, ab_mean / ab_norm))[:, np.newaxis, np.newaxis] 80 | ) 81 | 82 | 83 | ###### color space transformation ###### 84 | def rgb2lab_transpose(img_rgb): 85 | return color.rgb2lab(img_rgb).transpose((2, 0, 1)) 86 | 87 | 88 | def lab2rgb(img_l, img_ab): 89 | """INPUTS 90 | img_l XxXx1 [0,100] 91 | img_ab XxXx2 [-100,100] 92 | OUTPUTS 93 | returned value is XxXx3""" 94 | pred_lab = np.concatenate((img_l, img_ab), axis=2).astype("float64") 95 | pred_rgb = color.lab2rgb(pred_lab) 96 | pred_rgb = (np.clip(pred_rgb, 0, 1) * 255).astype("uint8") 97 | return pred_rgb 98 | 99 | 100 | def gray2rgb_batch(l): 101 | # gray image tensor to rgb image tensor 102 | l_uncenter = uncenter_l(l) 103 | l_uncenter = l_uncenter / (2 * l_mean) 104 | return torch.cat((l_uncenter, l_uncenter, l_uncenter), dim=1) 105 | 106 | 107 | def lab2rgb_transpose(img_l, img_ab): 108 | """INPUTS 109 | img_l 1xXxX [0,100] 110 | img_ab 2xXxX [-100,100] 111 | OUTPUTS 112 | returned value is XxXx3""" 113 | pred_lab = np.concatenate((img_l, img_ab), axis=0).transpose((1, 2, 0)) 114 | return (np.clip(color.lab2rgb(pred_lab), 0, 1) * 255).astype("uint8") 115 | 116 | 117 | def lab2rgb_transpose_mc(img_l_mc, img_ab_mc): 118 | if isinstance(img_l_mc, Variable): 119 | img_l_mc = img_l_mc.data.cpu() 120 | if isinstance(img_ab_mc, Variable): 121 | img_ab_mc = img_ab_mc.data.cpu() 122 | 123 | if img_l_mc.is_cuda: 124 | img_l_mc = img_l_mc.cpu() 125 | if img_ab_mc.is_cuda: 126 | img_ab_mc = img_ab_mc.cpu() 127 | 128 | assert img_l_mc.dim() == 3 and img_ab_mc.dim() == 3, "only for batch input" 129 | 130 | img_l = img_l_mc * l_norm + l_mean 131 | img_ab = img_ab_mc * ab_norm + ab_mean 132 | pred_lab = torch.cat((img_l, img_ab), dim=0) 133 | grid_lab = pred_lab.numpy().astype("float64") 134 | return (np.clip(color.lab2rgb(grid_lab.transpose((1, 2, 0))), 0, 1) * 255).astype("uint8") 135 | 136 | 137 | def batch_lab2rgb_transpose_mc(img_l_mc, img_ab_mc, nrow=8): 138 | if isinstance(img_l_mc, Variable): 139 | img_l_mc = img_l_mc.data.cpu() 140 | if isinstance(img_ab_mc, Variable): 141 | img_ab_mc = img_ab_mc.data.cpu() 142 | 143 | if img_l_mc.is_cuda: 144 | img_l_mc = img_l_mc.cpu() 145 | if img_ab_mc.is_cuda: 146 | img_ab_mc = img_ab_mc.cpu() 147 | 148 | assert img_l_mc.dim() == 4 and img_ab_mc.dim() == 4, "only for batch input" 149 | 150 | img_l = img_l_mc * l_norm + l_mean 151 | img_ab = img_ab_mc * ab_norm + ab_mean 152 | pred_lab = torch.cat((img_l, img_ab), dim=1) 153 | grid_lab = vutils.make_grid(pred_lab, nrow=nrow).numpy().astype("float64") 154 | return (np.clip(color.lab2rgb(grid_lab.transpose((1, 2, 0))), 0, 1) * 255).astype("uint8") 155 | 156 | 157 | ###### loss functions ###### 158 | def feature_normalize(feature_in): 159 | feature_in_norm = torch.norm(feature_in, 2, 1, keepdim=True) + sys.float_info.epsilon 160 | feature_in_norm = torch.div(feature_in, feature_in_norm) 161 | return feature_in_norm 162 | 163 | 164 | def statistics_matching(feature1, feature2): 165 | N, C, H, W = feature1.shape 166 | feature1 = feature1.view(N, C, -1) 167 | feature2 = feature2.view(N, C, -1) 168 | 169 | mean1 = feature1.mean(dim=-1) 170 | mean2 = feature2.mean(dim=-1) 171 | std1 = feature1.var(dim=-1).sqrt() 172 | std2 = feature2.var(dim=-1).sqrt() 173 | 174 | return mse_loss(mean1, mean2) + mse_loss(std1, std2) 175 | 176 | 177 | def calc_ab_gradient(input_ab): 178 | x_grad = input_ab[:, :, :, 1:] - input_ab[:, :, :, :-1] 179 | y_grad = input_ab[:, :, 1:, :] - input_ab[:, :, :-1, :] 180 | return x_grad, y_grad 181 | 182 | 183 | def calc_tv_loss(input): 184 | x_grad = input[:, :, :, 1:] - input[:, :, :, :-1] 185 | y_grad = input[:, :, 1:, :] - input[:, :, :-1, :] 186 | return torch.sum(x_grad ** 2) / x_grad.nelement() + torch.sum(y_grad ** 2) / y_grad.nelement() 187 | 188 | 189 | def calc_cosine_dist_loss(input, target): 190 | input_norm = torch.norm(input, 2, 1, keepdim=True) + sys.float_info.epsilon 191 | target_norm = torch.norm(target, 2, 1, keepdim=True) + sys.float_info.epsilon 192 | normalized_input = torch.div(input, input_norm) 193 | normalized_target = torch.div(target, target_norm) 194 | cos_dist = torch.mul(normalized_input, normalized_target) 195 | return torch.mean(1 - torch.sum(cos_dist, dim=1)) 196 | 197 | 198 | ###### video related ####### 199 | def save_frames(image, image_folder, index=None): 200 | if image is not None: 201 | image = np.clip(image, 0, 255).astype(np.uint8) 202 | io.imsave(os.path.join(image_folder, "frame" + str(index).zfill(5) + ".jpg"), image) 203 | 204 | 205 | 206 | ###### file system ###### 207 | def get_size(start_path="."): 208 | total_size = 0 209 | for dirpath, dirnames, filenames in os.walk(start_path): 210 | for f in filenames: 211 | fp = os.path.join(dirpath, f) 212 | total_size += os.path.getsize(fp) 213 | return total_size 214 | 215 | 216 | def parse(parser, save=True): 217 | opt = parser.parse_args(args=[]) 218 | args = vars(opt) 219 | 220 | from time import gmtime, strftime 221 | 222 | print("------------ Options -------------") 223 | for k, v in sorted(args.items()): 224 | print("%s: %s" % (str(k), str(v))) 225 | print("-------------- End ----------------") 226 | 227 | # save to the disk 228 | if save: 229 | file_name = os.path.join("opt.txt") 230 | with open(file_name, "wt") as opt_file: 231 | opt_file.write(os.path.basename(sys.argv[0]) + " " + strftime("%Y-%m-%d %H:%M:%S", gmtime()) + "\n") 232 | opt_file.write("------------ Options -------------\n") 233 | for k, v in sorted(args.items()): 234 | opt_file.write("%s: %s\n" % (str(k), str(v))) 235 | opt_file.write("-------------- End ----------------\n") 236 | return opt 237 | 238 | 239 | ###### interactive ###### 240 | def clean_tensorboard(directory): 241 | folder_list = os.walk(directory).__next__()[1] 242 | for folder in folder_list: 243 | folder = directory + folder 244 | if get_size(folder) < 10000000: 245 | print("delete the folder of " + folder) 246 | shutil.rmtree(folder) 247 | 248 | 249 | def imshow(input_image, title=None, type_conversion=False): 250 | inp = input_image 251 | if type_conversion or type(input_image) is torch.Tensor: 252 | inp = input_image.numpy() 253 | else: 254 | inp = input_image 255 | fig = plt.figure() 256 | if inp.ndim == 2: 257 | fig = plt.imshow(inp, cmap="gray", clim=[0, 255]) 258 | else: 259 | fig = plt.imshow(np.transpose(inp, [1, 2, 0]).astype(np.uint8)) 260 | plt.axis("off") 261 | fig.axes.get_xaxis().set_visible(False) 262 | fig.axes.get_yaxis().set_visible(False) 263 | plt.title(title) 264 | 265 | 266 | def imshow_lab(input_lab): 267 | plt.imshow((batch_lab2rgb_transpose_mc(input_lab[:32, 0:1, :, :], input_lab[:32, 1:3, :, :])).astype(np.uint8)) 268 | 269 | 270 | ###### vgg preprocessing ###### 271 | def vgg_preprocess(tensor): 272 | # input is RGB tensor which ranges in [0,1] 273 | # output is BGR tensor which ranges in [0,255] 274 | tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1) 275 | tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view(1, 3, 1, 1) 276 | return tensor_bgr_ml * 255 277 | 278 | 279 | def torch_vgg_preprocess(tensor): 280 | # pytorch version normalization 281 | # note that both input and output are RGB tensors; 282 | # input and output ranges in [0,1] 283 | # normalize the tensor with mean and variance 284 | tensor_mc = tensor - torch.Tensor([0.485, 0.456, 0.406]).type_as(tensor).view(1, 3, 1, 1) 285 | return tensor_mc / torch.Tensor([0.229, 0.224, 0.225]).type_as(tensor_mc).view(1, 3, 1, 1) 286 | 287 | 288 | def network_gradient(net, gradient_on=True): 289 | for param in net.parameters(): 290 | param.requires_grad = bool(gradient_on) 291 | return net 292 | 293 | 294 | ##### color space 295 | xyz_from_rgb = np.array( 296 | [[0.412453, 0.357580, 0.180423], [0.212671, 0.715160, 0.072169], [0.019334, 0.119193, 0.950227]] 297 | ) 298 | rgb_from_xyz = np.array( 299 | [[3.24048134, -0.96925495, 0.05564664], [-1.53715152, 1.87599, -0.20404134], [-0.49853633, 0.04155593, 1.05731107]] 300 | ) 301 | 302 | 303 | def tensor_lab2rgb(input): 304 | """ 305 | n * 3* h *w 306 | """ 307 | input_trans = input.transpose(1, 2).transpose(2, 3) # n * h * w * 3 308 | L, a, b = input_trans[:, :, :, 0:1], input_trans[:, :, :, 1:2], input_trans[:, :, :, 2:] 309 | y = (L + 16.0) / 116.0 310 | x = (a / 500.0) + y 311 | z = y - (b / 200.0) 312 | 313 | neg_mask = z.data < 0 314 | z[neg_mask] = 0 315 | xyz = torch.cat((x, y, z), dim=3) 316 | 317 | mask = xyz.data > 0.2068966 318 | mask_xyz = xyz.clone() 319 | mask_xyz[mask] = torch.pow(xyz[mask], 3.0) 320 | mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.0) / 7.787 321 | mask_xyz[:, :, :, 0] = mask_xyz[:, :, :, 0] * 0.95047 322 | mask_xyz[:, :, :, 2] = mask_xyz[:, :, :, 2] * 1.08883 323 | 324 | rgb_trans = torch.mm(mask_xyz.view(-1, 3), torch.from_numpy(rgb_from_xyz).type_as(xyz)).view( 325 | input.size(0), input.size(2), input.size(3), 3 326 | ) 327 | rgb = rgb_trans.transpose(2, 3).transpose(1, 2) 328 | 329 | mask = rgb > 0.0031308 330 | mask_rgb = rgb.clone() 331 | mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055 332 | mask_rgb[~mask] = rgb[~mask] * 12.92 333 | 334 | neg_mask = mask_rgb.data < 0 335 | large_mask = mask_rgb.data > 1 336 | mask_rgb[neg_mask] = 0 337 | mask_rgb[large_mask] = 1 338 | return mask_rgb 339 | -------------------------------------------------------------------------------- /models/DVP/README.md: -------------------------------------------------------------------------------- 1 | # pytorch-deep-video-prior (DVP) 2 | Official PyTorch implementation for NeurIPS 2020 paper: Blind Video Temporal Consistency via Deep Video Prior 3 | 4 | [TensorFlow implementation](https://github.com/ChenyangLEI/deep-video-prior) 5 | | [paper](https://arxiv.org/abs/2010.11838) 6 | | [project website](https://chenyanglei.github.io/DVP/index.html) 7 | 8 | 9 | ## Introduction 10 | Our method is a general framework to improve the temporal consistency of video processed by image algorithms. 11 | 12 | For example, combining single image colorization or single image dehazing algorithm with our framework, we can achieve the goal of video colorization or video dehazing. 13 | 14 | 15 | 16 | 17 | 18 | 19 | ## Dependency 20 | 21 | ### Environment 22 | This code is based on PyTorch. It has been tested on Ubuntu 18.04 LTS. 23 | 24 | Anaconda is recommended: [Ubuntu 18.04](https://www.digitalocean.com/community/tutorials/how-to-install-the-anaconda-python-distribution-on-ubuntu-18-04) 25 | | [Ubuntu 16.04](https://www.digitalocean.com/community/tutorials/how-to-install-the-anaconda-python-distribution-on-ubuntu-16-04) 26 | 27 | After installing Anaconda, you can setup the environment simply by 28 | 29 | ``` 30 | conda env create -f environment.yml 31 | ``` 32 | 33 | 34 | ## Inference 35 | 36 | ### Demo 37 | ``` 38 | bash test.sh 39 | ``` 40 | The results will be saved in ./result 41 | 42 | ### Use your own data 43 | For the video with unimodal inconsistency: 44 | 45 | ``` 46 | python main_IRT.py --max_epoch 25 --input PATH_TO_YOUR_INPUT_FOLDER --processed PATH_TO_YOUR_PROCESSED_FOLDER --model NAME_OF_YOUR_MODEL --with_IRT 0 --IRT_initialization 0 --output ./result/OWN_DATA 47 | ``` 48 | 49 | For the video with multimodal inconsistency: 50 | 51 | ``` 52 | python main_IRT.py --max_epoch 25 --input PATH_TO_YOUR_INPUT_FOLDER --processed PATH_TO_YOUR_PROCESSED_FOLDER --model NAME_OF_YOUR_MODEL --with_IRT 1 --IRT_initialization 1 --output ./result/OWN_DATA 53 | ``` 54 | 55 | 56 | ## Citation 57 | If you find this work useful for your research, please cite: 58 | ``` 59 | @inproceedings{lei2020dvp, 60 | title={Blind Video Temporal Consistency via Deep Video Prior}, 61 | author={Lei, Chenyang and Xing, Yazhou and Chen, Qifeng}, 62 | booktitle={Advances in Neural Information Processing Systems}, 63 | year={2020} 64 | } 65 | ``` 66 | 67 | 68 | ## Contact 69 | Feel free to contact me if there is any question. (Yazhou Xing, yzxing87@gmail.com) 70 | -------------------------------------------------------------------------------- /models/DVP/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import cv2 4 | import torch 5 | from PIL import Image 6 | from tqdm import tqdm 7 | import numpy as np 8 | import torchvision.transforms as transform_lib 9 | import matplotlib.pyplot as plt 10 | 11 | from utils.util import download_zipfile, mkdir 12 | from utils.v2i import convert_frames_to_video 13 | 14 | class OPT(): 15 | pass 16 | 17 | class DVP(): 18 | def __init__(self): 19 | self.small = (320, 180) 20 | self.in_size = (0, 0) 21 | 22 | def test(self, black_white_path, colorized_path, output_path, opt=None): 23 | assert os.path.exists(black_white_path) and os.path.exists(colorized_path) 24 | 25 | self.downscale(black_white_path) 26 | self.downscale(colorized_path) 27 | 28 | os.system(f'python3 ./models/DVP/main_IRT.py --save_freq {opt.sf} --max_epoch {opt.me} --input {black_white_path} --processed {colorized_path} --model temp --with_IRT 1 --IRT_initialization 1 --output {opt.op}') 29 | 30 | frames_path = f"{opt.op}/temp_IRT1_initial1/{os.path.basename(black_white_path)}/00{opt.me}" 31 | 32 | self.upscale(frames_path) 33 | 34 | length = len(os.listdir(frames_path)) 35 | frames = [f"out_main_{str(i).zfill(5)}.jpg" for i in range(length)] 36 | convert_frames_to_video(frames_path, output_path, frames) 37 | 38 | def downscale(self, path): 39 | frames = os.listdir(path) 40 | 41 | frame = Image.open(os.path.join(path, frames[0])) 42 | self.in_size = frame.size 43 | 44 | for each in frames: 45 | img = Image.open(os.path.join(path, each)) 46 | img = img.resize(self.small, Image.ANTIALIAS) 47 | img.save(os.path.join(path, each)) 48 | 49 | def upscale(self, path): 50 | frames = os.listdir(path) 51 | for each in frames: 52 | img = Image.open(os.path.join(path, each)) 53 | img = img.resize(self.in_size, Image.ANTIALIAS) 54 | img.save(os.path.join(path, each)) -------------------------------------------------------------------------------- /models/DVP/environment.yml: -------------------------------------------------------------------------------- 1 | name: pytorch-DVP 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - blas=1.0=mkl 7 | - ca-certificates=2020.10.14=0 8 | - certifi=2020.11.8=py36h06a4308_0 9 | - cffi=1.14.4=py36h261ae71_0 10 | - cudatoolkit=10.0.130=0 11 | - cudnn=7.6.5=cuda10.0_0 12 | - freetype=2.10.4=h5ab3b9f_0 13 | - intel-openmp=2020.2=254 14 | - jpeg=9b=h024ee3a_2 15 | - lcms2=2.11=h396b838_0 16 | - ld_impl_linux-64=2.33.1=h53a641e_7 17 | - libedit=3.1.20191231=h14c3975_1 18 | - libffi=3.3=he6710b0_2 19 | - libgcc-ng=9.1.0=hdf63c60_0 20 | - libpng=1.6.37=hbc83047_0 21 | - libstdcxx-ng=9.1.0=hdf63c60_0 22 | - libtiff=4.1.0=h2733197_1 23 | - lz4-c=1.9.2=heb0550a_3 24 | - mkl=2020.2=256 25 | - mkl-service=2.3.0=py36he904b0f_0 26 | - mkl_fft=1.2.0=py36h23d657b_0 27 | - mkl_random=1.1.1=py36h0573a6f_0 28 | - ncurses=6.2=he6710b0_1 29 | - ninja=1.10.2=py36hff7bd54_0 30 | - numpy=1.19.2=py36h54aff64_0 31 | - numpy-base=1.19.2=py36hfa32c7d_0 32 | - olefile=0.46=py36_0 33 | - openssl=1.1.1h=h7b6447c_0 34 | - pillow=8.0.1=py36he98fc37_0 35 | - pip=20.3=py36h06a4308_0 36 | - pycparser=2.20=py_2 37 | - python=3.6.12=hcff3b4d_2 38 | - pytorch=1.1.0=cuda100py36he554f03_0 39 | - readline=8.0=h7b6447c_0 40 | - setuptools=50.3.2=py36h06a4308_2 41 | - six=1.15.0=py36h06a4308_0 42 | - sqlite=3.33.0=h62c20be_0 43 | - tk=8.6.10=hbc83047_0 44 | - torchvision=0.3.0=cuda100py36h72fc40a_0 45 | - wheel=0.36.0=pyhd3eb1b0_0 46 | - xz=5.2.5=h7b6447c_0 47 | - zlib=1.2.11=h7b6447c_3 48 | - zstd=1.4.5=h9ceee32_0 49 | - pip: 50 | - scipy==1.2.0 51 | prefix: /home/yazhou/anaconda3/envs/pytorch-DVP 52 | 53 | -------------------------------------------------------------------------------- /models/DVP/main_IRT.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | import os 5 | import time 6 | import scipy.io 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | from glob import glob 11 | import scipy.misc as sic 12 | import subprocess 13 | import models.network as net 14 | import argparse 15 | import random 16 | import imageio 17 | from vgg import VGG19 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--model", default='Test', type=str, help="Name of model") 21 | parser.add_argument("--save_freq", default=5, type=int, help="save frequency of epochs") 22 | parser.add_argument("--use_gpu", default=1, type=int, help="use gpu or not") 23 | parser.add_argument("--with_IRT", default=0, type=int, help="use IRT or not") 24 | parser.add_argument("--IRT_initialization", default=0, type=int, help="use initialization for IRT or not") 25 | parser.add_argument("--max_epoch", default=25, type=int, help="The max number of epochs for training") 26 | parser.add_argument("--input", default='./demo/colorization/goat_input', type=str, help="dir of input video") 27 | parser.add_argument("--processed", default='./demo/colorization/goat_processed', type=str, help="dir of processed video") 28 | parser.add_argument("--output", default='None', type=str, help="dir of output video") 29 | 30 | # set random seed 31 | seed = 2020 32 | np.random.seed(seed) 33 | torch.manual_seed(seed) 34 | random.seed(seed) 35 | 36 | # process arguments 37 | ARGS = parser.parse_args() 38 | print(ARGS) 39 | save_freq = ARGS.save_freq 40 | input_folder = ARGS.input 41 | processed_folder = ARGS.processed 42 | with_IRT = ARGS.with_IRT 43 | maxepoch = ARGS.max_epoch + 1 44 | model= ARGS.model 45 | task = "/{}_IRT{}_initial{}".format(model, with_IRT, ARGS.IRT_initialization) #Colorization, HDR, StyleTransfer, Dehazing 46 | 47 | # set gpu 48 | if ARGS.use_gpu: 49 | os.environ["CUDA_VISIBLE_DEVICES"]=str(np.argmax([int(x.split()[2]) 50 | for x in subprocess.Popen("nvidia-smi -q -d Memory | grep -A4 GPU | grep Free", shell=True, stdout=subprocess.PIPE).stdout.readlines()])) 51 | else: 52 | os.environ["CUDA_VISIBLE_DEVICES"] = '' 53 | 54 | device = torch.device("cuda:0" if ARGS.use_gpu else "cpu") 55 | 56 | # clear cuda cache 57 | torch.cuda.empty_cache() 58 | 59 | # define loss function 60 | def compute_error(real,fake): 61 | # return tf.reduce_mean(tf.abs(fake-real)) 62 | return torch.mean(torch.abs(fake-real)) 63 | 64 | def Lp_loss(x, y): 65 | vgg_real = VGG_19(normalize_batch(x)) 66 | vgg_fake = VGG_19(normalize_batch(y)) 67 | p0 = compute_error(normalize_batch(x), normalize_batch(y)) 68 | 69 | content_loss_list = [] 70 | content_loss_list.append(p0) 71 | feat_layers = {'conv1_2' : 1.0/2.6, 'conv2_2' : 1.0/4.8, 'conv3_2': 1.0/3.7, 'conv4_2':1.0/5.6, 'conv5_2':10.0/1.5} 72 | 73 | for layer, w in feat_layers.items(): 74 | pi = compute_error(vgg_real[layer], vgg_fake[layer]) 75 | content_loss_list.append(w * pi) 76 | 77 | content_loss = torch.sum(torch.stack(content_loss_list)) 78 | 79 | return content_loss 80 | 81 | loss_L2 = torch.nn.MSELoss() 82 | loss_L1 = torch.nn.L1Loss() 83 | 84 | 85 | # Define model . 86 | out_channels = 6 if with_IRT else 3 87 | net = net.UNet(in_channels=3, out_channels=out_channels, init_features=32) 88 | net.to(device) 89 | optimizer = torch.optim.Adam(net.parameters(), lr=0.0001) 90 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3000,8000], gamma=0.5) 91 | 92 | VGG_19 = VGG19(requires_grad=False).to(device) 93 | 94 | # prepare data 95 | input_folders = [input_folder] 96 | processed_folders = [processed_folder] 97 | 98 | 99 | def prepare_paired_input(task, id, input_names, processed_names, is_train=0): 100 | net_in = np.float32(imageio.imread(input_names[id]))/255.0 101 | if len(net_in.shape) == 2: 102 | net_in = np.tile(net_in[:,:,np.newaxis], [1,1,3]) 103 | net_gt = np.float32(imageio.imread(processed_names[id]))/255.0 104 | org_h,org_w = net_in.shape[:2] 105 | h = org_h // 32 * 32 106 | w = org_w // 32 * 32 107 | print(net_in.shape, net_gt.shape) 108 | return net_in[np.newaxis, :h, :w, :], net_gt[np.newaxis, :h, :w, :] 109 | 110 | # some functions 111 | def initialize_weights(model): 112 | for module in model.modules(): 113 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear): 114 | # nn.init.kaiming_normal_(module.weight) 115 | nn.init.xavier_normal_(module.weight) 116 | if module.bias is not None: 117 | module.bias.data.zero_() 118 | elif isinstance(module, nn.BatchNorm2d): 119 | module.weight.data.fill_(1) 120 | module.bias.data.zero_() 121 | 122 | def normalize_batch(batch): 123 | # Normalize batch using ImageNet mean and std 124 | mean = batch.new_tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1) 125 | std = batch.new_tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1) 126 | return (batch - mean) / std 127 | 128 | 129 | 130 | # start to train 131 | for folder_idx, input_folder in enumerate(input_folders): 132 | # -----------load data------------- 133 | input_names = sorted(glob(input_folders[folder_idx] + "/*")) 134 | processed_names = sorted(glob(processed_folders[folder_idx] + "/*")) 135 | if ARGS.output == "None": 136 | output_folder = "./result/{}".format(task + '/' + input_folder.split("/")[-2] + '/' + input_folder.split("/")[-1]) 137 | else: 138 | output_folder = ARGS.output + "/" + task + '/' + input_folder.split("/")[-1] 139 | print(output_folder, input_folders[folder_idx], processed_folders[folder_idx] ) 140 | 141 | num_of_sample = min(len(input_names), len(processed_names)) 142 | data_in_memory = [None] * num_of_sample #Speedup 143 | for id in range(min(len(input_names), len(processed_names))): #Speedup 144 | net_in,net_gt = prepare_paired_input(task, id, input_names, processed_names) #Speedup 145 | net_in = torch.from_numpy(net_in).permute(0,3,1,2).float().to(device) 146 | net_gt = torch.from_numpy(net_gt).permute(0,3,1,2).float().to(device) 147 | data_in_memory[id] = [net_in,net_gt] #Speedup 148 | 149 | # model re-initialization 150 | initialize_weights(net) 151 | 152 | step = 0 153 | for epoch in range(1,maxepoch): 154 | # -----------start to train------------- 155 | print("Processing epoch {}".format(epoch)) 156 | frame_id = 0 157 | if os.path.isdir("{}/{:04d}".format(output_folder, epoch)): 158 | continue 159 | else: 160 | os.makedirs("{}/{:04d}".format(output_folder, epoch)) 161 | if not os.path.isdir("{}/training".format(output_folder)): 162 | os.makedirs("{}/training".format(output_folder)) 163 | 164 | print(len(input_names), len(processed_names)) 165 | for id in range(num_of_sample): 166 | if with_IRT: 167 | if epoch < 6 and ARGS.IRT_initialization: 168 | net_in,net_gt = data_in_memory[0] #Option: 169 | prediction = net(net_in) 170 | 171 | crt_loss = loss_L1(prediction[:,:3,:,:], net_gt) + 0.9*loss_L1(prediction[:,3:,:,:], net_gt) 172 | 173 | else: 174 | net_in,net_gt = data_in_memory[id] 175 | prediction = net(net_in) 176 | 177 | prediction_main = prediction[:,:3,:,:] 178 | prediction_minor = prediction[:,3:,:,:] 179 | diff_map_main,_ = torch.max(torch.abs(prediction_main - net_gt) / (net_in+1e-1), dim=1, keepdim=True) 180 | diff_map_minor,_ = torch.max(torch.abs(prediction_minor - net_gt) / (net_in+1e-1), dim=1, keepdim=True) 181 | confidence_map = torch.lt(diff_map_main, diff_map_minor).repeat(1,3,1,1).float() 182 | crt_loss = loss_L1(prediction_main*confidence_map, net_gt*confidence_map) \ 183 | + loss_L1(prediction_minor*(1-confidence_map), net_gt*(1-confidence_map)) 184 | else: 185 | net_in,net_gt = data_in_memory[id] 186 | prediction = net(net_in) 187 | crt_loss = Lp_loss(prediction, net_gt) 188 | 189 | optimizer.zero_grad() 190 | crt_loss.backward() 191 | optimizer.step() 192 | 193 | frame_id+=1 194 | step+=1 195 | if step % 10 == 0: 196 | print("Image iter: {} {} {} || Loss: {:.4f} ".format(epoch, frame_id, step, crt_loss)) 197 | if step % 100 == 0 : 198 | net_in = net_in.permute(0,2,3,1).cpu().numpy() 199 | net_gt = net_gt.permute(0,2,3,1).cpu().numpy() 200 | prediction = prediction.detach().permute(0,2,3,1).cpu().numpy() 201 | if with_IRT: 202 | prediction = prediction[...,:3] 203 | imageio.imsave("{}/training/step{:06d}_{:06d}.jpg".format(output_folder, step, id), 204 | np.uint8(np.concatenate([net_in[0], prediction[0], net_gt[0]], axis=1).clip(0,1) * 255.0)) 205 | 206 | # # -----------save intermidiate results------------- 207 | if epoch % save_freq == 0: 208 | for id in range(num_of_sample): 209 | st=time.time() 210 | net_in,net_gt = data_in_memory[id] 211 | print("Test: {}-{} \r".format(id, num_of_sample)) 212 | 213 | with torch.no_grad(): 214 | prediction = net(net_in) 215 | net_in = net_in.permute(0,2,3,1).cpu().numpy() 216 | net_gt = net_gt.permute(0,2,3,1).cpu().numpy() 217 | prediction = prediction.detach().permute(0,2,3,1).cpu().numpy() 218 | 219 | if with_IRT: 220 | prediction_main = prediction[...,:3] 221 | prediction_minor = prediction[...,3:] 222 | diff_map_main = np.amax(np.absolute(prediction_main - net_gt) / (net_in+1e-1), axis=3, keepdims=True) 223 | diff_map_minor = np.amax(np.absolute(prediction_minor - net_gt) / (net_in+1e-1), axis=3, keepdims=True) 224 | confidence_map = np.tile(np.less(diff_map_main, diff_map_minor), (1,1,1,3)).astype('float32') 225 | 226 | imageio.imsave("{}/{:04d}/predictions_{:05d}.jpg".format(output_folder, epoch, id), 227 | np.uint8(np.concatenate([net_in[0,:,:,:3],prediction_main[0], prediction_minor[0],net_gt[0], confidence_map[0]], axis=1).clip(0,1) * 255.0)) 228 | imageio.imsave("{}/{:04d}/out_main_{:05d}.jpg".format(output_folder, epoch, id),np.uint8(prediction_main[0].clip(0,1) * 255.0)) 229 | imageio.imsave("{}/{:04d}/out_minor_{:05d}.jpg".format(output_folder, epoch, id),np.uint8(prediction_minor[0].clip(0,1) * 255.0)) 230 | 231 | else: 232 | 233 | imageio.imsave("{}/{:04d}/predictions_{:05d}.jpg".format(output_folder, epoch, id), 234 | np.uint8(np.concatenate([net_in[0,:,:,:3], prediction[0], net_gt[0]],axis=1).clip(0,1) * 255.0)) 235 | imageio.imsave("{}/{:04d}/out_main_{:05d}.jpg".format(output_folder, epoch, id), 236 | np.uint8(prediction[0].clip(0,1) * 255.0)) 237 | 238 | -------------------------------------------------------------------------------- /models/DVP/models/__pycache__/network.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wensi-ai/Colorizer/ef382f44885f9d083c02a83952db9b4a6b4149c1/models/DVP/models/__pycache__/network.cpython-36.pyc -------------------------------------------------------------------------------- /models/DVP/models/__pycache__/network.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wensi-ai/Colorizer/ef382f44885f9d083c02a83952db9b4a6b4149c1/models/DVP/models/__pycache__/network.cpython-39.pyc -------------------------------------------------------------------------------- /models/DVP/models/network.py: -------------------------------------------------------------------------------- 1 | 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class UNet(nn.Module): 9 | 10 | def __init__(self, in_channels=3, out_channels=1, init_features=32): 11 | super(UNet, self).__init__() 12 | 13 | features = init_features 14 | self.encoder1 = UNet._block(in_channels, features, name="enc1") 15 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 16 | self.encoder2 = UNet._block(features, features * 2, name="enc2") 17 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 18 | self.encoder3 = UNet._block(features * 2, features * 4, name="enc3") 19 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 20 | self.encoder4 = UNet._block(features * 4, features * 8, name="enc4") 21 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 22 | 23 | self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck") 24 | 25 | self.upconv4 = nn.Sequential( 26 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 27 | nn.Conv2d(features * 16, features * 8, kernel_size=3, padding=1) 28 | ) 29 | self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4") 30 | self.upconv3 = nn.Sequential( 31 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 32 | nn.Conv2d(features * 8, features * 4, kernel_size=3, padding=1) 33 | ) 34 | 35 | self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3") 36 | self.upconv2 = nn.Sequential( 37 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 38 | nn.Conv2d(features * 4, features * 2, kernel_size=3, padding=1) 39 | ) 40 | 41 | self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2") 42 | self.upconv1 = nn.Sequential( 43 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), 44 | nn.Conv2d(features * 2, features, kernel_size=3, padding=1) 45 | ) 46 | 47 | self.decoder1 = UNet._block(features * 2, features, name="dec1") 48 | 49 | self.conv = nn.Conv2d( 50 | in_channels=features, out_channels=out_channels, kernel_size=1 51 | ) 52 | 53 | def forward(self, x): 54 | enc1 = self.encoder1(x) 55 | enc2 = self.encoder2(self.pool1(enc1)) 56 | enc3 = self.encoder3(self.pool2(enc2)) 57 | enc4 = self.encoder4(self.pool3(enc3)) 58 | 59 | bottleneck = self.bottleneck(self.pool4(enc4)) 60 | 61 | dec4 = self.upconv4(bottleneck) 62 | dec4 = torch.cat((dec4, enc4), dim=1) 63 | dec4 = self.decoder4(dec4) 64 | dec3 = self.upconv3(dec4) 65 | dec3 = torch.cat((dec3, enc3), dim=1) 66 | dec3 = self.decoder3(dec3) 67 | dec2 = self.upconv2(dec3) 68 | dec2 = torch.cat((dec2, enc2), dim=1) 69 | dec2 = self.decoder2(dec2) 70 | dec1 = self.upconv1(dec2) 71 | dec1 = torch.cat((dec1, enc1), dim=1) 72 | dec1 = self.decoder1(dec1) 73 | out = self.conv(dec1) 74 | return out 75 | 76 | @staticmethod 77 | def _block(in_channels, features, name): 78 | return nn.Sequential( 79 | OrderedDict( 80 | [ 81 | ( 82 | name + "conv1", 83 | nn.Conv2d( 84 | in_channels=in_channels, 85 | out_channels=features, 86 | kernel_size=3, 87 | padding=1, 88 | bias=False, 89 | ), 90 | ), 91 | # (name + "norm1", nn.BatchNorm2d(num_features=features)), 92 | (name + "relu1", nn.ReLU(inplace=True)), 93 | ( 94 | name + "conv2", 95 | nn.Conv2d( 96 | in_channels=features, 97 | out_channels=features, 98 | kernel_size=3, 99 | padding=1, 100 | bias=False, 101 | ), 102 | ), 103 | # (name + "norm2", nn.BatchNorm2d(num_features=features)), 104 | (name + "relu2", nn.ReLU(inplace=True)), 105 | ] 106 | ) 107 | ) 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /models/DVP/options/test_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch.backends.cudnn as cudnn 3 | class TestOptions(): 4 | def __init__(self): 5 | self.sf = 10 # save frequency 6 | self.op = "test/results" 7 | self.me = 30 # max epoch 8 | 9 | def parse(self): 10 | return self -------------------------------------------------------------------------------- /models/DVP/test.sh: -------------------------------------------------------------------------------- 1 | python main_IRT.py --max_epoch 25 --input demo/colorization/goat_input --processed demo/colorization/goat_processed --model colorization --with_IRT 1 --IRT_initialization 1 --output ./result/colorization -------------------------------------------------------------------------------- /models/DVP/vgg.py: -------------------------------------------------------------------------------- 1 | from torchvision import models 2 | import torch.nn as nn 3 | import torch 4 | 5 | 6 | class VGG19(torch.nn.Module): 7 | """docstring for Vgg19""" 8 | def __init__(self, requires_grad=False): 9 | super(VGG19, self).__init__() 10 | vgg_pretrained_features = models.vgg19(pretrained=True).features 11 | self.slice1 = torch.nn.Sequential() 12 | self.slice2 = torch.nn.Sequential() 13 | self.slice3 = torch.nn.Sequential() 14 | self.slice4 = torch.nn.Sequential() 15 | self.slice5 = torch.nn.Sequential() 16 | for x in range(4): # conv1_2 17 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 18 | for x in range(4, 9): # conv2_2 19 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 20 | for x in range(9, 14): # conv3_2 21 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 22 | for x in range(14, 23): # conv4_2 23 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 24 | for x in range(23, 32): # conv5_2 25 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 26 | 27 | if not requires_grad: 28 | for param in self.parameters(): 29 | param.requires_grad = False 30 | 31 | def forward(self, X): 32 | h = self.slice1(X) 33 | h_relu1_2 = h 34 | h = self.slice2(h) 35 | h_relu2_2 = h 36 | h = self.slice3(h) 37 | h_relu3_2 = h 38 | h = self.slice4(h) 39 | h_relu4_2 = h 40 | h = self.slice5(h) 41 | h_relu5_2 = h 42 | 43 | out = {} 44 | out['conv1_2'] = h_relu1_2 45 | out['conv2_2'] = h_relu2_2 46 | out['conv3_2'] = h_relu3_2 47 | out['conv4_2'] = h_relu4_2 48 | out['conv5_2'] = h_relu5_2 49 | 50 | return out 51 | 52 | 53 | ''' 54 | Sequential( 55 | (0): Sequential( 56 | (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 57 | (1): ReLU(inplace) 58 | (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 59 | (3): ReLU(inplace) ##con1_2 60 | (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 61 | (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 62 | (6): ReLU(inplace) 63 | (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 64 | (8): ReLU(inplace) ##con2_2 65 | (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 66 | (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 67 | (11): ReLU(inplace) 68 | (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 69 | (13): ReLU(inplace) ##con3_2 70 | (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 71 | (15): ReLU(inplace) 72 | (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 73 | (17): ReLU(inplace) 74 | (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 75 | (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 76 | (20): ReLU(inplace) 77 | (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 78 | (22): ReLU(inplace) ##conv4_2 79 | (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 80 | (24): ReLU(inplace) 81 | (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 82 | (26): ReLU(inplace) 83 | (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 84 | (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 85 | (29): ReLU(inplace) 86 | (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 87 | (31): ReLU(inplace) ##conv5_2 88 | (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 89 | (33): ReLU(inplace) 90 | (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 91 | (35): ReLU(inplace) 92 | (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 93 | ) 94 | (1): AdaptiveAvgPool2d(output_size=(7, 7)) 95 | (2): Sequential( 96 | (0): Linear(in_features=25088, out_features=4096, bias=True) 97 | (1): ReLU(inplace) 98 | (2): Dropout(p=0.5) 99 | (3): Linear(in_features=4096, out_features=4096, bias=True) 100 | (4): ReLU(inplace) 101 | (5): Dropout(p=0.5) 102 | (6): Linear(in_features=4096, out_features=1000, bias=True) 103 | ) 104 | ) 105 | ''' 106 | -------------------------------------------------------------------------------- /models/InstaColor/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import importlib 4 | import numpy as np 5 | import torch 6 | import cv2 7 | from typing import List 8 | from tqdm import tqdm 9 | from utils.util import download_zipfile, mkdir 10 | from utils.v2i import convert_frames_to_video 11 | 12 | from models.InstaColor.models.base_model import BaseModel 13 | from models.InstaColor.utils import util 14 | from models.InstaColor.utils.datasets import * 15 | from utils.util import download_zipfile 16 | 17 | from detectron2 import model_zoo 18 | from detectron2.engine import DefaultPredictor 19 | from detectron2.config import get_cfg 20 | 21 | class InstaColor: 22 | def __init__(self, pretrained=True): 23 | # Bounding box predictor 24 | cfg = get_cfg() 25 | cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml")) 26 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 27 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml") 28 | self.predictor = DefaultPredictor(cfg) 29 | self.model = None 30 | 31 | # download pretrained model 32 | if pretrained is True: 33 | download_zipfile("https://docs.google.com/uc?export=download&id=1Xb-DKAA9ibCVLqm8teKd1MWk6imjwTBh&confirm=t", "InstaColor_checkpoints.zip") 34 | 35 | def find_model_using_name(self, model_name): 36 | # Given the option --model [modelname], 37 | # the file "models/modelname_model.py" 38 | # will be imported. 39 | model_filename = "models.InstaColor.models." + model_name + "_model" 40 | modellib = importlib.import_module(model_filename) 41 | 42 | # In the file, the class called ModelNameModel() will 43 | # be instantiated. It has to be a subclass of BaseModel, 44 | # and it is case-insensitive. 45 | model = None 46 | target_model_name = model_name.replace('_', '') + 'model' 47 | for name, cls in modellib.__dict__.items(): 48 | if name.lower() == target_model_name.lower() \ 49 | and issubclass(cls, BaseModel): 50 | model = cls 51 | 52 | if model is None: 53 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 54 | exit(0) 55 | 56 | return model 57 | 58 | 59 | def create_model(self, opt): 60 | model = self.find_model_using_name(opt.model) 61 | self.model = model() 62 | self.model.initialize(opt) 63 | print("model [%s] was created" % (self.model.name())) 64 | 65 | def test_images(self, input_dir: str, file_names: List[str], opt): 66 | """ 67 | Testing function 68 | """ 69 | # get bounding box for each image 70 | print("Getting bounding boxes...") 71 | with torch.no_grad(): 72 | for image_name in tqdm(file_names): 73 | img = cv2.imread(f"{input_dir}/{image_name}") 74 | lab_image = cv2.cvtColor(img, cv2.COLOR_BGR2LAB) 75 | l_channel, a_channel, b_channel = cv2.split(lab_image) 76 | l_stack = np.stack([l_channel, l_channel, l_channel], axis=2) 77 | outputs = self.predictor(l_stack) 78 | pred_bbox = outputs["instances"].pred_boxes.to(torch.device('cpu')).tensor.numpy() 79 | pred_scores = outputs["instances"].scores.cpu().data.numpy() 80 | np.savez(f"{opt.output_npz_dir}/{image_name.split('.')[0]}", bbox = pred_bbox, scores = pred_scores) 81 | 82 | # setup dataset loader 83 | opt.batch_size = 1 84 | opt.test_img_dir = input_dir 85 | dataset = Fusion_Testing_Dataset(opt, file_names, -1) 86 | dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size) 87 | 88 | # setup model to test 89 | self.create_model(opt) 90 | self.model.setup_to_test('coco_finetuned_mask_256_ffs') 91 | 92 | print("Colorizing...") 93 | # colorize image 94 | with torch.no_grad(): 95 | for data_raw in tqdm(dataset_loader): 96 | data_raw['full_img'][0] = data_raw['full_img'][0].cuda() 97 | if data_raw['empty_box'][0] == 0: 98 | data_raw['cropped_img'][0] = data_raw['cropped_img'][0].cuda() 99 | box_info = data_raw['box_info'][0] 100 | box_info_2x = data_raw['box_info_2x'][0] 101 | box_info_4x = data_raw['box_info_4x'][0] 102 | box_info_8x = data_raw['box_info_8x'][0] 103 | cropped_data = util.get_colorization_data(data_raw['cropped_img'], opt, ab_thresh=0, p=opt.sample_p) 104 | full_img_data = util.get_colorization_data(data_raw['full_img'], opt, ab_thresh=0, p=opt.sample_p) 105 | self.model.set_input(cropped_data) 106 | self.model.set_fusion_input(full_img_data, [box_info, box_info_2x, box_info_4x, box_info_8x]) 107 | self.model.forward() 108 | else: 109 | full_img_data = util.get_colorization_data(data_raw['full_img'], opt, ab_thresh=0, p=opt.sample_p) 110 | self.model.set_forward_without_box(full_img_data) 111 | self.model.save_current_imgs(os.path.join(opt.results_img_dir, data_raw['file_id'][0] + '.png')) 112 | 113 | def test(self, input_dir: str, output_path: str, opt): 114 | # get file names 115 | input_files = os.listdir(input_dir) 116 | #create colorized output save folder 117 | mkdir(opt.results_img_dir) 118 | #create bounding box save folder 119 | output_npz_dir = "{0}_bbox".format(input_dir) 120 | mkdir(output_npz_dir) 121 | opt.output_npz_dir = output_npz_dir 122 | 123 | for i in range(0, len(input_files), 5): 124 | self.test_images(input_dir, input_files[i : min(len(input_files), i + 5)], opt) 125 | 126 | output_frames = sorted(os.listdir(opt.results_img_dir)) 127 | # resize back 128 | for output in output_frames: 129 | img = Image.open(os.path.join(opt.results_img_dir, output)) 130 | img = img.resize((320, 180), Image.ANTIALIAS) 131 | img.save(os.path.join(opt.results_img_dir, output)) 132 | 133 | convert_frames_to_video(opt.results_img_dir, output_path) 134 | # remove bounding box dir 135 | shutil.rmtree(output_npz_dir) 136 | # shutil.rmtree(opt.results_img_dir) 137 | print("Task Complete") -------------------------------------------------------------------------------- /models/InstaColor/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch.utils.data 3 | from models.InstaColor.data.base_data_loader import BaseDataLoader 4 | from models.InstaColor.data.base_dataset import BaseDataset 5 | 6 | 7 | def find_dataset_using_name(dataset_name): 8 | # Given the option --dataset_mode [datasetname], 9 | # the file "data/datasetname_dataset.py" 10 | # will be imported. 11 | dataset_filename = "models.InstaColor.data." + dataset_name + "_dataset" 12 | datasetlib = importlib.import_module(dataset_filename) 13 | 14 | # In the file, the class called DatasetNameDataset() will 15 | # be instantiated. It has to be a subclass of BaseDataset, 16 | # and it is case-insensitive. 17 | dataset = None 18 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 19 | for name, cls in datasetlib.__dict__.items(): 20 | if name.lower() == target_dataset_name.lower() \ 21 | and issubclass(cls, BaseDataset): 22 | dataset = cls 23 | 24 | if dataset is None: 25 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 26 | exit(0) 27 | 28 | return dataset 29 | 30 | 31 | def get_option_setter(dataset_name): 32 | dataset_class = find_dataset_using_name(dataset_name) 33 | return dataset_class.modify_commandline_options 34 | 35 | 36 | def create_dataset(opt): 37 | dataset = find_dataset_using_name(opt.dataset_mode) 38 | instance = dataset() 39 | instance.initialize(opt) 40 | print("dataset [%s] was created" % (instance.name())) 41 | return instance 42 | 43 | 44 | def CreateDataLoader(opt): 45 | data_loader = CustomDatasetDataLoader() 46 | data_loader.initialize(opt) 47 | return data_loader 48 | 49 | 50 | # Wrapper class of Dataset class that performs 51 | # multi-threaded data loading 52 | class CustomDatasetDataLoader(BaseDataLoader): 53 | def name(self): 54 | return 'CustomDatasetDataLoader' 55 | 56 | def initialize(self, opt): 57 | BaseDataLoader.initialize(self, opt) 58 | self.dataset = create_dataset(opt) 59 | self.dataloader = torch.utils.data.DataLoader( 60 | self.dataset, 61 | batch_size=opt.batch_size, 62 | shuffle=not opt.serial_batches, 63 | num_workers=int(opt.num_threads)) 64 | 65 | def load_data(self): 66 | return self 67 | 68 | def __len__(self): 69 | return min(len(self.dataset), self.opt.max_dataset_size) 70 | 71 | def __iter__(self): 72 | for i, data in enumerate(self.dataloader): 73 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 74 | break 75 | yield data 76 | -------------------------------------------------------------------------------- /models/InstaColor/data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import torchvision.transforms as transforms 4 | import torch 5 | from models.InstaColor.data.base_dataset import BaseDataset 6 | from models.InstaColor.data.image_folder import make_dataset 7 | from PIL import Image 8 | 9 | 10 | class AlignedDataset(BaseDataset): 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train): 13 | return parser 14 | 15 | def initialize(self, opt): 16 | self.opt = opt 17 | self.root = opt.dataroot 18 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) 19 | self.AB_paths = sorted(make_dataset(self.dir_AB)) 20 | assert(opt.resize_or_crop == 'resize_and_crop') 21 | 22 | def __getitem__(self, index): 23 | AB_path = self.AB_paths[index] 24 | AB = Image.open(AB_path).convert('RGB') 25 | w, h = AB.size 26 | w2 = int(w / 2) 27 | A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 28 | B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 29 | A = transforms.ToTensor()(A) 30 | B = transforms.ToTensor()(B) 31 | w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) 32 | h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) 33 | 34 | A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] 35 | B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] 36 | 37 | A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) 38 | B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B) 39 | 40 | if self.opt.which_direction == 'BtoA': 41 | input_nc = self.opt.output_nc 42 | output_nc = self.opt.input_nc 43 | else: 44 | input_nc = self.opt.input_nc 45 | output_nc = self.opt.output_nc 46 | 47 | if (not self.opt.no_flip) and random.random() < 0.5: 48 | idx = [i for i in range(A.size(2) - 1, -1, -1)] 49 | idx = torch.LongTensor(idx) 50 | A = A.index_select(2, idx) 51 | B = B.index_select(2, idx) 52 | 53 | if input_nc == 1: # RGB to gray 54 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 55 | A = tmp.unsqueeze(0) 56 | 57 | if output_nc == 1: # RGB to gray 58 | tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 59 | B = tmp.unsqueeze(0) 60 | 61 | return {'A': A, 'B': B, 62 | 'A_paths': AB_path, 'B_paths': AB_path} 63 | 64 | def __len__(self): 65 | return len(self.AB_paths) 66 | 67 | def name(self): 68 | return 'AlignedDataset' 69 | -------------------------------------------------------------------------------- /models/InstaColor/data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | class BaseDataLoader(): 2 | def __init__(self): 3 | pass 4 | 5 | def initialize(self, opt): 6 | self.opt = opt 7 | pass 8 | 9 | def load_data(): 10 | return None 11 | -------------------------------------------------------------------------------- /models/InstaColor/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | 6 | class BaseDataset(data.Dataset): 7 | def __init__(self): 8 | super(BaseDataset, self).__init__() 9 | 10 | def name(self): 11 | return 'BaseDataset' 12 | 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train): 15 | return parser 16 | 17 | def initialize(self, opt): 18 | pass 19 | 20 | def __len__(self): 21 | return 0 22 | 23 | 24 | def get_transform(opt): 25 | transform_list = [] 26 | if opt.resize_or_crop == 'resize_and_crop': 27 | osize = [opt.loadSize, opt.loadSize] 28 | transform_list.append(transforms.Resize(osize, Image.BICUBIC)) 29 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 30 | elif opt.resize_or_crop == 'crop': 31 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 32 | elif opt.resize_or_crop == 'scale_width': 33 | transform_list.append(transforms.Lambda( 34 | lambda img: __scale_width(img, opt.fineSize))) 35 | elif opt.resize_or_crop == 'scale_width_and_crop': 36 | transform_list.append(transforms.Lambda( 37 | lambda img: __scale_width(img, opt.loadSize))) 38 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 39 | elif opt.resize_or_crop == 'none': 40 | transform_list.append(transforms.Lambda( 41 | lambda img: __adjust(img))) 42 | else: 43 | raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop) 44 | 45 | if opt.isTrain and not opt.no_flip: 46 | transform_list.append(transforms.RandomHorizontalFlip()) 47 | 48 | transform_list += [transforms.ToTensor(), 49 | transforms.Normalize((0.5, 0.5, 0.5), 50 | (0.5, 0.5, 0.5))] 51 | return transforms.Compose(transform_list) 52 | 53 | # just modify the width and height to be multiple of 4 54 | 55 | 56 | def __adjust(img): 57 | ow, oh = img.size 58 | 59 | # the size needs to be a multiple of this number, 60 | # because going through generator network may change img size 61 | # and eventually cause size mismatch error 62 | mult = 4 63 | if ow % mult == 0 and oh % mult == 0: 64 | return img 65 | w = (ow - 1) // mult 66 | w = (w + 1) * mult 67 | h = (oh - 1) // mult 68 | h = (h + 1) * mult 69 | 70 | if ow != w or oh != h: 71 | __print_size_warning(ow, oh, w, h) 72 | 73 | return img.resize((w, h), Image.BICUBIC) 74 | 75 | 76 | def __scale_width(img, target_width): 77 | ow, oh = img.size 78 | 79 | # the size needs to be a multiple of this number, 80 | # because going through generator network may change img size 81 | # and eventually cause size mismatch error 82 | mult = 4 83 | assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult 84 | if (ow == target_width and oh % mult == 0): 85 | return img 86 | w = target_width 87 | target_height = int(target_width * oh / ow) 88 | m = (target_height - 1) // mult 89 | h = (m + 1) * mult 90 | 91 | if target_height != h: 92 | __print_size_warning(target_width, target_height, w, h) 93 | 94 | return img.resize((w, h), Image.BICUBIC) 95 | 96 | 97 | def __print_size_warning(ow, oh, w, h): 98 | if not hasattr(__print_size_warning, 'has_printed'): 99 | print("The image size needs to be a multiple of 4. " 100 | "The loaded image size was (%d, %d), so it was adjusted to " 101 | "(%d, %d). This adjustment will be done to all images " 102 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 103 | __print_size_warning.has_printed = True 104 | -------------------------------------------------------------------------------- /models/InstaColor/data/color_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from models.InstaColor.data.base_dataset import BaseDataset, get_transform 3 | from models.InstaColor.data.image_folder import make_dataset 4 | from PIL import Image 5 | 6 | 7 | class ColorDataset(BaseDataset): 8 | @staticmethod 9 | def modify_commandline_options(parser, is_train): 10 | return parser 11 | 12 | def initialize(self, opt): 13 | self.opt = opt 14 | self.root = opt.dataroot 15 | self.dir_A = os.path.join(opt.dataroot) 16 | 17 | self.A_paths = make_dataset(self.dir_A) 18 | 19 | self.A_paths = sorted(self.A_paths) 20 | 21 | self.transform = get_transform(opt) 22 | 23 | def __getitem__(self, index): 24 | A_path = self.A_paths[index] 25 | A_img = Image.open(A_path).convert('RGB') 26 | A = self.transform(A_img) 27 | if self.opt.which_direction == 'BtoA': 28 | input_nc = self.opt.output_nc 29 | else: 30 | input_nc = self.opt.input_nc 31 | 32 | # convert to Lab 33 | # rgb2lab(A_img) 34 | 35 | if input_nc == 1: # RGB to gray 36 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 37 | A = tmp.unsqueeze(0) 38 | 39 | return {'A': A, 'A_paths': A_path} 40 | 41 | def __len__(self): 42 | return len(self.A_paths) 43 | 44 | def name(self): 45 | return 'ColorImageDataset' 46 | -------------------------------------------------------------------------------- /models/InstaColor/data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | 34 | return images 35 | 36 | 37 | def default_loader(path): 38 | return Image.open(path).convert('RGB') 39 | 40 | 41 | class ImageFolder(data.Dataset): 42 | 43 | def __init__(self, root, transform=None, return_paths=False, 44 | loader=default_loader): 45 | imgs = make_dataset(root) 46 | if len(imgs) == 0: 47 | raise(RuntimeError("Found 0 images in: " + root + "\n" 48 | "Supported image extensions are: " + 49 | ",".join(IMG_EXTENSIONS))) 50 | 51 | self.root = root 52 | self.imgs = imgs 53 | self.transform = transform 54 | self.return_paths = return_paths 55 | self.loader = loader 56 | 57 | def __getitem__(self, index): 58 | path = self.imgs[index] 59 | img = self.loader(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.return_paths: 63 | return img, path 64 | else: 65 | return img 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | -------------------------------------------------------------------------------- /models/InstaColor/data/single_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from models.InstaColor.data.base_dataset import BaseDataset, get_transform 3 | from models.InstaColor.data.image_folder import make_dataset 4 | from PIL import Image 5 | 6 | 7 | class SingleDataset(BaseDataset): 8 | @staticmethod 9 | def modify_commandline_options(parser, is_train): 10 | return parser 11 | 12 | def initialize(self, opt): 13 | self.opt = opt 14 | self.root = opt.dataroot 15 | self.dir_A = os.path.join(opt.dataroot) 16 | 17 | self.A_paths = make_dataset(self.dir_A) 18 | 19 | self.A_paths = sorted(self.A_paths) 20 | 21 | self.transform = get_transform(opt) 22 | 23 | def __getitem__(self, index): 24 | A_path = self.A_paths[index] 25 | A_img = Image.open(A_path).convert('RGB') 26 | A = self.transform(A_img) 27 | if self.opt.which_direction == 'BtoA': 28 | input_nc = self.opt.output_nc 29 | else: 30 | input_nc = self.opt.input_nc 31 | 32 | if input_nc == 1: # RGB to gray 33 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 34 | A = tmp.unsqueeze(0) 35 | 36 | return {'A': A, 'A_paths': A_path} 37 | 38 | def __len__(self): 39 | return len(self.A_paths) 40 | 41 | def name(self): 42 | return 'SingleImageDataset' 43 | -------------------------------------------------------------------------------- /models/InstaColor/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from models.InstaColor.models.base_model import BaseModel 3 | 4 | def find_model_using_name(model_name): 5 | # Given the option --model [modelname], 6 | # the file "models/modelname_model.py" 7 | # will be imported. 8 | model_filename = "models.InstaColor.models." + model_name + "_model" 9 | modellib = importlib.import_module(model_filename) 10 | 11 | # In the file, the class called ModelNameModel() will 12 | # be instantiated. It has to be a subclass of BaseModel, 13 | # and it is case-insensitive. 14 | model = None 15 | target_model_name = model_name.replace('_', '') + 'model' 16 | for name, cls in modellib.__dict__.items(): 17 | if name.lower() == target_model_name.lower() \ 18 | and issubclass(cls, BaseModel): 19 | model = cls 20 | 21 | if model is None: 22 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 23 | exit(0) 24 | 25 | return model 26 | 27 | def get_option_setter(model_name): 28 | model_class = find_model_using_name(model_name) 29 | return model_class.modify_commandline_options -------------------------------------------------------------------------------- /models/InstaColor/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from . import networks 5 | 6 | 7 | class BaseModel(): 8 | 9 | # modify parser to add command line options, 10 | # and also change the default values if needed 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train): 13 | return parser 14 | 15 | def name(self): 16 | return 'BaseModel' 17 | 18 | def initialize(self, opt): 19 | self.opt = opt 20 | self.gpu_ids = opt.gpu_ids 21 | self.isTrain = opt.isTrain 22 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 23 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 24 | if opt.resize_or_crop != 'scale_width': 25 | torch.backends.cudnn.benchmark = True 26 | self.loss_names = [] 27 | self.model_names = [] 28 | self.visual_names = [] 29 | self.image_paths = [] 30 | 31 | def set_input(self, input): 32 | self.input = input 33 | 34 | def forward(self): 35 | pass 36 | 37 | # load and print networks; create schedulers 38 | def setup(self, opt, parser=None): 39 | if self.isTrain: 40 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 41 | 42 | if not self.isTrain or opt.load_model: 43 | self.load_networks(opt.which_epoch) 44 | 45 | # make models eval mode during test time 46 | def eval(self): 47 | for name in self.model_names: 48 | if isinstance(name, str): 49 | net = getattr(self, 'net' + name) 50 | net.eval() 51 | 52 | # used in test time, wrapping `forward` in no_grad() so we don't save 53 | # intermediate steps for backprop 54 | def test(self, compute_losses=False): 55 | with torch.no_grad(): 56 | self.forward() 57 | if(compute_losses): 58 | self.compute_losses_G() 59 | 60 | # get image paths 61 | def get_image_paths(self): 62 | return self.image_paths 63 | 64 | def optimize_parameters(self): 65 | pass 66 | 67 | # update learning rate (called once every epoch) 68 | def update_learning_rate(self): 69 | for scheduler in self.schedulers: 70 | scheduler.step() 71 | lr = self.optimizers[0].param_groups[0]['lr'] 72 | # print('learning rate = %.7f' % lr) 73 | 74 | # return visualization images. train.py will display these images, and save the images to a html 75 | def get_current_visuals(self): 76 | visual_ret = OrderedDict() 77 | for name in self.visual_names: 78 | if isinstance(name, str): 79 | visual_ret[name] = getattr(self, name) 80 | return visual_ret 81 | 82 | # return traning losses/errors. train.py will print out these errors as debugging information 83 | def get_current_losses(self): 84 | errors_ret = OrderedDict() 85 | for name in self.loss_names: 86 | if isinstance(name, str): 87 | # float(...) works for both scalar tensor and float number 88 | errors_ret[name] = float(getattr(self, 'loss_' + name)) 89 | return errors_ret 90 | 91 | # save models to the disk 92 | def save_networks(self, which_epoch): 93 | for name in self.model_names: 94 | if isinstance(name, str): 95 | save_filename = '%s_net_%s.pth' % (which_epoch, name) 96 | save_path = os.path.join(self.save_dir, save_filename) 97 | net = getattr(self, 'net' + name) 98 | 99 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 100 | torch.save(net.module.cpu().state_dict(), save_path) 101 | net.cuda(self.gpu_ids[0]) 102 | else: 103 | torch.save(net.cpu().state_dict(), save_path) 104 | 105 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 106 | key = keys[i] 107 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 108 | if module.__class__.__name__.startswith('InstanceNorm') and \ 109 | (key == 'running_mean' or key == 'running_var'): 110 | if getattr(module, key) is None: 111 | state_dict.pop('.'.join(keys)) 112 | if module.__class__.__name__.startswith('InstanceNorm') and \ 113 | (key == 'num_batches_tracked'): 114 | state_dict.pop('.'.join(keys)) 115 | else: 116 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 117 | 118 | # load models from the disk 119 | def load_networks(self, which_epoch): 120 | for name in self.model_names: 121 | if isinstance(name, str): 122 | load_filename = '%s_net_%s.pth' % (which_epoch, name) 123 | load_path = os.path.join(self.save_dir, load_filename) 124 | if os.path.isfile(load_path) is False: 125 | continue 126 | net = getattr(self, 'net' + name) 127 | if isinstance(net, torch.nn.DataParallel): 128 | net = net.module 129 | print('loading the model from %s' % load_path) 130 | # if you are using PyTorch newer than 0.4 (e.g., built from 131 | # GitHub source), you can remove str() on self.device 132 | state_dict = torch.load(load_path, map_location=str(self.device)) 133 | if hasattr(state_dict, '_metadata'): 134 | del state_dict._metadata 135 | 136 | # patch InstanceNorm checkpoints prior to 0.4 137 | # for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 138 | # self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 139 | net.load_state_dict(state_dict, strict=False) 140 | 141 | # print network information 142 | def print_networks(self, verbose): 143 | print('---------- Networks initialized -------------') 144 | for name in self.model_names: 145 | if isinstance(name, str): 146 | net = getattr(self, 'net' + name) 147 | num_params = 0 148 | for param in net.parameters(): 149 | num_params += param.numel() 150 | if verbose: 151 | print(net) 152 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 153 | print('-----------------------------------------------') 154 | 155 | # set requies_grad=Fasle to avoid computation 156 | def set_requires_grad(self, nets, requires_grad=False): 157 | if not isinstance(nets, list): 158 | nets = [nets] 159 | for net in nets: 160 | if net is not None: 161 | for param in net.parameters(): 162 | param.requires_grad = requires_grad 163 | -------------------------------------------------------------------------------- /models/InstaColor/models/fusion_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("..") 3 | 4 | import torch 5 | from utils import color_format 6 | from models.InstaColor.utils import util 7 | from .base_model import BaseModel 8 | from . import networks 9 | import numpy as np 10 | from skimage import io 11 | from skimage import img_as_ubyte 12 | 13 | 14 | 15 | class FusionModel(BaseModel): 16 | def name(self): 17 | return 'FusionModel' 18 | 19 | @staticmethod 20 | def modify_commandline_options(parser, is_train=True): 21 | return parser 22 | 23 | def initialize(self, opt): 24 | BaseModel.initialize(self, opt) 25 | self.model_names = ['G', 'GF'] 26 | 27 | # load/define networks 28 | num_in = opt.input_nc + opt.output_nc + 1 29 | 30 | self.netG = networks.define_G(num_in, opt.output_nc, opt.ngf, 31 | 'instance', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, 32 | use_tanh=True, classification=False) 33 | self.netG.eval() 34 | 35 | self.netGF = networks.define_G(num_in, opt.output_nc, opt.ngf, 36 | 'fusion', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, 37 | use_tanh=True, classification=False) 38 | self.netGF.eval() 39 | 40 | self.netGComp = networks.define_G(num_in, opt.output_nc, opt.ngf, 41 | 'siggraph', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, 42 | use_tanh=True, classification=opt.classification) 43 | self.netGComp.eval() 44 | 45 | 46 | def set_input(self, input): 47 | AtoB = self.opt.which_direction == 'AtoB' 48 | self.real_A = input['A' if AtoB else 'B'].to(self.device) 49 | self.real_B = input['B' if AtoB else 'A'].to(self.device) 50 | self.hint_B = input['hint_B'].to(self.device) 51 | 52 | self.mask_B = input['mask_B'].to(self.device) 53 | self.mask_B_nc = self.mask_B + self.opt.mask_cent 54 | 55 | self.real_B_enc = util.encode_ab_ind(self.real_B[:, :, ::4, ::4], self.opt) 56 | 57 | def set_fusion_input(self, input, box_info): 58 | AtoB = self.opt.which_direction == 'AtoB' 59 | self.full_real_A = input['A' if AtoB else 'B'].to(self.device) 60 | self.full_real_B = input['B' if AtoB else 'A'].to(self.device) 61 | 62 | self.full_hint_B = input['hint_B'].to(self.device) 63 | self.full_mask_B = input['mask_B'].to(self.device) 64 | 65 | self.full_mask_B_nc = self.full_mask_B + self.opt.mask_cent 66 | self.full_real_B_enc = util.encode_ab_ind(self.full_real_B[:, :, ::4, ::4], self.opt) 67 | self.box_info_list = box_info 68 | 69 | def set_forward_without_box(self, input): 70 | AtoB = self.opt.which_direction == 'AtoB' 71 | self.full_real_A = input['A' if AtoB else 'B'].to(self.device) 72 | self.full_real_B = input['B' if AtoB else 'A'].to(self.device) 73 | # self.image_paths = input['A_paths' if AtoB else 'B_paths'] 74 | self.full_hint_B = input['hint_B'].to(self.device) 75 | self.full_mask_B = input['mask_B'].to(self.device) 76 | self.full_mask_B_nc = self.full_mask_B + self.opt.mask_cent 77 | self.full_real_B_enc = util.encode_ab_ind(self.full_real_B[:, :, ::4, ::4], self.opt) 78 | 79 | (_, self.comp_B_reg) = self.netGComp(self.full_real_A, self.full_hint_B, self.full_mask_B) 80 | self.fake_B_reg = self.comp_B_reg 81 | 82 | def forward(self): 83 | (_, feature_map) = self.netG(self.real_A, self.hint_B, self.mask_B) 84 | self.fake_B_reg = self.netGF(self.full_real_A, self.full_hint_B, self.full_mask_B, feature_map, self.box_info_list) 85 | 86 | def save_current_imgs(self, path): 87 | out_img = torch.clamp(color_format.lab2rgb(torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt), 0.0, 1.0) 88 | out_img = np.transpose(out_img.cpu().data.numpy()[0], (1, 2, 0)) 89 | io.imsave(path, img_as_ubyte(out_img)) 90 | 91 | def setup_to_test(self, fusion_weight_path): 92 | GF_path = 'checkpoints/{0}/latest_net_GF.pth'.format(fusion_weight_path) 93 | print('load Fusion model from %s' % GF_path) 94 | GF_state_dict = torch.load(GF_path) 95 | 96 | # G_path = 'checkpoints/coco_finetuned_mask_256/latest_net_G.pth' # fine tuned on cocostuff 97 | G_path = 'checkpoints/{0}/latest_net_G.pth'.format(fusion_weight_path) 98 | G_state_dict = torch.load(G_path) 99 | 100 | # GComp_path = 'checkpoints/siggraph_retrained/latest_net_G.pth' # original net 101 | # GComp_path = 'checkpoints/coco_finetuned_mask_256/latest_net_GComp.pth' # fine tuned on cocostuff 102 | GComp_path = 'checkpoints/{0}/latest_net_GComp.pth'.format(fusion_weight_path) 103 | GComp_state_dict = torch.load(GComp_path) 104 | 105 | self.netGF.load_state_dict(GF_state_dict, strict=False) 106 | self.netG.module.load_state_dict(G_state_dict, strict=False) 107 | self.netGComp.module.load_state_dict(GComp_state_dict, strict=False) 108 | self.netGF.eval() 109 | self.netG.eval() 110 | self.netGComp.eval() -------------------------------------------------------------------------------- /models/InstaColor/options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from models.InstaColor.utils import util 4 | import torch 5 | import models.InstaColor.models as models 6 | import models.InstaColor.data as data 7 | 8 | 9 | class BaseOptions(): 10 | def __init__(self): 11 | self.initialized = False 12 | 13 | def initialize(self, parser): 14 | parser.add_argument('--batch_size', type=int, default=25, help='input batch size') 15 | parser.add_argument('--loadSize', type=int, default=256, help='scale images to this size') 16 | parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size') 17 | parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels') 18 | parser.add_argument('--output_nc', type=int, default=2, help='# of output image channels') 19 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 20 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 21 | parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD') 22 | parser.add_argument('--which_model_netG', type=str, default='siggraph', help='selects model to use for netG') 23 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') 24 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 25 | parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [unaligned | aligned | single]') 26 | parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA') 27 | parser.add_argument('--nThreads', default=4, type=int, help='# threads for loading data') 28 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 29 | parser.add_argument('--norm', type=str, default='batch', help='instance normalization or batch normalization') 30 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 31 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 32 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 33 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 34 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 35 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 36 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), 37 | help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 38 | parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 39 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 40 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 41 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 42 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{which_model_netG}_size{loadSize}') 43 | parser.add_argument('--ab_norm', type=float, default=110., help='colorization normalization factor') 44 | parser.add_argument('--ab_max', type=float, default=110., help='maximimum ab value') 45 | parser.add_argument('--ab_quant', type=float, default=10., help='quantization factor') 46 | parser.add_argument('--l_norm', type=float, default=100., help='colorization normalization factor') 47 | parser.add_argument('--l_cent', type=float, default=50., help='colorization centering factor') 48 | parser.add_argument('--mask_cent', type=float, default=.5, help='mask centering factor') 49 | parser.add_argument('--sample_p', type=float, default=1.0, help='sampling geometric distribution, 1.0 means no hints') 50 | parser.add_argument('--sample_Ps', type=int, nargs='+', default=[1, 2, 3, 4, 5, 6, 7, 8, 9, ], help='patch sizes') 51 | 52 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 53 | parser.add_argument('--classification', action='store_true', help='backprop trunk using classification, otherwise use regression') 54 | parser.add_argument('--phase', type=str, default='val', help='train_small, train, val, test, etc') 55 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 56 | parser.add_argument('--how_many', type=int, default=5, help='how many test images to run') 57 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 58 | 59 | parser.add_argument('--load_model', action='store_true', help='load the latest model') 60 | parser.add_argument('--half', action='store_true', help='half precision model') 61 | 62 | self.initialized = True 63 | return parser 64 | 65 | def gather_options(self): 66 | # initialize parser with basic options 67 | if not self.initialized: 68 | parser = argparse.ArgumentParser( 69 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 70 | parser = self.initialize(parser) 71 | 72 | # get the basic options 73 | opt, _ = parser.parse_known_args() 74 | 75 | # modify model-related parser options 76 | model_name = opt.model 77 | model_option_setter = models.get_option_setter(model_name) 78 | parser = model_option_setter(parser, self.isTrain) 79 | opt, _ = parser.parse_known_args() # parse again with the new defaults 80 | 81 | # modify dataset-related parser options 82 | dataset_name = opt.dataset_mode 83 | dataset_option_setter = data.get_option_setter(dataset_name) 84 | parser = dataset_option_setter(parser, self.isTrain) 85 | 86 | self.parser = parser 87 | 88 | return parser.parse_args() 89 | 90 | def print_options(self, opt): 91 | message = '' 92 | message += '----------------- Options ---------------\n' 93 | for k, v in sorted(vars(opt).items()): 94 | comment = '' 95 | default = self.parser.get_default(k) 96 | if v != default: 97 | comment = '\t[default: %s]' % str(default) 98 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 99 | message += '----------------- End -------------------' 100 | print(message) 101 | 102 | # save to the disk 103 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 104 | util.mkdirs(expr_dir) 105 | file_name = os.path.join(expr_dir, 'opt.txt') 106 | with open(file_name, 'wt') as opt_file: 107 | opt_file.write(message) 108 | opt_file.write('\n') 109 | 110 | def parse(self): 111 | 112 | opt = self.gather_options() 113 | opt.isTrain = self.isTrain # train or test 114 | 115 | # process opt.suffix 116 | if opt.suffix: 117 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 118 | opt.name = opt.name + suffix 119 | 120 | # self.print_options(opt) 121 | 122 | # set gpu ids 123 | str_ids = opt.gpu_ids.split(',') 124 | opt.gpu_ids = [] 125 | for str_id in str_ids: 126 | id = int(str_id) 127 | if id >= 0: 128 | opt.gpu_ids.append(id) 129 | if len(opt.gpu_ids) > 0: 130 | torch.cuda.set_device(opt.gpu_ids[0]) 131 | opt.A = 2 * opt.ab_max / opt.ab_quant + 1 132 | opt.B = opt.A 133 | 134 | self.opt = opt 135 | return self.opt 136 | -------------------------------------------------------------------------------- /models/InstaColor/options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | class TestOptions(BaseOptions): 4 | def initialize(self, parser): 5 | BaseOptions.initialize(self, parser) 6 | parser.add_argument('--test_img_dir', type=str, default='example', help='testing images folder') 7 | parser.add_argument('--results_img_dir', type=str, default='test/results', help='save the results image folder') 8 | parser.add_argument('--name', type=str, default='test_fusion', help='name of the experiment. It decides where to store samples and models') 9 | parser.add_argument('--model', type=str, default='fusion', 10 | help='chooses which model to use. cycle_gan, pix2pix, test') 11 | parser.add_argument('--display_freq', type=int, default=2000, help='frequency of showing training results on screen') 12 | parser.add_argument('--display_ncols', type=int, default=5, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 13 | parser.add_argument('--update_html_freq', type=int, default=10000, help='frequency of saving training results to html') 14 | parser.add_argument('--print_freq', type=int, default=2000, help='frequency of showing training results on console') 15 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 16 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs') 17 | parser.add_argument('--epoch_count', type=int, default=0, help='the starting epoch count, we save the model by , +, ...') 18 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 19 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 20 | parser.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam') 21 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') 22 | parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 23 | parser.add_argument('--lambda_GAN', type=float, default=0., help='weight for GAN loss') 24 | parser.add_argument('--lambda_A', type=float, default=1., help='weight for cycle loss (A -> B -> A)') 25 | parser.add_argument('--lambda_B', type=float, default=1., help='weight for cycle loss (B -> A -> B)') 26 | parser.add_argument('--lambda_identity', type=float, default=0.5, 27 | help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss.' 28 | 'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1') 29 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 30 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 31 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') 32 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 33 | parser.add_argument('--avg_loss_alpha', type=float, default=.986, help='exponential averaging weight for displaying loss') 34 | self.isTrain = False 35 | return parser -------------------------------------------------------------------------------- /models/InstaColor/utils/datasets.py: -------------------------------------------------------------------------------- 1 | from os import listdir 2 | from os.path import isfile, join 3 | from random import sample 4 | 5 | import numpy as np 6 | import torch 7 | import torch.utils.data as Data 8 | import torchvision.transforms as transforms 9 | 10 | from models.InstaColor.utils.image_utils import * 11 | 12 | 13 | class Fusion_Testing_Dataset(Data.Dataset): 14 | def __init__(self, opt, filenames, box_num=8): 15 | self.PRED_BBOX_DIR = '{0}_bbox'.format(opt.test_img_dir) 16 | self.IMAGE_DIR = opt.test_img_dir 17 | self.IMAGE_ID_LIST = filenames 18 | 19 | self.transforms = transforms.Compose([transforms.Resize((opt.fineSize, opt.fineSize), interpolation=2), 20 | transforms.ToTensor()]) 21 | self.final_size = opt.fineSize 22 | self.box_num = box_num 23 | 24 | def __getitem__(self, index): 25 | pred_info_path = join(self.PRED_BBOX_DIR, self.IMAGE_ID_LIST[index].split('.')[0] + '.npz') 26 | output_image_path = join(self.IMAGE_DIR, self.IMAGE_ID_LIST[index]) 27 | pred_bbox = gen_maskrcnn_bbox_fromPred(pred_info_path, self.box_num) 28 | 29 | img_list = [] 30 | pil_img = read_to_pil(output_image_path) 31 | img_list.append(self.transforms(pil_img)) 32 | 33 | cropped_img_list = [] 34 | index_list = range(len(pred_bbox)) 35 | box_info, box_info_2x, box_info_4x, box_info_8x = np.zeros((4, len(index_list), 6)) 36 | for i in index_list: 37 | startx, starty, endx, endy = pred_bbox[i] 38 | box_info[i] = np.array(get_box_info(pred_bbox[i], pil_img.size, self.final_size)) 39 | box_info_2x[i] = np.array(get_box_info(pred_bbox[i], pil_img.size, self.final_size // 2)) 40 | box_info_4x[i] = np.array(get_box_info(pred_bbox[i], pil_img.size, self.final_size // 4)) 41 | box_info_8x[i] = np.array(get_box_info(pred_bbox[i], pil_img.size, self.final_size // 8)) 42 | cropped_img = self.transforms(pil_img.crop((startx, starty, endx, endy))) 43 | cropped_img_list.append(cropped_img) 44 | output = {} 45 | output['full_img'] = torch.stack(img_list) 46 | output['file_id'] = self.IMAGE_ID_LIST[index].split('.')[0] 47 | if len(pred_bbox) > 0: 48 | output['cropped_img'] = torch.stack(cropped_img_list) 49 | output['box_info'] = torch.from_numpy(box_info).type(torch.long) 50 | output['box_info_2x'] = torch.from_numpy(box_info_2x).type(torch.long) 51 | output['box_info_4x'] = torch.from_numpy(box_info_4x).type(torch.long) 52 | output['box_info_8x'] = torch.from_numpy(box_info_8x).type(torch.long) 53 | output['empty_box'] = False 54 | else: 55 | output['empty_box'] = True 56 | return output 57 | 58 | def __len__(self): 59 | return len(self.IMAGE_ID_LIST) 60 | 61 | 62 | class Training_Full_Dataset(Data.Dataset): 63 | ''' 64 | Training on COCOStuff dataset. [train2017.zip] 65 | 66 | Download the training set from https://github.com/nightrome/cocostuff 67 | ''' 68 | def __init__(self, opt): 69 | self.IMAGE_DIR = opt.train_img_dir 70 | self.transforms = transforms.Compose([transforms.Resize((opt.fineSize, opt.fineSize), interpolation=2), 71 | transforms.ToTensor()]) 72 | self.IMAGE_ID_LIST = [f for f in listdir(self.IMAGE_DIR) if isfile(join(self.IMAGE_DIR, f))] 73 | 74 | def __getitem__(self, index): 75 | output_image_path = join(self.IMAGE_DIR, self.IMAGE_ID_LIST[index]) 76 | rgb_img, gray_img = gen_gray_color_pil(output_image_path) 77 | output = {} 78 | output['rgb_img'] = self.transforms(rgb_img) 79 | output['gray_img'] = self.transforms(gray_img) 80 | return output 81 | 82 | def __len__(self): 83 | return len(self.IMAGE_ID_LIST) 84 | 85 | 86 | class Training_Instance_Dataset(Data.Dataset): 87 | ''' 88 | Training on COCOStuff dataset. [train2017.zip] 89 | 90 | Download the training set from https://github.com/nightrome/cocostuff 91 | 92 | Make sure you've predicted all the images' bounding boxes using inference_bbox.py 93 | 94 | It would be better if you can filter out the images which don't have any box. 95 | ''' 96 | def __init__(self, opt): 97 | self.PRED_BBOX_DIR = '{0}_bbox'.format(opt.train_img_dir) 98 | self.IMAGE_DIR = opt.train_img_dir 99 | self.IMAGE_ID_LIST = [f for f in listdir(self.IMAGE_DIR) if isfile(join(self.IMAGE_DIR, f))] 100 | self.transforms = transforms.Compose([ 101 | transforms.Resize((opt.fineSize, opt.fineSize), interpolation=2), 102 | transforms.ToTensor() 103 | ]) 104 | 105 | def __getitem__(self, index): 106 | pred_info_path = join(self.PRED_BBOX_DIR, self.IMAGE_ID_LIST[index].split('.')[0] + '.npz') 107 | output_image_path = join(self.IMAGE_DIR, self.IMAGE_ID_LIST[index]) 108 | pred_bbox = gen_maskrcnn_bbox_fromPred(pred_info_path) 109 | 110 | rgb_img, gray_img = gen_gray_color_pil(output_image_path) 111 | 112 | index_list = range(len(pred_bbox)) 113 | index_list = sample(index_list, 1) 114 | startx, starty, endx, endy = pred_bbox[index_list[0]] 115 | output = {} 116 | output['rgb_img'] = self.transforms(rgb_img.crop((startx, starty, endx, endy))) 117 | output['gray_img'] = self.transforms(gray_img.crop((startx, starty, endx, endy))) 118 | return output 119 | 120 | def __len__(self): 121 | return len(self.IMAGE_ID_LIST) 122 | 123 | 124 | class Training_Fusion_Dataset(Data.Dataset): 125 | ''' 126 | Training on COCOStuff dataset. [train2017.zip] 127 | 128 | Download the training set from https://github.com/nightrome/cocostuff 129 | 130 | Make sure you've predicted all the images' bounding boxes using inference_bbox.py 131 | 132 | It would be better if you can filter out the images which don't have any box. 133 | ''' 134 | def __init__(self, opt, box_num=8): 135 | self.PRED_BBOX_DIR = '{0}_bbox'.format(opt.train_img_dir) 136 | self.IMAGE_DIR = opt.train_img_dir 137 | self.IMAGE_ID_LIST = [f for f in listdir(self.IMAGE_DIR) if isfile(join(self.IMAGE_DIR, f))] 138 | 139 | self.transforms = transforms.Compose([transforms.Resize((opt.fineSize, opt.fineSize), interpolation=2), 140 | transforms.ToTensor()]) 141 | self.final_size = opt.fineSize 142 | self.box_num = box_num 143 | 144 | def __getitem__(self, index): 145 | pred_info_path = join(self.PRED_BBOX_DIR, self.IMAGE_ID_LIST[index].split('.')[0] + '.npz') 146 | output_image_path = join(self.IMAGE_DIR, self.IMAGE_ID_LIST[index]) 147 | pred_bbox = gen_maskrcnn_bbox_fromPred(pred_info_path, self.box_num) 148 | 149 | full_rgb_list = [] 150 | full_gray_list = [] 151 | rgb_img, gray_image = gen_gray_color_pil(output_image_path) 152 | full_rgb_list.append(self.transforms(rgb_img)) 153 | full_gray_list.append(self.transforms(gray_image)) 154 | 155 | cropped_rgb_list = [] 156 | cropped_gray_list = [] 157 | index_list = range(len(pred_bbox)) 158 | box_info, box_info_2x, box_info_4x, box_info_8x = np.zeros((4, len(index_list), 6)) 159 | for i in range(len(index_list)): 160 | startx, starty, endx, endy = pred_bbox[i] 161 | box_info[i] = np.array(get_box_info(pred_bbox[i], rgb_img.size, self.final_size)) 162 | box_info_2x[i] = np.array(get_box_info(pred_bbox[i], rgb_img.size, self.final_size // 2)) 163 | box_info_4x[i] = np.array(get_box_info(pred_bbox[i], rgb_img.size, self.final_size // 4)) 164 | box_info_8x[i] = np.array(get_box_info(pred_bbox[i], rgb_img.size, self.final_size // 8)) 165 | cropped_rgb_list.append(self.transforms(rgb_img.crop((startx, starty, endx, endy)))) 166 | cropped_gray_list.append(self.transforms(gray_image.crop((startx, starty, endx, endy)))) 167 | output = {} 168 | output['cropped_rgb'] = torch.stack(cropped_rgb_list) 169 | output['cropped_gray'] = torch.stack(cropped_gray_list) 170 | output['full_rgb'] = torch.stack(full_rgb_list) 171 | output['full_gray'] = torch.stack(full_gray_list) 172 | output['box_info'] = torch.from_numpy(box_info).type(torch.long) 173 | output['box_info_2x'] = torch.from_numpy(box_info_2x).type(torch.long) 174 | output['box_info_4x'] = torch.from_numpy(box_info_4x).type(torch.long) 175 | output['box_info_8x'] = torch.from_numpy(box_info_8x).type(torch.long) 176 | output['file_id'] = self.IMAGE_ID_LIST[index] 177 | return output 178 | 179 | def __len__(self): 180 | return len(self.IMAGE_ID_LIST) -------------------------------------------------------------------------------- /models/InstaColor/utils/download.py: -------------------------------------------------------------------------------- 1 | #taken from this StackOverflow answer: https://stackoverflow.com/a/39225039 2 | from importlib_metadata import requires 3 | import requests 4 | from os.path import join, isdir 5 | import os 6 | 7 | def download_file_from_google_drive(id, destination): 8 | URL = "https://docs.google.com/uc?export=download" 9 | 10 | session = requests.Session() 11 | 12 | response = session.get(URL, params = { 'id' : id }, stream = True) 13 | token = get_confirm_token(response) 14 | 15 | if token: 16 | params = { 'id' : id, 'confirm' : token } 17 | response = session.get(URL, params = params, stream = True) 18 | 19 | 20 | def get_confirm_token(response): 21 | for key, value in response.cookies.items(): 22 | if key.startswith('download_warning'): 23 | return value 24 | 25 | return None 26 | 27 | def download_coco_dataset(dataset_dir: str): 28 | print('download cocostuff training dataset') 29 | url = "http://images.cocodataset.org/zips/train2017.zip" 30 | response = requests.get(url, stream = True) 31 | if isdir(join(dataset_dir, "cocostuff")) is False: 32 | os.makedirs(join(dataset_dir, "cocostuff")) 33 | save_response_content(response, join(dataset_dir, "cocostuff", "train.zip")) 34 | -------------------------------------------------------------------------------- /models/InstaColor/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from skimage import color 4 | 5 | 6 | def gen_gray_color_pil(color_img_path): 7 | ''' 8 | return: RGB and GRAY pillow image object 9 | ''' 10 | rgb_img = Image.open(color_img_path) 11 | if len(np.asarray(rgb_img).shape) == 2: 12 | rgb_img = np.stack([np.asarray(rgb_img), np.asarray(rgb_img), np.asarray(rgb_img)], 2) 13 | rgb_img = Image.fromarray(rgb_img) 14 | gray_img = np.round(color.rgb2gray(np.asarray(rgb_img)) * 255.0).astype(np.uint8) 15 | gray_img = np.stack([gray_img, gray_img, gray_img], -1) 16 | gray_img = Image.fromarray(gray_img) 17 | return rgb_img, gray_img 18 | 19 | def read_to_pil(img_path): 20 | ''' 21 | return: pillow image object HxWx3 22 | ''' 23 | out_img = Image.open(img_path) 24 | if len(np.asarray(out_img).shape) == 2: 25 | out_img = np.stack([np.asarray(out_img), np.asarray(out_img), np.asarray(out_img)], 2) 26 | out_img = Image.fromarray(out_img) 27 | return out_img 28 | 29 | def gen_maskrcnn_bbox_fromPred(pred_data_path, box_num_upbound=-1): 30 | ''' 31 | ## Arguments: 32 | - pred_data_path: Detectron2 predict results 33 | - box_num_upbound: object bounding boxes number. Default: -1 means use all the instances. 34 | ''' 35 | pred_data = np.load(pred_data_path) 36 | assert 'bbox' in pred_data 37 | assert 'scores' in pred_data 38 | pred_bbox = pred_data['bbox'].astype(np.int32) 39 | if box_num_upbound > 0 and pred_bbox.shape[0] > box_num_upbound: 40 | pred_scores = pred_data['scores'] 41 | index_mask = np.argsort(pred_scores, axis=0)[pred_scores.shape[0] - box_num_upbound: pred_scores.shape[0]] 42 | pred_bbox = pred_bbox[index_mask] 43 | # pred_scores = pred_data['scores'] 44 | # index_mask = pred_scores > 0.9 45 | # pred_bbox = pred_bbox[index_mask].astype(np.int32) 46 | return pred_bbox 47 | 48 | def get_box_info(pred_bbox, original_shape, final_size): 49 | assert len(pred_bbox) == 4 50 | resize_startx = int(pred_bbox[0] / original_shape[0] * final_size) 51 | resize_starty = int(pred_bbox[1] / original_shape[1] * final_size) 52 | resize_endx = int(pred_bbox[2] / original_shape[0] * final_size) 53 | resize_endy = int(pred_bbox[3] / original_shape[1] * final_size) 54 | rh = resize_endx - resize_startx 55 | rw = resize_endy - resize_starty 56 | if rh < 1: 57 | if final_size - resize_endx > 1: 58 | resize_endx += 1 59 | else: 60 | resize_startx -= 1 61 | rh = 1 62 | if rw < 1: 63 | if final_size - resize_endy > 1: 64 | resize_endy += 1 65 | else: 66 | resize_starty -= 1 67 | rw = 1 68 | L_pad = resize_startx 69 | R_pad = final_size - resize_endx 70 | T_pad = resize_starty 71 | B_pad = final_size - resize_endy 72 | return [L_pad, R_pad, T_pad, B_pad, rh, rw] -------------------------------------------------------------------------------- /models/InstaColor/utils/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | sys.path.append("../..") 5 | 6 | 7 | import torch 8 | import numpy as np 9 | from PIL import Image 10 | import os 11 | from collections import OrderedDict 12 | from IPython import embed 13 | import cv2 14 | from utils.color_format import rgb2lab 15 | 16 | # Converts a Tensor into an image array (numpy) 17 | # |imtype|: the desired type of the converted numpy array 18 | def tensor2im(input_image, imtype=np.uint8): 19 | if isinstance(input_image, torch.Tensor): 20 | image_tensor = input_image.data 21 | else: 22 | return input_image 23 | image_numpy = image_tensor[0].cpu().float().numpy() 24 | if image_numpy.shape[0] == 1: 25 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 26 | image_numpy = np.clip((np.transpose(image_numpy, (1, 2, 0)) ),0, 1) * 255.0 27 | return image_numpy.astype(imtype) 28 | 29 | 30 | def diagnose_network(net, name='network'): 31 | mean = 0.0 32 | count = 0 33 | for param in net.parameters(): 34 | if param.grad is not None: 35 | mean += torch.mean(torch.abs(param.grad.data)) 36 | count += 1 37 | if count > 0: 38 | mean = mean / count 39 | print(name) 40 | print(mean) 41 | 42 | 43 | def save_image(image_numpy, image_path): 44 | image_pil = Image.fromarray(image_numpy) 45 | image_pil.save(image_path) 46 | 47 | 48 | def print_numpy(x, val=True, shp=False): 49 | x = x.astype(np.float64) 50 | if shp: 51 | print('shape,', x.shape) 52 | if val: 53 | x = x.flatten() 54 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 55 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 56 | 57 | 58 | def mkdirs(paths): 59 | if isinstance(paths, list) and not isinstance(paths, str): 60 | for path in paths: 61 | mkdir(path) 62 | else: 63 | mkdir(paths) 64 | 65 | 66 | def mkdir(path): 67 | if not os.path.exists(path): 68 | os.makedirs(path) 69 | 70 | 71 | def get_subset_dict(in_dict,keys): 72 | if(len(keys)): 73 | subset = OrderedDict() 74 | for key in keys: 75 | subset[key] = in_dict[key] 76 | else: 77 | subset = in_dict 78 | return subset 79 | 80 | 81 | 82 | def get_colorization_data(data_raw, opt, ab_thresh=5., p=.125, num_points=None): 83 | data = {} 84 | data_lab = rgb2lab(data_raw[0], opt) 85 | data['A'] = data_lab[:,[0,],:,:] 86 | data['B'] = data_lab[:,1:,:,:] 87 | 88 | if(ab_thresh > 0): # mask out grayscale images 89 | thresh = 1.*ab_thresh/opt.ab_norm 90 | mask = torch.sum(torch.abs(torch.max(torch.max(data['B'],dim=3)[0],dim=2)[0]-torch.min(torch.min(data['B'],dim=3)[0],dim=2)[0]),dim=1) >= thresh 91 | data['A'] = data['A'][mask,:,:,:] 92 | data['B'] = data['B'][mask,:,:,:] 93 | # print('Removed %i points'%torch.sum(mask==0).numpy()) 94 | if(torch.sum(mask)==0): 95 | return None 96 | 97 | return add_color_patches_rand_gt(data, opt, p=p, num_points=num_points) 98 | 99 | def add_color_patches_rand_gt(data,opt,p=.125,num_points=None,use_avg=True,samp='normal'): 100 | # Add random color points sampled from ground truth based on: 101 | # Number of points 102 | # - if num_points is 0, then sample from geometric distribution, drawn from probability p 103 | # - if num_points > 0, then sample that number of points 104 | # Location of points 105 | # - if samp is 'normal', draw from N(0.5, 0.25) of image 106 | # - otherwise, draw from U[0, 1] of image 107 | N,C,H,W = data['B'].shape 108 | 109 | data['hint_B'] = torch.zeros_like(data['B']) 110 | data['mask_B'] = torch.zeros_like(data['A']) 111 | 112 | for nn in range(N): 113 | pp = 0 114 | cont_cond = True 115 | while(cont_cond): 116 | if(num_points is None): # draw from geometric 117 | # embed() 118 | cont_cond = np.random.rand() < (1-p) 119 | else: # add certain number of points 120 | cont_cond = pp < num_points 121 | if(not cont_cond): # skip out of loop if condition not met 122 | continue 123 | print('add hint !!!!!!!!!') 124 | P = np.random.choice(opt.sample_Ps) # patch size 125 | 126 | # sample location 127 | if(samp=='normal'): # geometric distribution 128 | h = int(np.clip(np.random.normal( (H-P+1)/2., (H-P+1)/4.), 0, H-P)) 129 | w = int(np.clip(np.random.normal( (W-P+1)/2., (W-P+1)/4.), 0, W-P)) 130 | else: # uniform distribution 131 | h = np.random.randint(H-P+1) 132 | w = np.random.randint(W-P+1) 133 | 134 | # add color point 135 | if(use_avg): 136 | # embed() 137 | data['hint_B'][nn,:,h:h+P,w:w+P] = torch.mean(torch.mean(data['B'][nn,:,h:h+P,w:w+P],dim=2,keepdim=True),dim=1,keepdim=True).view(1,C,1,1) 138 | else: 139 | data['hint_B'][nn,:,h:h+P,w:w+P] = data['B'][nn,:,h:h+P,w:w+P] 140 | 141 | data['mask_B'][nn,:,h:h+P,w:w+P] = 1 142 | 143 | # increment counter 144 | pp+=1 145 | 146 | data['mask_B']-=opt.mask_cent 147 | 148 | return data 149 | 150 | def add_color_patch(data,mask,opt,P=1,hw=[128,128],ab=[0,0]): 151 | # Add a color patch at (h,w) with color (a,b) 152 | data[:,0,hw[0]:hw[0]+P,hw[1]:hw[1]+P] = 1.*ab[0]/opt.ab_norm 153 | data[:,1,hw[0]:hw[0]+P,hw[1]:hw[1]+P] = 1.*ab[1]/opt.ab_norm 154 | mask[:,:,hw[0]:hw[0]+P,hw[1]:hw[1]+P] = 1-opt.mask_cent 155 | 156 | return (data,mask) 157 | 158 | def crop_mult(data,mult=16,HWmax=[800,1200]): 159 | # crop image to a multiple 160 | H,W = data.shape[2:] 161 | Hnew = int(min(H/mult*mult,HWmax[0])) 162 | Wnew = int(min(W/mult*mult,HWmax[1])) 163 | h = (H-Hnew)/2 164 | w = (W-Wnew)/2 165 | 166 | return data[:,:,h:h+Hnew,w:w+Wnew] 167 | 168 | def encode_ab_ind(data_ab, opt): 169 | # Encode ab value into an index 170 | # INPUTS 171 | # data_ab Nx2xHxW \in [-1,1] 172 | # OUTPUTS 173 | # data_q Nx1xHxW \in [0,Q) 174 | 175 | data_ab_rs = torch.round((data_ab*opt.ab_norm + opt.ab_max)/opt.ab_quant) # normalized bin number 176 | data_q = data_ab_rs[:,[0],:,:]*opt.A + data_ab_rs[:,[1],:,:] 177 | return data_q 178 | 179 | def decode_ind_ab(data_q, opt): 180 | # Decode index into ab value 181 | # INPUTS 182 | # data_q Nx1xHxW \in [0,Q) 183 | # OUTPUTS 184 | # data_ab Nx2xHxW \in [-1,1] 185 | 186 | data_a = data_q/opt.A 187 | data_b = data_q - data_a*opt.A 188 | data_ab = torch.cat((data_a,data_b),dim=1) 189 | 190 | if(data_q.is_cuda): 191 | type_out = torch.cuda.FloatTensor 192 | else: 193 | type_out = torch.FloatTensor 194 | data_ab = ((data_ab.type(type_out)*opt.ab_quant) - opt.ab_max)/opt.ab_norm 195 | 196 | return data_ab 197 | 198 | def decode_max_ab(data_ab_quant, opt): 199 | # Decode probability distribution by using bin with highest probability 200 | # INPUTS 201 | # data_ab_quant NxQxHxW \in [0,1] 202 | # OUTPUTS 203 | # data_ab Nx2xHxW \in [-1,1] 204 | 205 | data_q = torch.argmax(data_ab_quant,dim=1)[:,None,:,:] 206 | return decode_ind_ab(data_q, opt) 207 | 208 | def decode_mean(data_ab_quant, opt): 209 | # Decode probability distribution by taking mean over all bins 210 | # INPUTS 211 | # data_ab_quant NxQxHxW \in [0,1] 212 | # OUTPUTS 213 | # data_ab_inf Nx2xHxW \in [-1,1] 214 | 215 | (N,Q,H,W) = data_ab_quant.shape 216 | a_range = torch.range(-opt.ab_max, opt.ab_max, step=opt.ab_quant).to(data_ab_quant.device)[None,:,None,None] 217 | a_range = a_range.type(data_ab_quant.type()) 218 | 219 | # reshape to AB space 220 | data_ab_quant = data_ab_quant.view((N,int(opt.A),int(opt.A),H,W)) 221 | data_a_total = torch.sum(data_ab_quant,dim=2) 222 | data_b_total = torch.sum(data_ab_quant,dim=1) 223 | 224 | # matrix multiply 225 | data_a_inf = torch.sum(data_a_total * a_range,dim=1,keepdim=True) 226 | data_b_inf = torch.sum(data_b_total * a_range,dim=1,keepdim=True) 227 | 228 | data_ab_inf = torch.cat((data_a_inf,data_b_inf),dim=1)/opt.ab_norm 229 | 230 | return data_ab_inf -------------------------------------------------------------------------------- /notebooks/DEVC.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Apply Deep Examplar-based Video Corization (DEVC)" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%load_ext autoreload\n", 17 | "%autoreload 2" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": 2, 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import os\n", 27 | "import sys\n", 28 | "sys.path.append(\"../\")\n", 29 | "import numpy as np\n", 30 | "from models.DEVC import DEVC" 31 | ] 32 | }, 33 | { 34 | "cell_type": "markdown", 35 | "metadata": {}, 36 | "source": [ 37 | "### Configure Options" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "from models.DEVC.options.test_options import TestOptions\n", 47 | "sys.argv = [sys.argv[0]]\n", 48 | "opt = TestOptions().parse() " 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "### Setup Models" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "model = DEVC(pretrained=True)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "### Setup Test File" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "from utils import v2i\n", 81 | "\n", 82 | "v2i.convert_video_to_frames(\"../test/input.mp4\", \"../test/frames\")" 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "### Run Model on Test File" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": null, 95 | "metadata": {}, 96 | "outputs": [], 97 | "source": [ 98 | "model.test(\"../test/frames/\", \"../test/output.mp4\", opt)" 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": null, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "from utils.util import apply_metric_to_video\n", 108 | "from utils.metrics import *\n", 109 | "apply_metric_to_video(\"../test/input.mp4\", \"../test/output.mp4\", [PSNR, SSIM, LPIPS, cosine_similarity])" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [] 118 | } 119 | ], 120 | "metadata": { 121 | "interpreter": { 122 | "hash": "3dc88b2683938781427c3c64a4f5f2f03c4267c2f876dcaeb3e56e854a9d6a9e" 123 | }, 124 | "kernelspec": { 125 | "display_name": "Python 3.7.6 ('instacolorization')", 126 | "language": "python", 127 | "name": "python3" 128 | }, 129 | "language_info": { 130 | "codemirror_mode": { 131 | "name": "ipython", 132 | "version": 3 133 | }, 134 | "file_extension": ".py", 135 | "mimetype": "text/x-python", 136 | "name": "python", 137 | "nbconvert_exporter": "python", 138 | "pygments_lexer": "ipython3", 139 | "version": "3.7.6" 140 | }, 141 | "orig_nbformat": 4 142 | }, 143 | "nbformat": 4, 144 | "nbformat_minor": 2 145 | } 146 | -------------------------------------------------------------------------------- /notebooks/colorizer.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Colorful Image Colorization\n", 8 | "\n", 9 | "instacolor's environment is good for this algorithm" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 2, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import matplotlib.pyplot as plt\n", 19 | "import sys, os\n", 20 | "sys.path.append(\"../\")\n", 21 | "from models.colorizers import siggraph17\n", 22 | "from models.colorizers.util import *" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": 3, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "colorizer_siggraph17 = siggraph17(pretrained=True).eval()\n", 32 | "_ = colorizer_siggraph17.cuda()" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 4, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "data": { 42 | "text/plain": [ 43 | "['frame00000.jpg', 'frame00001.jpg']" 44 | ] 45 | }, 46 | "execution_count": 4, 47 | "metadata": {}, 48 | "output_type": "execute_result" 49 | } 50 | ], 51 | "source": [ 52 | "VIDEO_NAME = \"yosemite-short\"\n", 53 | "INPUT_PATH = f\"../test/{VIDEO_NAME}-original/\"\n", 54 | "OUTPUT_PATH = f\"../test/{VIDEO_NAME}-colorized/\"\n", 55 | "if not os.path.isdir(OUTPUT_PATH):\n", 56 | " os.makedirs(OUTPUT_PATH)\n", 57 | "\n", 58 | "frames = os.listdir(INPUT_PATH)\n", 59 | "frames.sort()\n", 60 | "frames[:2]" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 5, 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "def colorize(path):\n", 70 | " img = load_img(path)\n", 71 | " (tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256,256))\n", 72 | " tens_l_rs = tens_l_rs.cuda()\n", 73 | " out_img_siggraph17 = postprocess_tens(tens_l_orig, colorizer_siggraph17(tens_l_rs).cpu())\n", 74 | " return out_img_siggraph17" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 6, 80 | "metadata": {}, 81 | "outputs": [ 82 | { 83 | "name": "stderr", 84 | "output_type": "stream", 85 | "text": [ 86 | "/home/tonyx/Utils/anaconda3/envs/instacolor/lib/python3.9/site-packages/torch/nn/functional.py:3631: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.\n", 87 | " warnings.warn(\n" 88 | ] 89 | } 90 | ], 91 | "source": [ 92 | "for frame in frames:\n", 93 | " colorized = colorize(os.path.join(INPUT_PATH, frame))\n", 94 | " plt.imsave(os.path.join(OUTPUT_PATH, frame), colorized)" 95 | ] 96 | } 97 | ], 98 | "metadata": { 99 | "interpreter": { 100 | "hash": "d66f20e7d73dbe1af3ef7c28c682f3f96a708e845fb78d9a29b2de937addae0a" 101 | }, 102 | "kernelspec": { 103 | "display_name": "Python 3.9.7 ('instacolor')", 104 | "language": "python", 105 | "name": "python3" 106 | }, 107 | "language_info": { 108 | "codemirror_mode": { 109 | "name": "ipython", 110 | "version": 3 111 | }, 112 | "file_extension": ".py", 113 | "mimetype": "text/x-python", 114 | "name": "python", 115 | "nbconvert_exporter": "python", 116 | "pygments_lexer": "ipython3", 117 | "version": "3.9.7" 118 | }, 119 | "orig_nbformat": 4 120 | }, 121 | "nbformat": 4, 122 | "nbformat_minor": 2 123 | } 124 | -------------------------------------------------------------------------------- /notebooks/video_prepro.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Video Pre-Processing" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 3, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import os, sys\n", 17 | "import cv2\n", 18 | "import numpy as np\n", 19 | "sys.path.append(\"../\")\n", 20 | "from utils.v2i import convert_video_to_frames, convert_frames_to_video\n", 21 | "from PIL import Image, ImageOps" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "metadata": {}, 27 | "source": [ 28 | "## MP4 to Frames" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 6, 34 | "metadata": {}, 35 | "outputs": [], 36 | "source": [ 37 | "VIDEO_NAME = \"yosemite-short\"\n", 38 | "DIR = \"../test\"\n", 39 | "\n", 40 | "convert_video_to_frames(os.path.join(DIR, VIDEO_NAME + \".mp4\"), os.path.join(DIR, VIDEO_NAME + \"-original\"))\n", 41 | "\n", 42 | "frames = sorted(os.listdir(os.path.join(DIR, VIDEO_NAME + \"-original\")))\n", 43 | "\n", 44 | "for frame in frames:\n", 45 | " f = Image.open(os.path.join(DIR, VIDEO_NAME + \"-original\", frame))\n", 46 | " gray_image = ImageOps.grayscale(f)\n", 47 | " gray_image.save(os.path.join(DIR, VIDEO_NAME + \"-original\", frame))\n" 48 | ] 49 | } 50 | ], 51 | "metadata": { 52 | "interpreter": { 53 | "hash": "189fd69765f0fd89b2713b9afa90d91317e65ce7ee286ac845775eb9093f307d" 54 | }, 55 | "kernelspec": { 56 | "display_name": "Python 3.9.7 ('dvp')", 57 | "language": "python", 58 | "name": "python3" 59 | }, 60 | "language_info": { 61 | "codemirror_mode": { 62 | "name": "ipython", 63 | "version": 3 64 | }, 65 | "file_extension": ".py", 66 | "mimetype": "text/x-python", 67 | "name": "python", 68 | "nbconvert_exporter": "python", 69 | "pygments_lexer": "ipython3", 70 | "version": "3.9.7" 71 | }, 72 | "orig_nbformat": 4 73 | }, 74 | "nbformat": 4, 75 | "nbformat_minor": 2 76 | } 77 | -------------------------------------------------------------------------------- /notebooks/webdemo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Colorization Webdemo" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "#### Install streamlit" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "!pip install streamlit -q" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "metadata": {}, 29 | "source": [ 30 | "#### clone git repository" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "!git clone https://github.com/Vince-Ai/Colorization.git" 40 | ] 41 | }, 42 | { 43 | "cell_type": "code", 44 | "execution_count": null, 45 | "metadata": {}, 46 | "outputs": [], 47 | "source": [ 48 | "%cd Colorization" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "metadata": {}, 54 | "source": [ 55 | "#### Install requirements" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "!pip install -r files/general.txt" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "#### Start server" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": null, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "!streamlit run webdemo.py & npx localtunnel --port 8501" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "#### Click on the link ends with loca.it and enjoy!" 88 | ] 89 | } 90 | ], 91 | "metadata": { 92 | "language_info": { 93 | "name": "python" 94 | }, 95 | "orig_nbformat": 4 96 | }, 97 | "nbformat": 4, 98 | "nbformat_minor": 2 99 | } 100 | -------------------------------------------------------------------------------- /sample.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wensi-ai/Colorizer/ef382f44885f9d083c02a83952db9b4a6b4149c1/sample.mp4 -------------------------------------------------------------------------------- /utils/color_format.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def rgb2xyz(rgb): # rgb from [0,1] 5 | # xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423], 6 | # [0.212671, 0.715160, 0.072169], 7 | # [0.019334, 0.119193, 0.950227]]) 8 | 9 | mask = (rgb > .04045).type(torch.FloatTensor) 10 | if(rgb.is_cuda): 11 | mask = mask.cuda() 12 | 13 | rgb = (((rgb+.055)/1.055)**2.4)*mask + rgb/12.92*(1-mask) 14 | 15 | x = .412453*rgb[:,0,:,:]+.357580*rgb[:,1,:,:]+.180423*rgb[:,2,:,:] 16 | y = .212671*rgb[:,0,:,:]+.715160*rgb[:,1,:,:]+.072169*rgb[:,2,:,:] 17 | z = .019334*rgb[:,0,:,:]+.119193*rgb[:,1,:,:]+.950227*rgb[:,2,:,:] 18 | 19 | out = torch.cat((x[:,None,:,:],y[:,None,:,:],z[:,None,:,:]),dim=1) 20 | return out 21 | 22 | 23 | def xyz2rgb(xyz): 24 | # array([[ 3.24048134, -1.53715152, -0.49853633], 25 | # [-0.96925495, 1.87599 , 0.04155593], 26 | # [ 0.05564664, -0.20404134, 1.05731107]]) 27 | 28 | r = 3.24048134*xyz[:,0,:,:]-1.53715152*xyz[:,1,:,:]-0.49853633*xyz[:,2,:,:] 29 | g = -0.96925495*xyz[:,0,:,:]+1.87599*xyz[:,1,:,:]+.04155593*xyz[:,2,:,:] 30 | b = .05564664*xyz[:,0,:,:]-.20404134*xyz[:,1,:,:]+1.05731107*xyz[:,2,:,:] 31 | 32 | rgb = torch.cat((r[:,None,:,:],g[:,None,:,:],b[:,None,:,:]),dim=1) 33 | rgb = torch.max(rgb,torch.zeros_like(rgb)) # sometimes reaches a small negative number, which causes NaNs 34 | 35 | mask = (rgb > .0031308).type(torch.FloatTensor) 36 | if(rgb.is_cuda): 37 | mask = mask.cuda() 38 | 39 | rgb = (1.055*(rgb**(1./2.4)) - 0.055)*mask + 12.92*rgb*(1-mask) 40 | return rgb 41 | 42 | 43 | def xyz2lab(xyz): 44 | # 0.95047, 1., 1.08883 # white 45 | sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None] 46 | if(xyz.is_cuda): 47 | sc = sc.cuda() 48 | 49 | xyz_scale = xyz/sc 50 | 51 | mask = (xyz_scale > .008856).type(torch.FloatTensor) 52 | if(xyz_scale.is_cuda): 53 | mask = mask.cuda() 54 | 55 | xyz_int = xyz_scale**(1/3.)*mask + (7.787*xyz_scale + 16./116.)*(1-mask) 56 | 57 | L = 116.*xyz_int[:,1,:,:]-16. 58 | a = 500.*(xyz_int[:,0,:,:]-xyz_int[:,1,:,:]) 59 | b = 200.*(xyz_int[:,1,:,:]-xyz_int[:,2,:,:]) 60 | 61 | out = torch.cat((L[:,None,:,:],a[:,None,:,:],b[:,None,:,:]),dim=1) 62 | return out 63 | 64 | 65 | def lab2xyz(lab): 66 | y_int = (lab[:,0,:,:]+16.)/116. 67 | x_int = (lab[:,1,:,:]/500.) + y_int 68 | z_int = y_int - (lab[:,2,:,:]/200.) 69 | if(z_int.is_cuda): 70 | z_int = torch.max(torch.Tensor((0,)).cuda(), z_int) 71 | else: 72 | z_int = torch.max(torch.Tensor((0,)), z_int) 73 | 74 | out = torch.cat((x_int[:,None,:,:],y_int[:,None,:,:],z_int[:,None,:,:]),dim=1) 75 | mask = (out > .2068966).type(torch.FloatTensor) 76 | if(out.is_cuda): 77 | mask = mask.cuda() 78 | 79 | out = (out**3.)*mask + (out - 16./116.)/7.787*(1-mask) 80 | 81 | sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None] 82 | sc = sc.to(out.device) 83 | 84 | out = out*sc 85 | return out 86 | 87 | 88 | def rgb2lab(rgb, opt): 89 | lab = xyz2lab(rgb2xyz(rgb)) 90 | l_rs = (lab[:,[0],:,:]-opt.l_cent)/opt.l_norm 91 | ab_rs = lab[:,1:,:,:]/opt.ab_norm 92 | out = torch.cat((l_rs,ab_rs),dim=1) 93 | return out 94 | 95 | 96 | def lab2rgb(lab_rs, opt): 97 | l = lab_rs[:,[0],:,:]*opt.l_norm + opt.l_cent 98 | ab = lab_rs[:,1:,:,:]*opt.ab_norm 99 | lab = torch.cat((l,ab),dim=1) 100 | out = xyz2rgb(lab2xyz(lab)) 101 | return out -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def mse_loss(input, target=0): 4 | return torch.mean((input - target) ** 2) 5 | 6 | 7 | def l1_loss(input, target=0): 8 | return torch.mean(torch.abs(input - target)) 9 | 10 | 11 | def weighted_mse_loss(input, target, weights): 12 | out = (input - target) ** 2 13 | out = out * weights.expand_as(out) 14 | return out.mean() 15 | 16 | 17 | def weighted_l1_loss(input, target, weights): 18 | out = torch.abs(input - target) 19 | out = out * weights.expand_as(out) 20 | return out.mean() -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from math import exp 6 | import lpips 7 | 8 | def PSNR(img1, img2): 9 | SE_map = (1. * img1 - img2) ** 2 10 | cur_MSE = torch.mean(SE_map) 11 | return float(20 * torch.log10(1. / torch.sqrt(cur_MSE))) 12 | 13 | 14 | def SSIM(img1, img2, window_size=11, size_average=True): 15 | (_, channel, _, _) = img1.size() 16 | 17 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 /float(2 * 1.5 ** 2)) for x in range(window_size)]) 18 | gauss = gauss / gauss.sum() 19 | _1D_window = gauss.unsqueeze(1) 20 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 21 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 22 | 23 | if img1.is_cuda: 24 | window = window.cuda(img1.get_device()) 25 | window = window.type_as(img1) 26 | 27 | mu1 = F.conv2d(img1, window, padding = window_size // 2, groups = channel) 28 | mu2 = F.conv2d(img2, window, padding = window_size // 2, groups = channel) 29 | 30 | mu1_sq = mu1.pow(2) 31 | mu2_sq = mu2.pow(2) 32 | mu1_mu2 = mu1*mu2 33 | 34 | sigma1_sq = F.conv2d(img1 * img1, window, padding = window_size // 2, groups = channel) - mu1_sq 35 | sigma2_sq = F.conv2d(img2 * img2, window, padding = window_size // 2, groups = channel) - mu2_sq 36 | sigma12 = F.conv2d(img1 * img2, window, padding = window_size // 2, groups = channel) - mu1_mu2 37 | 38 | C1 = 0.01 ** 2 39 | C2 = 0.03 ** 2 40 | 41 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 42 | 43 | if size_average: 44 | return float(ssim_map.mean()) 45 | else: 46 | return ssim_map.mean(1).mean(1).mean(1) 47 | 48 | 49 | def LPIPS(img1, img2): 50 | # Normalize to [-1, 1] 51 | img1 = 2 * (img1 / 255.) - 1 52 | img2 = 2 * (img2 / 255.) - 1 53 | loss_fn = lpips.LPIPS(net='alex') 54 | result = float(torch.mean(loss_fn.forward(img1, img2))) 55 | return result 56 | 57 | 58 | def PCVC(img1, img2): 59 | pass 60 | 61 | 62 | def colorfulness(imgs): 63 | """ 64 | according to the paper: Measuring colourfulness in natural images 65 | input is batches of ab tensors in lab space 66 | """ 67 | N, C, H, W = imgs.shape 68 | a = imgs[:, 0:1, :, :] 69 | b = imgs[:, 1:2, :, :] 70 | 71 | a = a.view(N, -1) 72 | b = b.view(N, -1) 73 | 74 | sigma_a = torch.std(a, dim=-1) 75 | sigma_b = torch.std(b, dim=-1) 76 | 77 | mean_a = torch.mean(a, dim=-1) 78 | mean_b = torch.mean(b, dim=-1) 79 | 80 | return torch.sqrt(sigma_a ** 2 + sigma_b ** 2) + 0.37 * torch.sqrt(mean_a ** 2 + mean_b ** 2) 81 | 82 | 83 | def cosine_similarity(img1, img2): 84 | input_norm = torch.norm(img1, 2, 1, keepdim=True) + sys.float_info.epsilon 85 | target_norm = torch.norm(img2, 2, 1, keepdim=True) + sys.float_info.epsilon 86 | normalized_input = torch.div(img1, input_norm) 87 | normalized_target = torch.div(img2, target_norm) 88 | cos_similarity = torch.mul(normalized_input, normalized_target) 89 | return float(torch.mean(cos_similarity)) -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import requests 4 | import torch 5 | from torchvision import io 6 | import torchvideo.transforms as transforms 7 | 8 | def download_zipfile(url: str, destination: str='.', unzip: bool=True): 9 | response = requests.get(url) 10 | CHUNK_SIZE = 32768 11 | with open(destination, "wb") as f: 12 | for chunk in response.iter_content(CHUNK_SIZE): 13 | if chunk: # filter out keep-alive new chunks 14 | f.write(chunk) 15 | if unzip: 16 | with zipfile.ZipFile(destination, 'r') as zip_ref: 17 | zip_ref.extractall(".") 18 | 19 | def mkdir(dir_path): 20 | if not os.path.exists(dir_path): 21 | os.makedirs(dir_path) 22 | 23 | 24 | def apply_metric_to_video(video_path1, video_path2, metrics): 25 | batch_size = 50 26 | 27 | video1, _, _ = io.read_video(video_path1) 28 | video2, _, _ = io.read_video(video_path2) 29 | video1 = video1.cpu().detach().numpy() 30 | transform = transforms.Compose([ 31 | transforms.NDArrayToPILVideo(), 32 | transforms.ResizeVideo((video2.shape[1], video2.shape[2])), 33 | transforms.CollectFrames(), 34 | transforms.PILVideoToTensor(rescale=False, ordering='TCHW')] 35 | ) 36 | length = min(video1.shape[0], video2.shape[0]) 37 | video1 = transform(video1)[:length].float() 38 | video2 = video2.permute((0, 3, 1, 2))[:length].float() 39 | results = [] 40 | for metric in metrics: 41 | cur_metric_result = 0 42 | for i in range(0, video1.shape[0], batch_size): 43 | cur_metric_result += min(video1.shape[0] - i, batch_size) * metric(video1[i: min(video1.shape[0], i + batch_size)], video2[i: min(video1.shape[0], i + batch_size)]) 44 | results.append(cur_metric_result / video1.shape[0]) 45 | return results 46 | -------------------------------------------------------------------------------- /utils/v2i.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from typing import List 4 | import skvideo.io 5 | 6 | def convert_video_to_frames(video_path: str, frames_path: str): 7 | """Convert video input to a set of frames in jpg format""" 8 | if os.path.isdir(frames_path) is False: 9 | os.makedirs(frames_path) 10 | vidcap = cv2.VideoCapture(video_path) 11 | success,image = vidcap.read() 12 | count = 0 13 | success = True 14 | success, image = vidcap.read() 15 | while success: 16 | cv2.imwrite(f"{frames_path}/frame" + str(count).zfill(5) + ".jpg", image) # save frame as JPEG file 17 | if cv2.waitKey(10) == 27: # exit if Escape is hit 18 | break 19 | success, image = vidcap.read() 20 | count += 1 21 | 22 | 23 | def convert_frames_to_video(frames_path: str, output_video_path: str, images: List[str]=None): 24 | """Convert video input to a set of frames in jpg format""" 25 | if images is None: 26 | images = sorted(os.listdir(frames_path)) 27 | frame = cv2.imread(os.path.join(frames_path, images[0])) 28 | height, width, _ = frame.shape 29 | fourcc = cv2.VideoWriter_fourcc(*'MP4V') 30 | video = cv2.VideoWriter("temp.mp4", fourcc, 30.0, (width, height)) 31 | 32 | for image in images: 33 | video.write(cv2.imread(os.path.join(frames_path, image))) 34 | 35 | cv2.destroyAllWindows() 36 | video.release() 37 | 38 | os.system(f"ffmpeg -y -i temp.mp4 -vcodec libx264 {output_video_path}") 39 | os.remove("temp.mp4") -------------------------------------------------------------------------------- /webdemo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append(".") 3 | 4 | import os 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | import streamlit as st 8 | from importlib import import_module 9 | from utils import v2i 10 | from utils.util import apply_metric_to_video, mkdir 11 | from utils.metrics import * 12 | import shutil 13 | 14 | mkdir("test") 15 | 16 | @st.cache 17 | def colorize(model, model_name, opt): 18 | sys.argv = [sys.argv[0]] 19 | model.test("test/frames", f"test/output_{model_name}.mp4", opt) 20 | video = open(f"test/output_{model_name}.mp4", 'rb').read() 21 | return video 22 | 23 | @st.cache 24 | def dvp(model, model_name, opt): 25 | sys.argv = [sys.argv[0]] 26 | model.test("test/frames", "test/results", f"test/output_{model_name}.mp4", opt) 27 | video = open(f"test/output_{model_name}.mp4", 'rb').read() 28 | return video 29 | 30 | @st.cache 31 | def metric(video_out): 32 | results = {} 33 | ret = apply_metric_to_video("test/temp.mp4", video_out, [PSNR, SSIM, LPIPS, cosine_similarity]) 34 | results["PSNR"] = ret[0] 35 | results["SSIM"] = ret[1] 36 | results["LPIPS"] = ret[2] 37 | results["cosine_similarity"] = ret[3] 38 | return results 39 | 40 | st.title("Video Colorization Web Demo") 41 | 42 | # Add a selectbox to the sidebar: 43 | model_names = st.sidebar.multiselect( 44 | 'Select colorization model(s):', 45 | ('DEVC', 'InstaColor', 'Colorful') 46 | ) 47 | 48 | # Add a slider to the sidebar: 49 | input_method = st.sidebar.selectbox( 50 | 'Select video input:', 51 | ("sample video", "random from dataset", "upload local video") 52 | ) 53 | # get input video 54 | if input_method == "upload local video": 55 | video_orig = st.sidebar.file_uploader("Choose a file") 56 | if video_orig: 57 | video_orig = video_orig.getvalue() 58 | elif input_method == "sample video": 59 | video_orig = open('sample.mp4', 'rb').read() 60 | else: 61 | video_orig = None 62 | 63 | # display video 64 | st.subheader("Original video:") 65 | st.video(video_orig, format="video/mp4", start_time=0) 66 | st.subheader("Colorized video:") 67 | 68 | # setup metric 69 | psnr, ssim, lpips, cs = [], [], [], [] 70 | 71 | if st.sidebar.button('Colorize'): 72 | with open("test/temp.mp4", "wb") as f: 73 | f.write(video_orig) 74 | video_in = v2i.convert_video_to_frames("test/temp.mp4", "./test/frames") 75 | for i, model_name in enumerate(model_names): 76 | # import model 77 | model_module = import_module(f"models.{model_name}") 78 | Model = getattr(model_module, model_name) 79 | # import and parse test options 80 | try: 81 | opt_module = import_module(f"models.{model_name}.options.test_options") 82 | TestOptions = getattr(opt_module, "TestOptions") 83 | opt = TestOptions().parse() 84 | except AttributeError: # no option needed 85 | opt = None 86 | # run test on model 87 | model = Model() 88 | if model_name == "DVP": 89 | video_out = dvp(model, model_name, opt) 90 | else: 91 | video_out = colorize(model, model_name, opt) 92 | st.write(model_name + ": ") 93 | st.video(video_out, format="video/mp4", start_time=0) 94 | print("video displayed") 95 | # get metrics 96 | results = metric(f"test/output_{model_name}.mp4") 97 | psnr.append(results["PSNR"]) 98 | ssim.append(results["SSIM"]) 99 | lpips.append(results["LPIPS"]) 100 | cs.append(results["cosine_similarity"]) 101 | print(f"{model_name} metric complete!") 102 | os.remove("test/temp.mp4") 103 | 104 | shutil.rmtree("test/frames") 105 | shutil.rmtree("test/results") 106 | 107 | # display metrics 108 | st.subheader("Metrics:") 109 | idx = model_names 110 | fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(20, 30)) 111 | # plt.subplots_adjust(wspace=1, hspace=1) 112 | ax1.bar(idx, psnr, color=['grey' if (x < max(psnr)) else 'red' for x in psnr ], width=0.4) 113 | ax2.bar(idx, ssim, color=['grey' if (x < max(ssim)) else 'red' for x in ssim ], width=0.4) 114 | ax3.bar(idx, lpips, color=['grey' if (x > min(lpips)) else 'red' for x in lpips ], width=0.4) 115 | ax4.bar(idx, cs, color=['grey' if (x < max(cs)) else 'red' for x in cs ], width=0.4) 116 | ax1.set_title('PNSR (higher is better)') 117 | ax2.set_title('SSIM (higher is better)') 118 | ax3.set_title('LPIPS (lower is better)') 119 | ax4.set_title('Cosine Similarity (higher is better)') 120 | st.pyplot(fig) --------------------------------------------------------------------------------