├── .gitignore
├── LICENSE
├── README.md
├── files
├── DECV_requirements.txt
├── DVPEnv.yml
├── general.txt
├── instaColorEnv.yml
└── instaColorInstallDep.sh
├── models
├── Colorful
│ ├── __init__.py
│ ├── base_color.py
│ ├── eccv16.py
│ ├── options
│ │ └── test_options.py
│ ├── siggraph17.py
│ └── util.py
├── DEVC
│ ├── __init__.py
│ ├── models
│ │ ├── ColorVidNet.py
│ │ ├── FrameColor.py
│ │ ├── GAN_Models.py
│ │ ├── NonlocalNet.py
│ │ ├── spectral_normalization.py
│ │ └── vgg19_gray.py
│ ├── options
│ │ └── test_options.py
│ └── utils
│ │ ├── lib
│ │ ├── functional.py
│ │ └── test_transforms.py
│ │ ├── loss.py
│ │ ├── util.py
│ │ └── util_distortion.py
├── DVP
│ ├── README.md
│ ├── __init__.py
│ ├── environment.yml
│ ├── main_IRT.py
│ ├── models
│ │ ├── __pycache__
│ │ │ ├── network.cpython-36.pyc
│ │ │ └── network.cpython-39.pyc
│ │ └── network.py
│ ├── options
│ │ └── test_options.py
│ ├── test.sh
│ └── vgg.py
└── InstaColor
│ ├── __init__.py
│ ├── data
│ ├── __init__.py
│ ├── aligned_dataset.py
│ ├── base_data_loader.py
│ ├── base_dataset.py
│ ├── color_dataset.py
│ ├── image_folder.py
│ └── single_dataset.py
│ ├── models
│ ├── __init__.py
│ ├── base_model.py
│ ├── fusion_model.py
│ └── networks.py
│ ├── options
│ ├── base_options.py
│ └── test_options.py
│ └── utils
│ ├── datasets.py
│ ├── download.py
│ ├── image_utils.py
│ └── util.py
├── notebooks
├── DEVC.ipynb
├── DVP.ipynb
├── InstaColor.ipynb
├── colorizer.ipynb
├── video_prepro.ipynb
└── webdemo.ipynb
├── sample.mp4
├── utils
├── color_format.py
├── loss.py
├── metrics.py
├── util.py
└── v2i.py
└── webdemo.py
/.gitignore:
--------------------------------------------------------------------------------
1 | test
2 |
3 | notebooks/.ipynb_checkpoints
4 | notebooks/results
5 | notebooks/checkpoints*
6 |
7 | **/__pycache__
8 | *.zip
9 | checkpoints
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Vince Ai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Colorizer
2 | A collection of colorization models
3 |
4 | Model curently included:
5 |
6 | - DEVC
7 |
8 | - DVP
9 |
10 | - InstaColor
11 |
12 | - Colorful
13 |
14 | DVP Colab Demo:
15 | [](https://drive.google.com/drive/folders/1L6_hY35kvL5EkFCIncQPi7uRoOsNdgvn)
16 |
17 | Interactive Web Demo for other models:
18 | [](https://colab.research.google.com/github/Vince-Ai/Colorization/blob/main/notebooks/webdemo.ipynb)
19 |
--------------------------------------------------------------------------------
/files/DECV_requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | tqdm>=4.23.0
3 | scipy==1.2
4 | torchsummary>=1.3
5 | matplotlib>=2.2.2
6 | opencv_contrib_python==4.2.0.32
7 | torchvision>=0.3.0
8 | scikit_image>=0.15.0
9 | torchviz>=0.0.1
10 | Pillow>=6.1.0
11 | torch
12 | easydict
13 | prefetch_generator
14 | PyYAML
15 | tensorboard
16 | future
17 | numba
18 | pypng
--------------------------------------------------------------------------------
/files/DVPEnv.yml:
--------------------------------------------------------------------------------
1 | name: dvp
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=4.5=1_gnu
9 | - asttokens=2.0.5=pyhd8ed1ab_0
10 | - backcall=0.2.0=pyh9f0ad1d_0
11 | - backports=1.0=py_2
12 | - backports.functools_lru_cache=1.6.4=pyhd8ed1ab_0
13 | - blas=1.0=mkl
14 | - brotli=1.0.9=he6710b0_2
15 | - bzip2=1.0.8=h7b6447c_0
16 | - ca-certificates=2022.2.1=h06a4308_0
17 | - certifi=2021.10.8=py39h06a4308_2
18 | - cudatoolkit=11.3.1=h2bc3f7f_2
19 | - cycler=0.11.0=pyhd3eb1b0_0
20 | - dbus=1.13.18=hb2f20db_0
21 | - debugpy=1.5.1=py39h295c915_0
22 | - decorator=5.1.1=pyhd8ed1ab_0
23 | - entrypoints=0.4=pyhd8ed1ab_0
24 | - executing=0.8.3=pyhd8ed1ab_0
25 | - expat=2.4.4=h295c915_0
26 | - ffmpeg=4.3=hf484d3e_0
27 | - fontconfig=2.13.1=h6c09931_0
28 | - fonttools=4.25.0=pyhd3eb1b0_0
29 | - freetype=2.11.0=h70c0345_0
30 | - giflib=5.2.1=h7b6447c_0
31 | - glib=2.69.1=h4ff587b_1
32 | - gmp=6.2.1=h2531618_2
33 | - gnutls=3.6.15=he1e5248_0
34 | - gst-plugins-base=1.14.0=h8213a91_2
35 | - gstreamer=1.14.0=h28cd5cc_2
36 | - icu=58.2=he6710b0_3
37 | - imageio=2.16.1=pyhcf75d05_0
38 | - intel-openmp=2021.4.0=h06a4308_3561
39 | - ipykernel=6.9.1=py39hef51801_0
40 | - ipython=8.1.1=py39hf3d152e_0
41 | - jedi=0.18.1=py39hf3d152e_0
42 | - jpeg=9d=h7f8727e_0
43 | - jupyter_client=7.1.2=pyhd8ed1ab_0
44 | - jupyter_core=4.9.2=py39hf3d152e_0
45 | - kiwisolver=1.3.2=py39h295c915_0
46 | - lame=3.100=h7b6447c_0
47 | - lcms2=2.12=h3be6417_0
48 | - ld_impl_linux-64=2.35.1=h7274673_9
49 | - libffi=3.3=he6710b0_2
50 | - libgcc-ng=9.3.0=h5101ec6_17
51 | - libgfortran-ng=7.5.0=ha8ba4b0_17
52 | - libgfortran4=7.5.0=ha8ba4b0_17
53 | - libgomp=9.3.0=h5101ec6_17
54 | - libiconv=1.15=h63c8f33_5
55 | - libidn2=2.3.2=h7f8727e_0
56 | - libpng=1.6.37=hbc83047_0
57 | - libsodium=1.0.18=h36c2ea0_1
58 | - libstdcxx-ng=9.3.0=hd4cf53a_17
59 | - libtasn1=4.16.0=h27cfd23_0
60 | - libtiff=4.2.0=h85742a9_0
61 | - libunistring=0.9.10=h27cfd23_0
62 | - libuuid=1.0.3=h7f8727e_2
63 | - libuv=1.40.0=h7b6447c_0
64 | - libwebp=1.2.2=h55f646e_0
65 | - libwebp-base=1.2.2=h7f8727e_0
66 | - libxcb=1.14=h7b6447c_0
67 | - libxml2=2.9.12=h03d6c58_0
68 | - lz4-c=1.9.3=h295c915_1
69 | - matplotlib=3.5.1=py39h06a4308_0
70 | - matplotlib-base=3.5.1=py39ha18d171_0
71 | - matplotlib-inline=0.1.3=pyhd8ed1ab_0
72 | - mkl=2021.4.0=h06a4308_640
73 | - mkl-service=2.4.0=py39h7f8727e_0
74 | - mkl_fft=1.3.1=py39hd3c417c_0
75 | - mkl_random=1.2.2=py39h51133e4_0
76 | - munkres=1.1.4=py_0
77 | - ncurses=6.3=h7f8727e_2
78 | - nest-asyncio=1.5.4=pyhd8ed1ab_0
79 | - nettle=3.7.3=hbbd107a_1
80 | - numpy=1.21.2=py39h20f2e39_0
81 | - numpy-base=1.21.2=py39h79a1101_0
82 | - openh264=2.1.1=h4ff587b_0
83 | - openssl=1.1.1m=h7f8727e_0
84 | - packaging=21.3=pyhd3eb1b0_0
85 | - parso=0.8.3=pyhd8ed1ab_0
86 | - pcre=8.45=h295c915_0
87 | - pexpect=4.8.0=pyh9f0ad1d_2
88 | - pickleshare=0.7.5=py_1003
89 | - pillow=9.0.1=py39h22f2fdc_0
90 | - pip=21.2.4=py39h06a4308_0
91 | - prompt-toolkit=3.0.27=pyha770c72_0
92 | - ptyprocess=0.7.0=pyhd3deb0d_0
93 | - pure_eval=0.2.2=pyhd8ed1ab_0
94 | - pygments=2.11.2=pyhd8ed1ab_0
95 | - pyparsing=3.0.4=pyhd3eb1b0_0
96 | - pyqt=5.9.2=py39h2531618_6
97 | - python=3.9.7=h12debd9_1
98 | - python-dateutil=2.8.2=pyhd8ed1ab_0
99 | - python_abi=3.9=2_cp39
100 | - pytorch=1.10.2=py3.9_cuda11.3_cudnn8.2.0_0
101 | - pytorch-mutex=1.0=cuda
102 | - pyzmq=19.0.2=py39hb69f2a1_2
103 | - qt=5.9.7=h5867ecd_1
104 | - readline=8.1.2=h7f8727e_1
105 | - scipy=1.7.3=py39hc147768_0
106 | - setuptools=58.0.4=py39h06a4308_0
107 | - sip=4.19.13=py39h295c915_0
108 | - six=1.16.0=pyhd3eb1b0_1
109 | - sqlite=3.37.2=hc218d9a_0
110 | - stack_data=0.2.0=pyhd8ed1ab_0
111 | - tk=8.6.11=h1ccaba5_0
112 | - torchaudio=0.10.2=py39_cu113
113 | - torchvision=0.11.3=py39_cu113
114 | - tornado=6.1=py39h3811e60_1
115 | - traitlets=5.1.1=pyhd8ed1ab_0
116 | - typing_extensions=3.10.0.2=pyh06a4308_0
117 | - tzdata=2021e=hda174b7_0
118 | - wcwidth=0.2.5=pyh9f0ad1d_2
119 | - wheel=0.37.1=pyhd3eb1b0_0
120 | - xz=5.2.5=h7b6447c_0
121 | - zeromq=4.3.4=h9c3ff4c_0
122 | - zlib=1.2.11=h7f8727e_4
123 | - zstd=1.4.9=haebb681_0
124 | - pip:
125 | - opencv-python==4.5.5.62
126 | prefix: /home/tonyx/Utils/anaconda3/envs/dvp
127 |
--------------------------------------------------------------------------------
/files/general.txt:
--------------------------------------------------------------------------------
1 | av
2 | torchvideo
3 | pillow-simd
4 | lpips
--------------------------------------------------------------------------------
/files/instaColorEnv.yml:
--------------------------------------------------------------------------------
1 | name: instacolorization
2 | channels:
3 | - conda-forge
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - backcall=0.1.0=py37_0
8 | - blas=1.0=mkl
9 | - ca-certificates=2019.11.28=hecc5488_0
10 | - certifi=2019.11.28=py37_0
11 | - cloudpickle=1.3.0=py_0
12 | - cudatoolkit=10.1.243=h6bb024c_0
13 | - cycler=0.10.0=py37_0
14 | - cython=0.29.15=py37he6710b0_0
15 | - cytoolz=0.10.1=py37h7b6447c_0
16 | - dask-core=2.10.1=py_0
17 | - dbus=1.13.12=h746ee38_0
18 | - decorator=4.4.1=py_0
19 | - expat=2.2.6=he6710b0_0
20 | - fontconfig=2.13.0=h9420a91_0
21 | - freetype=2.9.1=h8a8886c_1
22 | - glib=2.63.1=h5a9c865_0
23 | - gst-plugins-base=1.14.0=hbbd80ab_1
24 | - gstreamer=1.14.0=hb453b48_1
25 | - icu=58.2=h9c2bf20_1
26 | - imageio=2.6.1=py37_0
27 | - intel-openmp=2020.0=166
28 | - ipython=7.12.0=py37h5ca1d4c_0
29 | - ipython_genutils=0.2.0=py37_0
30 | - jedi=0.16.0=py37_0
31 | - joblib=0.14.1=py_0
32 | - jpeg=9b=h024ee3a_2
33 | - kiwisolver=1.1.0=py37he6710b0_0
34 | - ld_impl_linux-64=2.33.1=h53a641e_7
35 | - libedit=3.1.20181209=hc058e9b_0
36 | - libffi=3.2.1=hd88cf55_4
37 | - libgcc-ng=9.1.0=hdf63c60_0
38 | - libgfortran-ng=7.3.0=hdf63c60_0
39 | - libpng=1.6.37=hbc83047_0
40 | - libstdcxx-ng=9.1.0=hdf63c60_0
41 | - libtiff=4.1.0=h2733197_0
42 | - libuuid=1.0.3=h1bed415_2
43 | - libxcb=1.13=h1bed415_1
44 | - libxml2=2.9.9=hea5a465_1
45 | - matplotlib=3.1.3=py37_0
46 | - matplotlib-base=3.1.3=py37hef1b27d_0
47 | - mkl=2020.0=166
48 | - mkl-service=2.3.0=py37he904b0f_0
49 | - mkl_fft=1.0.15=py37ha843d7b_0
50 | - mkl_random=1.1.0=py37hd6b4f25_0
51 | - ncurses=6.1=he6710b0_1
52 | - networkx=2.4=py_0
53 | - ninja=1.9.0=py37hfd86e86_0
54 | - numpy=1.18.1=py37h4f9e942_0
55 | - numpy-base=1.18.1=py37hde5b4d6_1
56 | - olefile=0.46=py37_0
57 | - openssl=1.1.1d=h516909a_0
58 | - parso=0.6.1=py_0
59 | - pcre=8.43=he6710b0_0
60 | - pexpect=4.8.0=py37_0
61 | - pickleshare=0.7.5=py37_0
62 | - pillow=7.0.0=py37hb39fc2d_0
63 | - pip=20.0.2=py37_1
64 | - prompt_toolkit=3.0.3=py_0
65 | - ptyprocess=0.6.0=py37_0
66 | - pygments=2.5.2=py_0
67 | - pyparsing=2.4.6=py_0
68 | - pyqt=5.9.2=py37h05f1152_2
69 | - python=3.7.6=h0371630_2
70 | - python-dateutil=2.8.1=py_0
71 | - pywavelets=1.1.1=py37h7b6447c_0
72 | - qt=5.9.7=h5867ecd_1
73 | - readline=7.0=h7b6447c_5
74 | - scikit-image=0.16.2=py37h0573a6f_0
75 | - scikit-learn=0.22.1=py37hd81dba3_0
76 | - scipy=1.4.1=py37h0b6359f_0
77 | - setuptools=45.2.0=py37_0
78 | - sip=4.19.8=py37hf484d3e_0
79 | - six=1.14.0=py37_0
80 | - sqlite=3.31.1=h7b6447c_0
81 | - tk=8.6.8=hbc83047_0
82 | - toolz=0.10.0=py_0
83 | - tornado=6.0.3=py37h7b6447c_3
84 | - tqdm=4.43.0=py_0
85 | - traitlets=4.3.3=py37_0
86 | - wcwidth=0.1.8=py_0
87 | - wheel=0.34.2=py37_0
88 | - xz=5.2.4=h14c3975_4
89 | - zlib=1.2.11=h7b6447c_3
90 | - zstd=1.3.7=h0b5b093_0
91 | - pip:
92 | - absl-py==0.9.0
93 | - cachetools==4.1.0
94 | - chardet==3.0.4
95 | - future==0.18.2
96 | - fvcore==0.1.dev200506
97 | - google-auth==1.14.2
98 | - google-auth-oauthlib==0.4.1
99 | - grpcio==1.28.1
100 | - idna==2.9
101 | - importlib-metadata==1.6.0
102 | - jsonpatch==1.25
103 | - jsonpointer==2.0
104 | - markdown==3.2.2
105 | - mock==4.0.2
106 | - oauthlib==3.1.0
107 | - opencv-python==4.2.0.32
108 | - portalocker==1.7.0
109 | - protobuf==3.11.3
110 | - pyasn1==0.4.8
111 | - pyasn1-modules==0.2.8
112 | - pydot==1.4.1
113 | - pyzmq==18.1.1
114 | - requests==2.23.0
115 | - requests-oauthlib==1.3.0
116 | - rsa==4.0
117 | - tabulate==0.8.7
118 | - tensorboard==2.2.1
119 | - tensorboard-plugin-wit==1.6.0.post3
120 | - termcolor==1.1.0
121 | - urllib3==1.25.8
122 | - visdom==0.1.8.9
123 | - websocket-client==0.57.0
124 | - werkzeug==1.0.1
125 | - yacs==0.1.7
126 | - zipp==3.1.0
127 | prefix: /home/user/miniconda3/envs/py37_pt14
128 |
129 |
--------------------------------------------------------------------------------
/files/instaColorInstallDep.sh:
--------------------------------------------------------------------------------
1 | pip install -U torch==1.5 torchvision==0.6 -f https://download.pytorch.org/whl/cu101/torch_stable.html
2 | pip install cython pyyaml==5.1
3 | pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
4 | pip install dominate==2.4.0
5 | pip install detectron2==0.1.2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/index.html
6 |
--------------------------------------------------------------------------------
/models/Colorful/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .base_color import *
3 | from .eccv16 import *
4 | from .siggraph17 import *
5 | from .util import *
6 |
7 | import os
8 | import shutil
9 | import cv2
10 | import torch
11 | from PIL import Image
12 | from tqdm import tqdm
13 | import numpy as np
14 | import torchvision.transforms as transform_lib
15 | import matplotlib.pyplot as plt
16 |
17 | from utils.util import download_zipfile, mkdir
18 | from utils.v2i import convert_frames_to_video
19 |
20 | class OPT():
21 | pass
22 |
23 | class Colorful():
24 | def __init__(self, pretrained=True):
25 | self.model = siggraph17(pretrained=True).cuda().eval()
26 | self.opt = OPT()
27 | self.opt.output_frame_path = "./test/results"
28 |
29 | def test(self, input_path, output_path, opt=None):
30 |
31 | if not os.path.isdir(self.opt.output_frame_path):
32 | os.makedirs(self.opt.output_frame_path)
33 |
34 | frames = os.listdir(input_path)
35 | frames.sort()
36 | for frame in frames:
37 | colorized = self.colorize(os.path.join(input_path, frame))
38 | plt.imsave(os.path.join(self.opt.output_frame_path, frame), colorized)
39 |
40 | convert_frames_to_video(self.opt.output_frame_path, output_path)
41 |
42 |
43 | def colorize(self, path):
44 | img = load_img(path)
45 | (tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256,256))
46 | tens_l_rs = tens_l_rs.cuda()
47 | out_img_siggraph17 = postprocess_tens(tens_l_orig, self.model(tens_l_rs).cpu())
48 | return out_img_siggraph17
--------------------------------------------------------------------------------
/models/Colorful/base_color.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | from torch import nn
4 |
5 | class BaseColor(nn.Module):
6 | def __init__(self):
7 | super(BaseColor, self).__init__()
8 |
9 | self.l_cent = 50.
10 | self.l_norm = 100.
11 | self.ab_norm = 110.
12 |
13 | def normalize_l(self, in_l):
14 | return (in_l-self.l_cent)/self.l_norm
15 |
16 | def unnormalize_l(self, in_l):
17 | return in_l*self.l_norm + self.l_cent
18 |
19 | def normalize_ab(self, in_ab):
20 | return in_ab/self.ab_norm
21 |
22 | def unnormalize_ab(self, in_ab):
23 | return in_ab*self.ab_norm
24 |
25 |
--------------------------------------------------------------------------------
/models/Colorful/eccv16.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | from IPython import embed
6 |
7 | from .base_color import *
8 |
9 | class ECCVGenerator(BaseColor):
10 | def __init__(self, norm_layer=nn.BatchNorm2d):
11 | super(ECCVGenerator, self).__init__()
12 |
13 | model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),]
14 | model1+=[nn.ReLU(True),]
15 | model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),]
16 | model1+=[nn.ReLU(True),]
17 | model1+=[norm_layer(64),]
18 |
19 | model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
20 | model2+=[nn.ReLU(True),]
21 | model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),]
22 | model2+=[nn.ReLU(True),]
23 | model2+=[norm_layer(128),]
24 |
25 | model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
26 | model3+=[nn.ReLU(True),]
27 | model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
28 | model3+=[nn.ReLU(True),]
29 | model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),]
30 | model3+=[nn.ReLU(True),]
31 | model3+=[norm_layer(256),]
32 |
33 | model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
34 | model4+=[nn.ReLU(True),]
35 | model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
36 | model4+=[nn.ReLU(True),]
37 | model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
38 | model4+=[nn.ReLU(True),]
39 | model4+=[norm_layer(512),]
40 |
41 | model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
42 | model5+=[nn.ReLU(True),]
43 | model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
44 | model5+=[nn.ReLU(True),]
45 | model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
46 | model5+=[nn.ReLU(True),]
47 | model5+=[norm_layer(512),]
48 |
49 | model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
50 | model6+=[nn.ReLU(True),]
51 | model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
52 | model6+=[nn.ReLU(True),]
53 | model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
54 | model6+=[nn.ReLU(True),]
55 | model6+=[norm_layer(512),]
56 |
57 | model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
58 | model7+=[nn.ReLU(True),]
59 | model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
60 | model7+=[nn.ReLU(True),]
61 | model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
62 | model7+=[nn.ReLU(True),]
63 | model7+=[norm_layer(512),]
64 |
65 | model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),]
66 | model8+=[nn.ReLU(True),]
67 | model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
68 | model8+=[nn.ReLU(True),]
69 | model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
70 | model8+=[nn.ReLU(True),]
71 |
72 | model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),]
73 |
74 | self.model1 = nn.Sequential(*model1)
75 | self.model2 = nn.Sequential(*model2)
76 | self.model3 = nn.Sequential(*model3)
77 | self.model4 = nn.Sequential(*model4)
78 | self.model5 = nn.Sequential(*model5)
79 | self.model6 = nn.Sequential(*model6)
80 | self.model7 = nn.Sequential(*model7)
81 | self.model8 = nn.Sequential(*model8)
82 |
83 | self.softmax = nn.Softmax(dim=1)
84 | self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
85 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear')
86 |
87 | def forward(self, input_l):
88 | conv1_2 = self.model1(self.normalize_l(input_l))
89 | conv2_2 = self.model2(conv1_2)
90 | conv3_3 = self.model3(conv2_2)
91 | conv4_3 = self.model4(conv3_3)
92 | conv5_3 = self.model5(conv4_3)
93 | conv6_3 = self.model6(conv5_3)
94 | conv7_3 = self.model7(conv6_3)
95 | conv8_3 = self.model8(conv7_3)
96 | out_reg = self.model_out(self.softmax(conv8_3))
97 |
98 | return self.unnormalize_ab(self.upsample4(out_reg))
99 |
100 | def eccv16(pretrained=True):
101 | model = ECCVGenerator()
102 | if(pretrained):
103 | import torch.utils.model_zoo as model_zoo
104 | model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth',map_location='cpu',check_hash=True))
105 | return model
106 |
--------------------------------------------------------------------------------
/models/Colorful/options/test_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch.backends.cudnn as cudnn
3 | class TestOptions():
4 | def __init__(self):
5 | pass
6 |
7 | def parse(self):
8 | # initialize parser with basic options
9 | opt = {}
10 | return opt
--------------------------------------------------------------------------------
/models/Colorful/siggraph17.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .base_color import *
5 |
6 | class SIGGRAPHGenerator(BaseColor):
7 | def __init__(self, norm_layer=nn.BatchNorm2d, classes=529):
8 | super(SIGGRAPHGenerator, self).__init__()
9 |
10 | # Conv1
11 | model1=[nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True),]
12 | model1+=[nn.ReLU(True),]
13 | model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),]
14 | model1+=[nn.ReLU(True),]
15 | model1+=[norm_layer(64),]
16 | # add a subsampling operation
17 |
18 | # Conv2
19 | model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
20 | model2+=[nn.ReLU(True),]
21 | model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
22 | model2+=[nn.ReLU(True),]
23 | model2+=[norm_layer(128),]
24 | # add a subsampling layer operation
25 |
26 | # Conv3
27 | model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
28 | model3+=[nn.ReLU(True),]
29 | model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
30 | model3+=[nn.ReLU(True),]
31 | model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
32 | model3+=[nn.ReLU(True),]
33 | model3+=[norm_layer(256),]
34 | # add a subsampling layer operation
35 |
36 | # Conv4
37 | model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
38 | model4+=[nn.ReLU(True),]
39 | model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
40 | model4+=[nn.ReLU(True),]
41 | model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
42 | model4+=[nn.ReLU(True),]
43 | model4+=[norm_layer(512),]
44 |
45 | # Conv5
46 | model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
47 | model5+=[nn.ReLU(True),]
48 | model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
49 | model5+=[nn.ReLU(True),]
50 | model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
51 | model5+=[nn.ReLU(True),]
52 | model5+=[norm_layer(512),]
53 |
54 | # Conv6
55 | model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
56 | model6+=[nn.ReLU(True),]
57 | model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
58 | model6+=[nn.ReLU(True),]
59 | model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
60 | model6+=[nn.ReLU(True),]
61 | model6+=[norm_layer(512),]
62 |
63 | # Conv7
64 | model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
65 | model7+=[nn.ReLU(True),]
66 | model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
67 | model7+=[nn.ReLU(True),]
68 | model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
69 | model7+=[nn.ReLU(True),]
70 | model7+=[norm_layer(512),]
71 |
72 | # Conv7
73 | model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)]
74 | model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
75 |
76 | model8=[nn.ReLU(True),]
77 | model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
78 | model8+=[nn.ReLU(True),]
79 | model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
80 | model8+=[nn.ReLU(True),]
81 | model8+=[norm_layer(256),]
82 |
83 | # Conv9
84 | model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),]
85 | model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
86 | # add the two feature maps above
87 |
88 | model9=[nn.ReLU(True),]
89 | model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
90 | model9+=[nn.ReLU(True),]
91 | model9+=[norm_layer(128),]
92 |
93 | # Conv10
94 | model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),]
95 | model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
96 | # add the two feature maps above
97 |
98 | model10=[nn.ReLU(True),]
99 | model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True),]
100 | model10+=[nn.LeakyReLU(negative_slope=.2),]
101 |
102 | # classification output
103 | model_class=[nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
104 |
105 | # regression output
106 | model_out=[nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
107 | model_out+=[nn.Tanh()]
108 |
109 | self.model1 = nn.Sequential(*model1)
110 | self.model2 = nn.Sequential(*model2)
111 | self.model3 = nn.Sequential(*model3)
112 | self.model4 = nn.Sequential(*model4)
113 | self.model5 = nn.Sequential(*model5)
114 | self.model6 = nn.Sequential(*model6)
115 | self.model7 = nn.Sequential(*model7)
116 | self.model8up = nn.Sequential(*model8up)
117 | self.model8 = nn.Sequential(*model8)
118 | self.model9up = nn.Sequential(*model9up)
119 | self.model9 = nn.Sequential(*model9)
120 | self.model10up = nn.Sequential(*model10up)
121 | self.model10 = nn.Sequential(*model10)
122 | self.model3short8 = nn.Sequential(*model3short8)
123 | self.model2short9 = nn.Sequential(*model2short9)
124 | self.model1short10 = nn.Sequential(*model1short10)
125 |
126 | self.model_class = nn.Sequential(*model_class)
127 | self.model_out = nn.Sequential(*model_out)
128 |
129 | self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='bilinear'),])
130 | self.softmax = nn.Sequential(*[nn.Softmax(dim=1),])
131 |
132 | def forward(self, input_A, input_B=None, mask_B=None):
133 | if(input_B is None):
134 | input_B = torch.cat((input_A*0, input_A*0), dim=1)
135 | if(mask_B is None):
136 | mask_B = input_A*0
137 |
138 | conv1_2 = self.model1(torch.cat((self.normalize_l(input_A),self.normalize_ab(input_B),mask_B),dim=1))
139 | conv2_2 = self.model2(conv1_2[:,:,::2,::2])
140 | conv3_3 = self.model3(conv2_2[:,:,::2,::2])
141 | conv4_3 = self.model4(conv3_3[:,:,::2,::2])
142 | conv5_3 = self.model5(conv4_3)
143 | conv6_3 = self.model6(conv5_3)
144 | conv7_3 = self.model7(conv6_3)
145 |
146 | conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
147 | conv8_3 = self.model8(conv8_up)
148 | conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
149 | conv9_3 = self.model9(conv9_up)
150 | conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
151 | conv10_2 = self.model10(conv10_up)
152 | out_reg = self.model_out(conv10_2)
153 |
154 | conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
155 | conv9_3 = self.model9(conv9_up)
156 | conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
157 | conv10_2 = self.model10(conv10_up)
158 | out_reg = self.model_out(conv10_2)
159 |
160 | return self.unnormalize_ab(out_reg)
161 |
162 | def siggraph17(pretrained=True):
163 | model = SIGGRAPHGenerator()
164 | if(pretrained):
165 | import torch.utils.model_zoo as model_zoo
166 | model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True))
167 | return model
168 |
169 |
--------------------------------------------------------------------------------
/models/Colorful/util.py:
--------------------------------------------------------------------------------
1 |
2 | from PIL import Image
3 | import numpy as np
4 | from skimage import color
5 | import torch
6 | import torch.nn.functional as F
7 | from IPython import embed
8 |
9 | def load_img(img_path):
10 | out_np = np.asarray(Image.open(img_path))
11 | if(out_np.ndim==2):
12 | out_np = np.tile(out_np[:,:,None],3)
13 | return out_np
14 |
15 | def resize_img(img, HW=(256,256), resample=3):
16 | return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample))
17 |
18 | def preprocess_img(img_rgb_orig, HW=(256,256), resample=3):
19 | # return original size L and resized L as torch Tensors
20 | img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample)
21 |
22 | img_lab_orig = color.rgb2lab(img_rgb_orig)
23 | img_lab_rs = color.rgb2lab(img_rgb_rs)
24 |
25 | img_l_orig = img_lab_orig[:,:,0]
26 | img_l_rs = img_lab_rs[:,:,0]
27 |
28 | tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:]
29 | tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:]
30 |
31 | return (tens_orig_l, tens_rs_l)
32 |
33 | def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'):
34 | # tens_orig_l 1 x 1 x H_orig x W_orig
35 | # out_ab 1 x 2 x H x W
36 |
37 | HW_orig = tens_orig_l.shape[2:]
38 | HW = out_ab.shape[2:]
39 |
40 | # call resize function if needed
41 | if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]):
42 | out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear')
43 | else:
44 | out_ab_orig = out_ab
45 |
46 | out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1)
47 | return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0)))
48 |
--------------------------------------------------------------------------------
/models/DEVC/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import cv2
4 | import torch
5 | from PIL import Image
6 | from tqdm import tqdm
7 | import numpy as np
8 | import torchvision.transforms as transform_lib
9 |
10 | from utils.util import download_zipfile, mkdir
11 | from utils.v2i import convert_frames_to_video
12 | import models.DEVC.utils.lib.test_transforms as transforms
13 | from models.DEVC.utils.util_distortion import CenterPad, Normalize, RGB2Lab, ToTensor
14 | from models.DEVC.utils.util import batch_lab2rgb_transpose_mc, save_frames, tensor_lab2rgb, uncenter_l
15 | from models.DEVC.models.ColorVidNet import ColorVidNet
16 | from models.DEVC.models.FrameColor import frame_colorization
17 | from models.DEVC.models.NonlocalNet import VGG19_pytorch, WarpNet
18 |
19 | class DEVC():
20 | def __init__(self, pretrained=True):
21 | self.nonlocal_net = WarpNet(1)
22 | self.colornet = ColorVidNet(7)
23 | self.vggnet = VGG19_pytorch()
24 |
25 | if pretrained is True:
26 | download_zipfile("https://facevc.blob.core.windows.net/zhanbo/old_photo/colorization_checkpoint.zip", "DEVC_checkpoints.zip")
27 | self.vggnet.load_state_dict(torch.load("data/vgg19_conv.pth"))
28 | self.nonlocal_net.load_state_dict(torch.load("checkpoints/video_moredata_l1/nonlocal_net_iter_76000.pth"))
29 | self.colornet.load_state_dict(torch.load("checkpoints/video_moredata_l1/colornet_iter_76000.pth"))
30 |
31 | def test(self, input_path, output_path, opt):
32 | mkdir(opt.output_frame_path)
33 | # parameters for wls filter
34 | wls_filter_on = True
35 | lambda_value = 500
36 | sigma_color = 4
37 |
38 | # net
39 | self.nonlocal_net.eval()
40 | self.colornet.eval()
41 | self.vggnet.eval()
42 | self.nonlocal_net.cuda()
43 | self.colornet.cuda()
44 | self.vggnet.cuda()
45 | for param in self.vggnet.parameters():
46 | param.requires_grad = False
47 |
48 | # processing folders
49 | print("processing the folder:", input_path)
50 | _, _, filenames = os.walk(input_path).__next__()
51 | filenames.sort(key=lambda f: int("".join(filter(str.isdigit, f) or -1)))
52 |
53 | # NOTE: resize frames to 216*384
54 | transform = transforms.Compose(
55 | [CenterPad(opt.image_size), transform_lib.CenterCrop(opt.image_size), RGB2Lab(), ToTensor(), Normalize()]
56 | )
57 |
58 | # if frame propagation: use the first frame as reference
59 | # otherwise, use the specified reference image
60 | ref_name = os.path.join(input_path , filenames[0]) if opt.frame_propagate else opt.ref_path
61 | print("reference name:", ref_name)
62 | frame_ref = Image.open(ref_name)
63 |
64 | I_last_lab_predict = None
65 |
66 | IB_lab_large = transform(frame_ref).unsqueeze(0).cuda()
67 | IB_lab = torch.nn.functional.interpolate(IB_lab_large, scale_factor=0.5, mode="bilinear")
68 | IB_l = IB_lab[:, 0:1, :, :]
69 | IB_ab = IB_lab[:, 1:3, :, :]
70 | with torch.no_grad():
71 | I_reference_lab = IB_lab
72 | I_reference_l = I_reference_lab[:, 0:1, :, :]
73 | I_reference_ab = I_reference_lab[:, 1:3, :, :]
74 | I_reference_rgb = tensor_lab2rgb(torch.cat((uncenter_l(I_reference_l), I_reference_ab), dim=1))
75 | features_B = self.vggnet(I_reference_rgb, ["r12", "r22", "r32", "r42", "r52"], preprocess=True)
76 |
77 | for index, frame_name in enumerate(tqdm(filenames)):
78 | frame1 = Image.open(os.path.join(input_path, frame_name))
79 | IA_lab_large = transform(frame1).unsqueeze(0).cuda()
80 | IA_lab = torch.nn.functional.interpolate(IA_lab_large, scale_factor=0.5, mode="bilinear")
81 |
82 | IA_l = IA_lab[:, 0:1, :, :]
83 | IA_ab = IA_lab[:, 1:3, :, :]
84 |
85 | if I_last_lab_predict is None:
86 | if opt.frame_propagate:
87 | I_last_lab_predict = IB_lab
88 | else:
89 | I_last_lab_predict = torch.zeros_like(IA_lab).cuda()
90 |
91 | # start the frame colorization
92 | with torch.no_grad():
93 | I_current_lab = IA_lab
94 | I_current_ab_predict, I_current_nonlocal_lab_predict, features_current_gray = frame_colorization(
95 | I_current_lab,
96 | I_reference_lab,
97 | I_last_lab_predict,
98 | features_B,
99 | self.vggnet,
100 | self.nonlocal_net,
101 | self.colornet,
102 | feature_noise=0,
103 | temperature=1e-10,
104 | )
105 | I_last_lab_predict = torch.cat((IA_l, I_current_ab_predict), dim=1)
106 |
107 | # upsampling
108 | curr_bs_l = IA_lab_large[:, 0:1, :, :]
109 | curr_predict = (
110 | torch.nn.functional.interpolate(I_current_ab_predict.data.cpu(), scale_factor=2, mode="bilinear") * 1.25
111 | )
112 |
113 | # filtering
114 | if wls_filter_on:
115 | guide_image = uncenter_l(curr_bs_l) * 255 / 100
116 | wls_filter = cv2.ximgproc.createFastGlobalSmootherFilter(
117 | guide_image[0, 0, :, :].cpu().numpy().astype(np.uint8), lambda_value, sigma_color
118 | )
119 | curr_predict_a = wls_filter.filter(curr_predict[0, 0, :, :].cpu().numpy())
120 | curr_predict_b = wls_filter.filter(curr_predict[0, 1, :, :].cpu().numpy())
121 | curr_predict_a = torch.from_numpy(curr_predict_a).unsqueeze(0).unsqueeze(0)
122 | curr_predict_b = torch.from_numpy(curr_predict_b).unsqueeze(0).unsqueeze(0)
123 | curr_predict_filter = torch.cat((curr_predict_a, curr_predict_b), dim=1)
124 | IA_predict_rgb = batch_lab2rgb_transpose_mc(curr_bs_l[:32], curr_predict_filter[:32, ...])
125 | else:
126 | IA_predict_rgb = batch_lab2rgb_transpose_mc(curr_bs_l[:32], curr_predict[:32, ...])
127 |
128 | # save the frames
129 | save_frames(IA_predict_rgb, opt.output_frame_path, index)
130 |
131 | # output video
132 | convert_frames_to_video(opt.output_frame_path, output_path)
133 |
134 | shutil.rmtree("data")
135 | shutil.rmtree("checkpoints")
136 | shutil.rmtree(opt.output_frame_path)
137 |
138 | print("Task Complete!")
--------------------------------------------------------------------------------
/models/DEVC/models/ColorVidNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.parallel
4 |
5 |
6 | class ColorVidNet(nn.Module):
7 | def __init__(self, ic):
8 | super(ColorVidNet, self).__init__()
9 | self.conv1_1 = nn.Sequential(nn.Conv2d(ic, 32, 3, 1, 1), nn.ReLU(), nn.Conv2d(32, 64, 3, 1, 1))
10 | self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 1)
11 | self.conv1_2norm = nn.BatchNorm2d(64, affine=False)
12 | self.conv1_2norm_ss = nn.Conv2d(64, 64, 1, 2, bias=False, groups=64)
13 | self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 1)
14 | self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 1)
15 | self.conv2_2norm = nn.BatchNorm2d(128, affine=False)
16 | self.conv2_2norm_ss = nn.Conv2d(128, 128, 1, 2, bias=False, groups=128)
17 | self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1)
18 | self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1)
19 | self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 1)
20 | self.conv3_3norm = nn.BatchNorm2d(256, affine=False)
21 | self.conv3_3norm_ss = nn.Conv2d(256, 256, 1, 2, bias=False, groups=256)
22 | self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1)
23 | self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1)
24 | self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 1)
25 | self.conv4_3norm = nn.BatchNorm2d(512, affine=False)
26 | self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 2, 2)
27 | self.conv5_2 = nn.Conv2d(512, 512, 3, 1, 2, 2)
28 | self.conv5_3 = nn.Conv2d(512, 512, 3, 1, 2, 2)
29 | self.conv5_3norm = nn.BatchNorm2d(512, affine=False)
30 | self.conv6_1 = nn.Conv2d(512, 512, 3, 1, 2, 2)
31 | self.conv6_2 = nn.Conv2d(512, 512, 3, 1, 2, 2)
32 | self.conv6_3 = nn.Conv2d(512, 512, 3, 1, 2, 2)
33 | self.conv6_3norm = nn.BatchNorm2d(512, affine=False)
34 | self.conv7_1 = nn.Conv2d(512, 512, 3, 1, 1)
35 | self.conv7_2 = nn.Conv2d(512, 512, 3, 1, 1)
36 | self.conv7_3 = nn.Conv2d(512, 512, 3, 1, 1)
37 | self.conv7_3norm = nn.BatchNorm2d(512, affine=False)
38 | self.conv8_1 = nn.ConvTranspose2d(512, 256, 4, 2, 1)
39 | self.conv3_3_short = nn.Conv2d(256, 256, 3, 1, 1)
40 | self.conv8_2 = nn.Conv2d(256, 256, 3, 1, 1)
41 | self.conv8_3 = nn.Conv2d(256, 256, 3, 1, 1)
42 | self.conv8_3norm = nn.BatchNorm2d(256, affine=False)
43 | self.conv9_1 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
44 | self.conv2_2_short = nn.Conv2d(128, 128, 3, 1, 1)
45 | self.conv9_2 = nn.Conv2d(128, 128, 3, 1, 1)
46 | self.conv9_2norm = nn.BatchNorm2d(128, affine=False)
47 | self.conv10_1 = nn.ConvTranspose2d(128, 128, 4, 2, 1)
48 | self.conv1_2_short = nn.Conv2d(64, 128, 3, 1, 1)
49 | self.conv10_2 = nn.Conv2d(128, 128, 3, 1, 1)
50 | self.conv10_ab = nn.Conv2d(128, 2, 1, 1)
51 |
52 | # add self.relux_x
53 | self.relu1_1 = nn.ReLU()
54 | self.relu1_2 = nn.ReLU()
55 | self.relu2_1 = nn.ReLU()
56 | self.relu2_2 = nn.ReLU()
57 | self.relu3_1 = nn.ReLU()
58 | self.relu3_2 = nn.ReLU()
59 | self.relu3_3 = nn.ReLU()
60 | self.relu4_1 = nn.ReLU()
61 | self.relu4_2 = nn.ReLU()
62 | self.relu4_3 = nn.ReLU()
63 | self.relu5_1 = nn.ReLU()
64 | self.relu5_2 = nn.ReLU()
65 | self.relu5_3 = nn.ReLU()
66 | self.relu6_1 = nn.ReLU()
67 | self.relu6_2 = nn.ReLU()
68 | self.relu6_3 = nn.ReLU()
69 | self.relu7_1 = nn.ReLU()
70 | self.relu7_2 = nn.ReLU()
71 | self.relu7_3 = nn.ReLU()
72 | self.relu8_1_comb = nn.ReLU()
73 | self.relu8_2 = nn.ReLU()
74 | self.relu8_3 = nn.ReLU()
75 | self.relu9_1_comb = nn.ReLU()
76 | self.relu9_2 = nn.ReLU()
77 | self.relu10_1_comb = nn.ReLU()
78 | self.relu10_2 = nn.LeakyReLU(0.2, True)
79 |
80 | print("replace all deconv with [nearest + conv]")
81 | self.conv8_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(512, 256, 3, 1, 1))
82 | self.conv9_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(256, 128, 3, 1, 1))
83 | self.conv10_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode="nearest"), nn.Conv2d(128, 128, 3, 1, 1))
84 |
85 | print("replace all batchnorm with instancenorm")
86 | self.conv1_2norm = nn.InstanceNorm2d(64)
87 | self.conv2_2norm = nn.InstanceNorm2d(128)
88 | self.conv3_3norm = nn.InstanceNorm2d(256)
89 | self.conv4_3norm = nn.InstanceNorm2d(512)
90 | self.conv5_3norm = nn.InstanceNorm2d(512)
91 | self.conv6_3norm = nn.InstanceNorm2d(512)
92 | self.conv7_3norm = nn.InstanceNorm2d(512)
93 | self.conv8_3norm = nn.InstanceNorm2d(256)
94 | self.conv9_2norm = nn.InstanceNorm2d(128)
95 |
96 | def forward(self, x):
97 | """ x: gray image (1 channel), ab(2 channel), ab_err, ba_err"""
98 | conv1_1 = self.relu1_1(self.conv1_1(x))
99 | conv1_2 = self.relu1_2(self.conv1_2(conv1_1))
100 | conv1_2norm = self.conv1_2norm(conv1_2)
101 | conv1_2norm_ss = self.conv1_2norm_ss(conv1_2norm)
102 | conv2_1 = self.relu2_1(self.conv2_1(conv1_2norm_ss))
103 | conv2_2 = self.relu2_2(self.conv2_2(conv2_1))
104 | conv2_2norm = self.conv2_2norm(conv2_2)
105 | conv2_2norm_ss = self.conv2_2norm_ss(conv2_2norm)
106 | conv3_1 = self.relu3_1(self.conv3_1(conv2_2norm_ss))
107 | conv3_2 = self.relu3_2(self.conv3_2(conv3_1))
108 | conv3_3 = self.relu3_3(self.conv3_3(conv3_2))
109 | conv3_3norm = self.conv3_3norm(conv3_3)
110 | conv3_3norm_ss = self.conv3_3norm_ss(conv3_3norm)
111 | conv4_1 = self.relu4_1(self.conv4_1(conv3_3norm_ss))
112 | conv4_2 = self.relu4_2(self.conv4_2(conv4_1))
113 | conv4_3 = self.relu4_3(self.conv4_3(conv4_2))
114 | conv4_3norm = self.conv4_3norm(conv4_3)
115 | conv5_1 = self.relu5_1(self.conv5_1(conv4_3norm))
116 | conv5_2 = self.relu5_2(self.conv5_2(conv5_1))
117 | conv5_3 = self.relu5_3(self.conv5_3(conv5_2))
118 | conv5_3norm = self.conv5_3norm(conv5_3)
119 | conv6_1 = self.relu6_1(self.conv6_1(conv5_3norm))
120 | conv6_2 = self.relu6_2(self.conv6_2(conv6_1))
121 | conv6_3 = self.relu6_3(self.conv6_3(conv6_2))
122 | conv6_3norm = self.conv6_3norm(conv6_3)
123 | conv7_1 = self.relu7_1(self.conv7_1(conv6_3norm))
124 | conv7_2 = self.relu7_2(self.conv7_2(conv7_1))
125 | conv7_3 = self.relu7_3(self.conv7_3(conv7_2))
126 | conv7_3norm = self.conv7_3norm(conv7_3)
127 | conv8_1 = self.conv8_1(conv7_3norm)
128 | conv3_3_short = self.conv3_3_short(conv3_3norm)
129 | conv8_1_comb = self.relu8_1_comb(conv8_1 + conv3_3_short)
130 | conv8_2 = self.relu8_2(self.conv8_2(conv8_1_comb))
131 | conv8_3 = self.relu8_3(self.conv8_3(conv8_2))
132 | conv8_3norm = self.conv8_3norm(conv8_3)
133 | conv9_1 = self.conv9_1(conv8_3norm)
134 | conv2_2_short = self.conv2_2_short(conv2_2norm)
135 | conv9_1_comb = self.relu9_1_comb(conv9_1 + conv2_2_short)
136 | conv9_2 = self.relu9_2(self.conv9_2(conv9_1_comb))
137 | conv9_2norm = self.conv9_2norm(conv9_2)
138 | conv10_1 = self.conv10_1(conv9_2norm)
139 | conv1_2_short = self.conv1_2_short(conv1_2norm)
140 | conv10_1_comb = self.relu10_1_comb(conv10_1 + conv1_2_short)
141 | conv10_2 = self.relu10_2(self.conv10_2(conv10_1_comb))
142 | conv10_ab = self.conv10_ab(conv10_2)
143 |
144 | return torch.tanh(conv10_ab) * 128
145 |
--------------------------------------------------------------------------------
/models/DEVC/models/FrameColor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from models.DEVC.utils.util import *
3 |
4 |
5 | def warp_color(IA_l, IB_lab, features_B, vggnet, nonlocal_net, colornet, feature_noise=0, temperature=0.01):
6 | IA_rgb_from_gray = gray2rgb_batch(IA_l)
7 | with torch.no_grad():
8 | A_relu1_1, A_relu2_1, A_relu3_1, A_relu4_1, A_relu5_1 = vggnet(
9 | IA_rgb_from_gray, ["r12", "r22", "r32", "r42", "r52"], preprocess=True
10 | )
11 | B_relu1_1, B_relu2_1, B_relu3_1, B_relu4_1, B_relu5_1 = features_B
12 |
13 | # NOTE: output the feature before normalization
14 | features_A = [A_relu1_1, A_relu2_1, A_relu3_1, A_relu4_1, A_relu5_1]
15 |
16 | A_relu2_1 = feature_normalize(A_relu2_1)
17 | A_relu3_1 = feature_normalize(A_relu3_1)
18 | A_relu4_1 = feature_normalize(A_relu4_1)
19 | A_relu5_1 = feature_normalize(A_relu5_1)
20 | B_relu2_1 = feature_normalize(B_relu2_1)
21 | B_relu3_1 = feature_normalize(B_relu3_1)
22 | B_relu4_1 = feature_normalize(B_relu4_1)
23 | B_relu5_1 = feature_normalize(B_relu5_1)
24 |
25 | nonlocal_BA_lab, similarity_map = nonlocal_net(
26 | IB_lab,
27 | A_relu2_1,
28 | A_relu3_1,
29 | A_relu4_1,
30 | A_relu5_1,
31 | B_relu2_1,
32 | B_relu3_1,
33 | B_relu4_1,
34 | B_relu5_1,
35 | temperature=temperature,
36 | )
37 |
38 | return nonlocal_BA_lab, similarity_map, features_A
39 |
40 |
41 | def frame_colorization(
42 | IA_lab,
43 | IB_lab,
44 | IA_last_lab,
45 | features_B,
46 | vggnet,
47 | nonlocal_net,
48 | colornet,
49 | joint_training=True,
50 | feature_noise=0,
51 | luminance_noise=0,
52 | temperature=0.01,
53 | ):
54 |
55 | IA_l = IA_lab[:, 0:1, :, :]
56 | if luminance_noise:
57 | IA_l = IA_l + torch.randn_like(IA_l, requires_grad=False) * luminance_noise
58 |
59 | with torch.autograd.set_grad_enabled(joint_training):
60 | nonlocal_BA_lab, similarity_map, features_A_gray = warp_color(
61 | IA_l, IB_lab, features_B, vggnet, nonlocal_net, colornet, feature_noise, temperature=temperature
62 | )
63 | nonlocal_BA_ab = nonlocal_BA_lab[:, 1:3, :, :]
64 | color_input = torch.cat((IA_l, nonlocal_BA_ab, similarity_map, IA_last_lab), dim=1)
65 | IA_ab_predict = colornet(color_input)
66 |
67 | return IA_ab_predict, nonlocal_BA_lab, features_A_gray
68 |
--------------------------------------------------------------------------------
/models/DEVC/models/GAN_Models.py:
--------------------------------------------------------------------------------
1 | # DCGAN-like generator and discriminator
2 | import torch
3 | import torch.nn.functional as F
4 | from torch import nn
5 |
6 | from models.DEVC.models.spectral_normalization import SpectralNorm
7 |
8 |
9 | class Generator(nn.Module):
10 | def __init__(self, z_dim):
11 | super(Generator, self).__init__()
12 | self.z_dim = z_dim
13 |
14 | self.model = nn.Sequential(
15 | nn.ConvTranspose2d(z_dim, 512, 4, stride=1),
16 | nn.InstanceNorm2d(512),
17 | nn.ReLU(),
18 | nn.ConvTranspose2d(512, 256, 4, stride=2, padding=(1, 1)),
19 | nn.InstanceNorm2d(256),
20 | nn.ReLU(),
21 | nn.ConvTranspose2d(256, 128, 4, stride=2, padding=(1, 1)),
22 | nn.InstanceNorm2d(128),
23 | nn.ReLU(),
24 | nn.ConvTranspose2d(128, 64, 4, stride=2, padding=(1, 1)),
25 | nn.InstanceNorm2d(64),
26 | nn.ReLU(),
27 | nn.ConvTranspose2d(64, channels, 3, stride=1, padding=(1, 1)),
28 | nn.Tanh(),
29 | )
30 |
31 | def forward(self, z):
32 | return self.model(z.view(-1, self.z_dim, 1, 1))
33 |
34 |
35 | channels = 3
36 | leak = 0.1
37 | w_g = 4
38 |
39 |
40 | class Discriminator(nn.Module):
41 | def __init__(self):
42 | super(Discriminator, self).__init__()
43 |
44 | self.conv1 = SpectralNorm(nn.Conv2d(channels, 64, 3, stride=1, padding=(1, 1)))
45 | self.conv2 = SpectralNorm(nn.Conv2d(64, 64, 4, stride=2, padding=(1, 1)))
46 | self.conv3 = SpectralNorm(nn.Conv2d(64, 128, 3, stride=1, padding=(1, 1)))
47 | self.conv4 = SpectralNorm(nn.Conv2d(128, 128, 4, stride=2, padding=(1, 1)))
48 | self.conv5 = SpectralNorm(nn.Conv2d(128, 256, 3, stride=1, padding=(1, 1)))
49 | self.conv6 = SpectralNorm(nn.Conv2d(256, 256, 4, stride=2, padding=(1, 1)))
50 | self.conv7 = SpectralNorm(nn.Conv2d(256, 256, 3, stride=1, padding=(1, 1)))
51 | self.conv8 = SpectralNorm(nn.Conv2d(256, 512, 4, stride=2, padding=(1, 1)))
52 | self.fc = SpectralNorm(nn.Linear(w_g * w_g * 512, 1))
53 |
54 | def forward(self, x):
55 | m = x
56 | m = nn.LeakyReLU(leak)(self.conv1(m))
57 | m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(64)(self.conv2(m)))
58 | m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(128)(self.conv3(m)))
59 | m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(128)(self.conv4(m)))
60 | m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv5(m)))
61 | m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv6(m)))
62 | m = nn.LeakyReLU(leak)(nn.InstanceNorm2d(256)(self.conv7(m)))
63 | m = nn.LeakyReLU(leak)(self.conv8(m))
64 |
65 | return self.fc(m.view(-1, w_g * w_g * 512))
66 |
67 |
68 | class Self_Attention(nn.Module):
69 | """ Self attention Layer"""
70 |
71 | def __init__(self, in_dim):
72 | super(Self_Attention, self).__init__()
73 | self.chanel_in = in_dim
74 |
75 | self.query_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 1, kernel_size=1))
76 | self.key_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 1, kernel_size=1))
77 | self.value_conv = SpectralNorm(nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1))
78 | self.gamma = nn.Parameter(torch.zeros(1))
79 |
80 | self.softmax = nn.Softmax(dim=-1) #
81 |
82 | def forward(self, x):
83 | """
84 | inputs :
85 | x : input feature maps( B X C X W X H)
86 | returns :
87 | out : self attention value + input feature
88 | attention: B X N X N (N is Width*Height)
89 | """
90 | m_batchsize, C, width, height = x.size()
91 | proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X CX(N)
92 | proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)
93 | energy = torch.bmm(proj_query, proj_key) # transpose check
94 | attention = self.softmax(energy) # BX (N) X (N)
95 | proj_value = self.value_conv(x).view(m_batchsize, -1, width * height) # B X C X N
96 |
97 | out = torch.bmm(proj_value, attention.permute(0, 2, 1))
98 | out = out.view(m_batchsize, C, width, height)
99 |
100 | out = self.gamma * out + x
101 | return out
102 |
103 |
104 | class Discriminator_x64(nn.Module):
105 | """
106 | Discriminative Network
107 | """
108 |
109 | def __init__(self, in_size=6, ndf=64):
110 | super(Discriminator_x64, self).__init__()
111 | self.in_size = in_size
112 | self.ndf = ndf
113 |
114 | self.layer1 = nn.Sequential(
115 | SpectralNorm(nn.Conv2d(self.in_size, self.ndf, 4, 2, 1)), nn.LeakyReLU(0.2, inplace=True)
116 | )
117 | self.layer2 = nn.Sequential(
118 | SpectralNorm(nn.Conv2d(self.ndf, self.ndf, 4, 2, 1)),
119 | nn.InstanceNorm2d(self.ndf),
120 | nn.LeakyReLU(0.2, inplace=True),
121 | )
122 | self.attention = Self_Attention(self.ndf)
123 | self.layer3 = nn.Sequential(
124 | SpectralNorm(nn.Conv2d(self.ndf, self.ndf * 2, 4, 2, 1)),
125 | nn.InstanceNorm2d(self.ndf * 2),
126 | nn.LeakyReLU(0.2, inplace=True),
127 | )
128 | self.layer4 = nn.Sequential(
129 | SpectralNorm(nn.Conv2d(self.ndf * 2, self.ndf * 4, 4, 2, 1)),
130 | nn.InstanceNorm2d(self.ndf * 4),
131 | nn.LeakyReLU(0.2, inplace=True),
132 | )
133 | self.layer5 = nn.Sequential(
134 | SpectralNorm(nn.Conv2d(self.ndf * 4, self.ndf * 8, 4, 2, 1)),
135 | nn.InstanceNorm2d(self.ndf * 8),
136 | nn.LeakyReLU(0.2, inplace=True),
137 | )
138 | self.layer6 = nn.Sequential(
139 | SpectralNorm(nn.Conv2d(self.ndf * 8, self.ndf * 16, 4, 2, 1)),
140 | nn.InstanceNorm2d(self.ndf * 16),
141 | nn.LeakyReLU(0.2, inplace=True),
142 | )
143 |
144 | self.last = SpectralNorm(nn.Conv2d(self.ndf * 16, 1, [3, 6], 1, 0))
145 |
146 | def forward(self, input):
147 | feature1 = self.layer1(input)
148 | feature2 = self.layer2(feature1)
149 | feature_attention = self.attention(feature2)
150 | feature3 = self.layer3(feature_attention)
151 | feature4 = self.layer4(feature3)
152 | feature5 = self.layer5(feature4)
153 | feature6 = self.layer6(feature5)
154 | output = self.last(feature6)
155 | output = F.avg_pool2d(output, output.size()[2:]).view(output.size()[0], -1)
156 |
157 | return output, feature4
158 |
--------------------------------------------------------------------------------
/models/DEVC/models/spectral_normalization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import Parameter
4 |
5 |
6 | def l2normalize(v, eps=1e-12):
7 | return v / (v.norm() + eps)
8 |
9 |
10 | class SpectralNorm(nn.Module):
11 | def __init__(self, module, name="weight", power_iterations=1):
12 | super(SpectralNorm, self).__init__()
13 | self.module = module
14 | self.name = name
15 | self.power_iterations = power_iterations
16 | if not self._made_params():
17 | self._make_params()
18 |
19 | def _update_u_v(self):
20 | u = getattr(self.module, self.name + "_u")
21 | v = getattr(self.module, self.name + "_v")
22 | w = getattr(self.module, self.name + "_bar")
23 |
24 | height = w.data.shape[0]
25 | for _ in range(self.power_iterations):
26 | v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
27 | u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
28 |
29 | sigma = u.dot(w.view(height, -1).mv(v))
30 | setattr(self.module, self.name, w / sigma.expand_as(w))
31 |
32 | def _made_params(self):
33 | try:
34 | u = getattr(self.module, self.name + "_u")
35 | v = getattr(self.module, self.name + "_v")
36 | w = getattr(self.module, self.name + "_bar")
37 | return True
38 | except AttributeError:
39 | return False
40 |
41 | def _make_params(self):
42 | w = getattr(self.module, self.name)
43 |
44 | height = w.data.shape[0]
45 | width = w.view(height, -1).data.shape[1]
46 |
47 | u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
48 | v = Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
49 | u.data = l2normalize(u.data)
50 | v.data = l2normalize(v.data)
51 | w_bar = Parameter(w.data)
52 |
53 | del self.module._parameters[self.name]
54 |
55 | self.module.register_parameter(self.name + "_u", u)
56 | self.module.register_parameter(self.name + "_v", v)
57 | self.module.register_parameter(self.name + "_bar", w_bar)
58 |
59 | def forward(self, *args):
60 | self._update_u_v()
61 | return self.module.forward(*args)
62 |
--------------------------------------------------------------------------------
/models/DEVC/models/vgg19_gray.py:
--------------------------------------------------------------------------------
1 | from functools import reduce
2 |
3 | import torch
4 | import torch.nn as nn
5 |
6 |
7 | class LambdaBase(nn.Sequential):
8 | def __init__(self, fn, *args):
9 | super(LambdaBase, self).__init__(*args)
10 | self.lambda_func = fn
11 |
12 | def forward_prepare(self, input):
13 | output = []
14 | for module in self._modules.values():
15 | output.append(module(input))
16 |
17 | return output if output else input
18 |
19 |
20 | class Lambda(LambdaBase):
21 | def forward(self, input):
22 | return self.lambda_func(self.forward_prepare(input))
23 |
24 |
25 | class LambdaMap(LambdaBase):
26 | def forward(self, input):
27 | return list(map(self.lambda_func, self.forward_prepare(input)))
28 |
29 |
30 | class LambdaReduce(LambdaBase):
31 | def forward(self, input):
32 | return reduce(self.lambda_func, self.forward_prepare(input))
33 |
34 | layer_names = [
35 | "conv1_1",
36 | "relu1_1",
37 | "conv1_2",
38 | "relu1_2",
39 | "pool1",
40 | "conv2_1",
41 | "relu2_1",
42 | "conv2_2",
43 | "relu2_2",
44 | "pool2",
45 | "conv3_1",
46 | "relu3_1",
47 | "conv3_2",
48 | "relu3_2",
49 | "conv3_3",
50 | "relu3_3",
51 | "conv3_4",
52 | "relu3_4",
53 | "pool3",
54 | "conv4_1",
55 | "relu4_1",
56 | "conv4_2",
57 | "relu4_2",
58 | "conv4_3",
59 | "relu4_3",
60 | "conv4_4",
61 | "relu4_4",
62 | "pool4",
63 | "conv5_1",
64 | "relu5_1",
65 | "conv5_2",
66 | "relu5_2",
67 | "conv5_3",
68 | "relu5_3",
69 | "conv5_4",
70 | "relu5_4",
71 | "pool5",
72 | "view1",
73 | "fc6",
74 | "fc6_relu",
75 | "fc7",
76 | "fc7_relu",
77 | "fc8",
78 | ]
79 |
80 | def get_pretrained_vgg():
81 | model = nn.Sequential( # Sequential,
82 | nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1)),
83 | nn.ReLU(),
84 | nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1)),
85 | nn.ReLU(),
86 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
87 | nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1)),
88 | nn.ReLU(),
89 | nn.Conv2d(128, 128, (3, 3), (1, 1), (1, 1)),
90 | nn.ReLU(),
91 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
92 | nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1)),
93 | nn.ReLU(),
94 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
95 | nn.ReLU(),
96 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
97 | nn.ReLU(),
98 | nn.Conv2d(256, 256, (3, 3), (1, 1), (1, 1)),
99 | nn.ReLU(),
100 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
101 | nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1)),
102 | nn.ReLU(),
103 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)),
104 | nn.ReLU(),
105 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)),
106 | nn.ReLU(),
107 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)),
108 | nn.ReLU(),
109 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
110 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)),
111 | nn.ReLU(),
112 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)),
113 | nn.ReLU(),
114 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)),
115 | nn.ReLU(),
116 | nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1)),
117 | nn.ReLU(),
118 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
119 | Lambda(lambda x: x.view(x.size(0), -1)), # View,
120 | nn.Sequential(Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x), nn.Linear(25088, 4096)), # Linear,
121 | nn.ReLU(),
122 | nn.Sequential(Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x), nn.Linear(4096, 4096)), # Linear,
123 | nn.ReLU(),
124 | nn.Sequential(Lambda(lambda x: x.view(1, -1) if 1 == len(x.size()) else x), nn.Linear(4096, 1000)), # Linear,
125 | )
126 |
127 |
128 | model.load_state_dict(torch.load("data/vgg19_gray.pth"))
129 | vgg19_gray_net = torch.nn.Sequential()
130 | for (name, layer) in model._modules.items():
131 | vgg19_gray_net.add_module(layer_names[int(name)], model[int(name)])
132 |
133 | for param in vgg19_gray_net.parameters():
134 | param.requires_grad = False
135 | vgg19_gray_net.eval()
136 | return vgg19_gray_net
137 |
138 |
139 | class vgg19_gray(torch.nn.Module):
140 | def __init__(self, requires_grad=False):
141 | super(vgg19_gray, self).__init__()
142 | vgg_pretrained_features = get_pretrained_vgg()
143 | self.slice1 = torch.nn.Sequential()
144 | self.slice2 = torch.nn.Sequential()
145 | self.slice3 = torch.nn.Sequential()
146 | for x in range(12):
147 | self.slice1.add_module(layer_names[x], vgg_pretrained_features[x])
148 | for x in range(12, 21):
149 | self.slice2.add_module(layer_names[x], vgg_pretrained_features[x])
150 | for x in range(21, 30):
151 | self.slice3.add_module(layer_names[x], vgg_pretrained_features[x])
152 | if not requires_grad:
153 | for param in self.parameters():
154 | param.requires_grad = False
155 |
156 | def forward(self, X):
157 | h = self.slice1(X)
158 | h_relu3_1 = h
159 | h = self.slice2(h)
160 | h_relu4_1 = h
161 | h = self.slice3(h)
162 | h_relu5_1 = h
163 | return h_relu3_1, h_relu4_1, h_relu5_1
164 |
165 |
166 | class vgg19_gray_new(torch.nn.Module):
167 | def __init__(self, requires_grad=False):
168 | super(vgg19_gray_new, self).__init__()
169 | vgg_pretrained_features = get_pretrained_vgg()
170 | self.slice0 = torch.nn.Sequential()
171 | self.slice1 = torch.nn.Sequential()
172 | self.slice2 = torch.nn.Sequential()
173 | self.slice3 = torch.nn.Sequential()
174 | for x in range(7):
175 | self.slice0.add_module(layer_names[x], vgg_pretrained_features[x])
176 | for x in range(7, 12):
177 | self.slice1.add_module(layer_names[x], vgg_pretrained_features[x])
178 | for x in range(12, 21):
179 | self.slice2.add_module(layer_names[x], vgg_pretrained_features[x])
180 | for x in range(21, 30):
181 | self.slice3.add_module(layer_names[x], vgg_pretrained_features[x])
182 | if not requires_grad:
183 | for param in self.parameters():
184 | param.requires_grad = False
185 |
186 | def forward(self, X):
187 | h = self.slice0(X)
188 | h_relu2_1 = h
189 | h = self.slice1(h)
190 | h_relu3_1 = h
191 | h = self.slice2(h)
192 | h_relu4_1 = h
193 | h = self.slice3(h)
194 | h_relu5_1 = h
195 | return h_relu2_1, h_relu3_1, h_relu4_1, h_relu5_1
196 |
--------------------------------------------------------------------------------
/models/DEVC/options/test_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch.backends.cudnn as cudnn
3 | class TestOptions():
4 | def __init__(self):
5 | self.initialized = False
6 |
7 | def initialize(self, parser):
8 | parser.add_argument(
9 | "--frame_propagate", default=True, type=bool, help="propagation mode, , please check the paper"
10 | )
11 | parser.add_argument("--image_size", type=int, default=[216 * 2, 384 * 2], help="the image size, eg. [216,384]")
12 | parser.add_argument("--cuda", action="store_false")
13 | parser.add_argument("--gpu_ids", type=str, default="0", help="separate by comma")
14 | parser.add_argument("--clip_path", type=str, default="../test/input.mp4", help="path of input clips")
15 | parser.add_argument("--ref_path", type=str, default="../test/frame00000.jpg", help="path of refernce images")
16 | parser.add_argument("--output_frame_path", type=str, default="../test/results", help="path of output colorized frames")
17 |
18 | self.initialized = True
19 | return parser
20 |
21 | def parse(self):
22 | # initialize parser with basic options
23 | if not self.initialized:
24 | parser = argparse.ArgumentParser(
25 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
26 | parser = self.initialize(parser)
27 | opt = parser.parse_args()
28 |
29 | opt.gpu_ids = [int(x) for x in opt.gpu_ids.split(",")]
30 | cudnn.benchmark = True
31 | print("running on GPU", opt.gpu_ids)
32 | return opt
--------------------------------------------------------------------------------
/models/DEVC/utils/lib/test_transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import collections
4 | import numbers
5 | import random
6 |
7 | import torch
8 | from PIL import Image
9 | from skimage import color
10 |
11 | import models.DEVC.utils.lib.functional as F
12 |
13 | __all__ = [
14 | "Compose",
15 | "Concatenate",
16 | "ToTensor",
17 | "Normalize",
18 | "Resize",
19 | "Scale",
20 | "CenterCrop",
21 | "Pad",
22 | "RandomCrop",
23 | "RandomHorizontalFlip",
24 | "RandomVerticalFlip",
25 | "RandomResizedCrop",
26 | "RandomSizedCrop",
27 | "FiveCrop",
28 | "TenCrop",
29 | "RGB2Lab",
30 | ]
31 |
32 |
33 | def CustomFunc(inputs, func, *args, **kwargs):
34 | im_l = func(inputs[0], *args, **kwargs)
35 | im_ab = func(inputs[1], *args, **kwargs)
36 | warp_ba = func(inputs[2], *args, **kwargs)
37 | warp_aba = func(inputs[3], *args, **kwargs)
38 | # im_gbl_ab = func(inputs[4], *args, **kwargs)
39 | # bgr_mc_im = func(inputs[5], *args, **kwargs)
40 | layer_data = [im_l, im_ab, warp_ba, warp_aba]
41 | # layer_data = [im_l, im_ab, warp_ba, warp_aba, im_gbl_ab, bgr_mc_im]
42 | for l in range(5):
43 | layer = inputs[4 + l]
44 | err_ba = func(layer[0], *args, **kwargs)
45 | err_ab = func(layer[1], *args, **kwargs)
46 |
47 | layer_data.append([err_ba, err_ab])
48 |
49 | return layer_data
50 |
51 |
52 | class Compose(object):
53 | """Composes several transforms together.
54 |
55 | Args:
56 | transforms (list of ``Transform`` objects): list of transforms to compose.
57 |
58 | Example:
59 | >>> transforms.Compose([
60 | >>> transforms.CenterCrop(10),
61 | >>> transforms.ToTensor(),
62 | >>> ])
63 | """
64 |
65 | def __init__(self, transforms):
66 | self.transforms = transforms
67 |
68 | def __call__(self, inputs):
69 | for t in self.transforms:
70 | inputs = t(inputs)
71 | return inputs
72 |
73 |
74 | class Concatenate(object):
75 | """
76 | Input: [im_l, im_ab, inputs]
77 | inputs = [warp_ba_l, warp_ba_ab, warp_aba, err_pm, err_aba]
78 |
79 | Output:[im_l, err_pm, warp_ba, warp_aba, im_ab, err_aba]
80 | """
81 |
82 | def __call__(self, inputs):
83 | im_l = inputs[0]
84 | im_ab = inputs[1]
85 | warp_ba = inputs[2]
86 | warp_aba = inputs[3]
87 | # im_glb_ab = inputs[4]
88 | # bgr_mc_im = inputs[5]
89 | # bgr_mc_im = bgr_mc_im[[2, 1, 0], ...]
90 |
91 | err_ba = []
92 | err_ab = []
93 |
94 | for l in range(5):
95 | layer = inputs[4 + l]
96 | err_ba.append(layer[0])
97 | err_ab.append(layer[1])
98 |
99 | cerr_ba = torch.cat(err_ba, 0)
100 | cerr_ab = torch.cat(err_ab, 0)
101 |
102 | return (im_l, cerr_ba, warp_ba, warp_aba, im_ab, cerr_ab)
103 | # return (im_l, cerr_ba, warp_ba, warp_aba, im_glb_ab, bgr_mc_im, im_ab, cerr_ab)
104 |
105 |
106 | class ToTensor(object):
107 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
108 |
109 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range
110 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
111 | """
112 |
113 | def __call__(self, inputs):
114 | """
115 | Args:
116 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
117 |
118 | Returns:
119 | Tensor: Converted image.
120 | """
121 | inputs = CustomFunc(inputs, F.to_mytensor)
122 | return inputs
123 |
124 |
125 | class Normalize(object):
126 | """Normalize an tensor image with mean and standard deviation.
127 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
128 | will normalize each channel of the input ``torch.*Tensor`` i.e.
129 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
130 |
131 | Args:
132 | mean (sequence): Sequence of means for each channel.
133 | std (sequence): Sequence of standard deviations for each channel.
134 | """
135 |
136 | def __call__(self, inputs):
137 | """
138 | Args:
139 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
140 |
141 | Returns:
142 | Tensor: Normalized Tensor image.
143 | """
144 |
145 | im_l = F.normalize(inputs[0], 50, 1) # [0, 100]
146 | im_ab = F.normalize(inputs[1], (0, 0), (1, 1)) # [-100, 100]
147 |
148 | inputs[2][0:1, :, :] = F.normalize(inputs[2][0:1, :, :], 50, 1)
149 | inputs[2][1:3, :, :] = F.normalize(inputs[2][1:3, :, :], (0, 0), (1, 1))
150 | warp_ba = inputs[2]
151 |
152 | inputs[3][0:1, :, :] = F.normalize(inputs[3][0:1, :, :], 50, 1)
153 | inputs[3][1:3, :, :] = F.normalize(inputs[3][1:3, :, :], (0, 0), (1, 1))
154 | warp_aba = inputs[3]
155 |
156 | # im_gbl_ab = F.normalize(inputs[4], (0, 0), (1, 1)) # [-100, 100]
157 | #
158 | # bgr_mc_im = F.normalize(inputs[5], (123.68, 116.78, 103.938), (1, 1, 1))
159 |
160 | # layer_data = [im_l, im_ab, warp_ba, warp_aba, im_gbl_ab, bgr_mc_im]
161 | layer_data = [im_l, im_ab, warp_ba, warp_aba]
162 |
163 | for l in range(5):
164 | layer = inputs[4 + l]
165 | err_ba = F.normalize(layer[0], 127, 2) # [0, 255]
166 | err_ab = F.normalize(layer[1], 127, 2) # [0, 255]
167 | layer_data.append([err_ba, err_ab])
168 |
169 | return layer_data
170 |
171 |
172 | class Resize(object):
173 | """Resize the input PIL Image to the given size.
174 |
175 | Args:
176 | size (sequence or int): Desired output size. If size is a sequence like
177 | (h, w), output size will be matched to this. If size is an int,
178 | smaller edge of the image will be matched to this number.
179 | i.e, if height > width, then image will be rescaled to
180 | (size * height / width, size)
181 | interpolation (int, optional): Desired interpolation. Default is
182 | ``PIL.Image.BILINEAR``
183 | """
184 |
185 | def __init__(self, size, interpolation=Image.BILINEAR):
186 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
187 | self.size = size
188 | self.interpolation = interpolation
189 |
190 | def __call__(self, inputs):
191 | """
192 | Args:
193 | img (PIL Image): Image to be scaled.
194 |
195 | Returns:
196 | PIL Image: Rescaled image.
197 | """
198 | return CustomFunc(inputs, F.resize, self.size, self.interpolation)
199 |
200 |
201 | class RandomCrop(object):
202 | """Crop the given PIL Image at a random location.
203 |
204 | Args:
205 | size (sequence or int): Desired output size of the crop. If size is an
206 | int instead of sequence like (h, w), a square crop (size, size) is
207 | made.
208 | padding (int or sequence, optional): Optional padding on each border
209 | of the image. Default is 0, i.e no padding. If a sequence of length
210 | 4 is provided, it is used to pad left, top, right, bottom borders
211 | respectively.
212 | """
213 |
214 | def __init__(self, size, padding=0):
215 | if isinstance(size, numbers.Number):
216 | self.size = (int(size), int(size))
217 | else:
218 | self.size = size
219 | self.padding = padding
220 |
221 | @staticmethod
222 | def get_params(img, output_size):
223 | """Get parameters for ``crop`` for a random crop.
224 |
225 | Args:
226 | img (PIL Image): Image to be cropped.
227 | output_size (tuple): Expected output size of the crop.
228 |
229 | Returns:
230 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
231 | """
232 | w, h = img.size
233 | th, tw = output_size
234 | if w == tw and h == th:
235 | return 0, 0, h, w
236 |
237 | i = random.randint(0, h - th)
238 | j = random.randint(0, w - tw)
239 | return i, j, th, tw
240 |
241 | def __call__(self, inputs):
242 | """
243 | Args:
244 | img (PIL Image): Image to be cropped.
245 |
246 | Returns:
247 | PIL Image: Cropped image.
248 | """
249 | if self.padding > 0:
250 | inputs = CustomFunc(inputs, F.pad, self.padding)
251 |
252 | i, j, h, w = self.get_params(inputs[0], self.size)
253 | return CustomFunc(inputs, F.crop, i, j, h, w)
254 |
255 |
256 | class CenterCrop(object):
257 | """Crop the given PIL Image at a random location.
258 |
259 | Args:
260 | size (sequence or int): Desired output size of the crop. If size is an
261 | int instead of sequence like (h, w), a square crop (size, size) is
262 | made.
263 | padding (int or sequence, optional): Optional padding on each border
264 | of the image. Default is 0, i.e no padding. If a sequence of length
265 | 4 is provided, it is used to pad left, top, right, bottom borders
266 | respectively.
267 | """
268 |
269 | def __init__(self, size, padding=0):
270 | if isinstance(size, numbers.Number):
271 | self.size = (int(size), int(size))
272 | else:
273 | self.size = size
274 | self.padding = padding
275 |
276 | @staticmethod
277 | def get_params(img, output_size):
278 | """Get parameters for ``crop`` for a random crop.
279 |
280 | Args:
281 | img (PIL Image): Image to be cropped.
282 | output_size (tuple): Expected output size of the crop.
283 |
284 | Returns:
285 | tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
286 | """
287 | w, h = img.size
288 | th, tw = output_size
289 | if w == tw and h == th:
290 | return 0, 0, h, w
291 |
292 | i = (h - th) // 2
293 | j = (w - tw) // 2
294 | return i, j, th, tw
295 |
296 | def __call__(self, inputs):
297 | """
298 | Args:
299 | img (PIL Image): Image to be cropped.
300 |
301 | Returns:
302 | PIL Image: Cropped image.
303 | """
304 | if self.padding > 0:
305 | inputs = CustomFunc(inputs, F.pad, self.padding)
306 |
307 | if type(inputs) is list:
308 | i, j, h, w = self.get_params(inputs[0], self.size)
309 | else:
310 | i, j, h, w = self.get_params(inputs, self.size)
311 | return CustomFunc(inputs, F.crop, i, j, h, w)
312 |
313 |
314 | class RandomHorizontalFlip(object):
315 | """Horizontally flip the given PIL Image randomly with a probability of 0.5."""
316 |
317 | def __call__(self, inputs):
318 | """
319 | Args:
320 | img (PIL Image): Image to be flipped.
321 |
322 | Returns:
323 | PIL Image: Randomly flipped image.
324 | """
325 |
326 | if random.random() < 0.5:
327 | return CustomFunc(inputs, F.hflip)
328 | return inputs
329 |
330 |
331 | class RGB2Lab(object):
332 | def __call__(self, inputs):
333 | """
334 | Args:
335 | img (PIL Image): Image to be flipped.
336 |
337 | Returns:
338 | PIL Image: Randomly flipped image.
339 | """
340 |
341 | def __call__(self, inputs):
342 | image_lab = color.rgb2lab(inputs[0])
343 | warp_ba_lab = color.rgb2lab(inputs[2])
344 | warp_aba_lab = color.rgb2lab(inputs[3])
345 | # im_gbl_lab = color.rgb2lab(inputs[4])
346 |
347 | inputs[0] = image_lab[:, :, :1] # l channel
348 | inputs[1] = image_lab[:, :, 1:] # ab channel
349 | inputs[2] = warp_ba_lab # lab channel
350 | inputs[3] = warp_aba_lab # lab channel
351 | # inputs[4] = im_gbl_lab[:, :, 1:] # ab channel
352 |
353 | return inputs
354 |
--------------------------------------------------------------------------------
/models/DEVC/utils/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision
6 | from utils.util import feature_normalize
7 |
8 | postpa = torchvision.transforms.Compose(
9 | [
10 | torchvision.transforms.Lambda(lambda x: x.mul_(1.0 / 255)),
11 | torchvision.transforms.Normalize(
12 | mean=[-0.40760392, -0.45795686, -0.48501961], std=[1, 1, 1] # add imagenet mean
13 | ),
14 | torchvision.transforms.Lambda(lambda x: x[torch.LongTensor([2, 1, 0])]), # turn to RGB
15 | ]
16 | )
17 | postpb = torchvision.transforms.Compose([torchvision.transforms.ToPILImage()])
18 |
19 |
20 | def post_processing(tensor):
21 | t = postpa(tensor) # denormalize the image since the optimized tensor is the normalized one
22 | t[t > 1] = 1
23 | t[t < 0] = 0
24 | img = postpb(t)
25 | img = np.array(img)
26 | return img
27 |
28 |
29 | class ContextualLoss(nn.Module):
30 | """
31 | input is Al, Bl, channel = 1, range ~ [0, 255]
32 | """
33 |
34 | def __init__(self):
35 | super(ContextualLoss, self).__init__()
36 | return None
37 |
38 | def forward(self, X_features, Y_features, h=0.1, feature_centering=True):
39 | """
40 | X_features&Y_features are are feature vectors or feature 2d array
41 | h: bandwidth
42 | return the per-sample loss
43 | """
44 | batch_size = X_features.shape[0]
45 | feature_depth = X_features.shape[1]
46 | feature_size = X_features.shape[2]
47 |
48 | # to normalized feature vectors
49 | if feature_centering:
50 | X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(
51 | dim=-1
52 | ).unsqueeze(dim=-1)
53 | Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(
54 | dim=-1
55 | ).unsqueeze(dim=-1)
56 | X_features = feature_normalize(X_features).view(
57 | batch_size, feature_depth, -1
58 | ) # batch_size * feature_depth * feature_size^2
59 | Y_features = feature_normalize(Y_features).view(
60 | batch_size, feature_depth, -1
61 | ) # batch_size * feature_depth * feature_size^2
62 |
63 | # conine distance = 1 - similarity
64 | X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth
65 | d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2
66 |
67 | # normalized distance: dij_bar
68 | d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2
69 |
70 | # pairwise affinity
71 | w = torch.exp((1 - d_norm) / h)
72 | A_ij = w / torch.sum(w, dim=-1, keepdim=True)
73 |
74 | # contextual loss per sample
75 | CX = torch.mean(torch.max(A_ij, dim=1)[0], dim=-1)
76 | return -torch.log(CX)
77 |
78 |
79 | class ContextualLoss_forward(nn.Module):
80 | """
81 | input is Al, Bl, channel = 1, range ~ [0, 255]
82 | """
83 |
84 | def __init__(self):
85 | super(ContextualLoss_forward, self).__init__()
86 | return None
87 |
88 | def forward(self, X_features, Y_features, h=0.1, feature_centering=True):
89 | """
90 | X_features&Y_features are are feature vectors or feature 2d array
91 | h: bandwidth
92 | return the per-sample loss
93 | """
94 | batch_size = X_features.shape[0]
95 | feature_depth = X_features.shape[1]
96 | feature_size = X_features.shape[2]
97 |
98 | # to normalized feature vectors
99 | if feature_centering:
100 | X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(
101 | dim=-1
102 | ).unsqueeze(dim=-1)
103 | Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(
104 | dim=-1
105 | ).unsqueeze(dim=-1)
106 | X_features = feature_normalize(X_features).view(
107 | batch_size, feature_depth, -1
108 | ) # batch_size * feature_depth * feature_size^2
109 | Y_features = feature_normalize(Y_features).view(
110 | batch_size, feature_depth, -1
111 | ) # batch_size * feature_depth * feature_size^2
112 |
113 | # conine distance = 1 - similarity
114 | X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth
115 | d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2
116 |
117 | # normalized distance: dij_bar
118 | d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2
119 |
120 | # pairwise affinity
121 | w = torch.exp((1 - d_norm) / h)
122 | A_ij = w / torch.sum(w, dim=-1, keepdim=True)
123 |
124 | # contextual loss per sample
125 | CX = torch.mean(torch.max(A_ij, dim=-1)[0], dim=1)
126 | return -torch.log(CX)
127 |
128 |
129 | class ContextualLoss_complex(nn.Module):
130 | """
131 | input is Al, Bl, channel = 1, range ~ [0, 255]
132 | """
133 |
134 | def __init__(self):
135 | super(ContextualLoss_complex, self).__init__()
136 | return None
137 |
138 | def forward(self, X_features, Y_features, h=0.1, patch_size=1, direction="forward"):
139 | """
140 | X_features&Y_features are are feature vectors or feature 2d array
141 | h: bandwidth
142 | return the per-sample loss
143 | """
144 | batch_size = X_features.shape[0]
145 | feature_depth = X_features.shape[1]
146 | feature_size = X_features.shape[2]
147 |
148 | # to normalized feature vectors
149 | X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(
150 | dim=-1
151 | ).unsqueeze(dim=-1)
152 | Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(
153 | dim=-1
154 | ).unsqueeze(dim=-1)
155 | X_features = feature_normalize(X_features) # batch_size * feature_depth * feature_size^2
156 | Y_features = feature_normalize(Y_features) # batch_size * feature_depth * feature_size^2
157 |
158 | # to normalized feature vectors
159 | X_features = F.unfold(
160 | X_features, kernel_size=(patch_size, patch_size), stride=(1, 1), padding=(patch_size // 2, patch_size // 2)
161 | ) # batch_size * feature_depth_new * feature_size^2
162 | Y_features = F.unfold(
163 | Y_features, kernel_size=(patch_size, patch_size), stride=(1, 1), padding=(patch_size // 2, patch_size // 2)
164 | ) # batch_size * feature_depth_new * feature_size^2
165 |
166 | # conine distance = 1 - similarity
167 | X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth
168 | d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2
169 |
170 | # normalized distance: dij_bar
171 | d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-5) # batch_size * feature_size^2 * feature_size^2
172 |
173 | # pairwise affinity
174 | w = torch.exp((1 - d_norm) / h)
175 | A_ij = w / torch.sum(w, dim=-1, keepdim=True)
176 |
177 | # contextual loss per sample
178 | if direction == "forward":
179 | CX = torch.mean(torch.max(A_ij, dim=-1)[0], dim=1)
180 | else:
181 | CX = torch.mean(torch.max(A_ij, dim=1)[0], dim=-1)
182 |
183 | return -torch.log(CX)
184 |
185 |
186 | class ChamferDistance_patch_loss(nn.Module):
187 | """
188 | input is Al, Bl, channel = 1, range ~ [0, 255]
189 | """
190 |
191 | def __init__(self):
192 | super(ChamferDistance_patch_loss, self).__init__()
193 | return None
194 |
195 | def forward(self, X_features, Y_features, patch_size=3, image_x=None, image_y=None, h=0.1, Y_features_in=None):
196 | """
197 | X_features&Y_features are are feature vectors or feature 2d array
198 | h: bandwidth
199 | return the per-sample loss
200 | """
201 | batch_size = X_features.shape[0]
202 | feature_depth = X_features.shape[1]
203 | feature_size = X_features.shape[2]
204 |
205 | # to normalized feature vectors
206 | X_features = F.unfold(
207 | X_features, kernel_size=(patch_size, patch_size), stride=(1, 1), padding=(patch_size // 2, patch_size // 2)
208 | ) # batch_size, feature_depth_new * feature_size^2
209 | Y_features = F.unfold(
210 | Y_features, kernel_size=(patch_size, patch_size), stride=(1, 1), padding=(patch_size // 2, patch_size // 2)
211 | ) # batch_size, feature_depth_new * feature_size^2
212 |
213 | if image_x is not None and image_y is not None:
214 | image_x = torch.nn.functional.interpolate(image_x, size=(feature_size, feature_size), mode="bilinear").view(
215 | batch_size, 3, -1
216 | )
217 | image_y = torch.nn.functional.interpolate(image_y, size=(feature_size, feature_size), mode="bilinear").view(
218 | batch_size, 3, -1
219 | )
220 |
221 | X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth
222 | similarity_matrix = torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2
223 | NN_index = similarity_matrix.max(dim=-1, keepdim=True)[1].squeeze()
224 |
225 | if Y_features_in is not None:
226 | loss = torch.mean((X_features - Y_features_in.detach()) ** 2)
227 | Y_features_in = Y_features_in.detach()
228 | else:
229 | loss = torch.mean((X_features - Y_features[:, :, NN_index].detach()) ** 2)
230 | Y_features_in = Y_features[:, :, NN_index].detach()
231 |
232 | # re-arrange image
233 | if image_x is not None and image_y is not None:
234 | image_y_rearrange = image_y[:, :, NN_index]
235 | image_y_rearrange = image_y_rearrange.view(batch_size, 3, feature_size, feature_size)
236 | image_x = image_x.view(batch_size, 3, feature_size, feature_size)
237 | image_y = image_y.view(batch_size, 3, feature_size, feature_size)
238 |
239 | return loss
240 |
241 |
242 | class ChamferDistance_loss(nn.Module):
243 | """
244 | input is Al, Bl, channel = 1, range ~ [0, 255]
245 | """
246 |
247 | def __init__(self):
248 | super(ChamferDistance_loss, self).__init__()
249 | return None
250 |
251 | def forward(self, X_features, Y_features, image_x, image_y, h=0.1, Y_features_in=None):
252 | """
253 | X_features&Y_features are are feature vectors or feature 2d array
254 | h: bandwidth
255 | return the per-sample loss
256 | """
257 | batch_size = X_features.shape[0]
258 | feature_depth = X_features.shape[1]
259 | feature_size = X_features.shape[2]
260 |
261 | # to normalized feature vectors
262 | X_features = feature_normalize(X_features).view(
263 | batch_size, feature_depth, -1
264 | ) # batch_size * feature_depth * feature_size^2
265 | Y_features = feature_normalize(Y_features).view(
266 | batch_size, feature_depth, -1
267 | ) # batch_size * feature_depth * feature_size^2
268 | image_x = torch.nn.functional.interpolate(image_x, size=(feature_size, feature_size), mode="bilinear").view(
269 | batch_size, 3, -1
270 | )
271 | image_y = torch.nn.functional.interpolate(image_y, size=(feature_size, feature_size), mode="bilinear").view(
272 | batch_size, 3, -1
273 | )
274 |
275 | X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth
276 | similarity_matrix = torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2
277 | NN_index = similarity_matrix.max(dim=-1, keepdim=True)[1].squeeze()
278 | if Y_features_in is not None:
279 | loss = torch.mean((X_features - Y_features_in.detach()) ** 2)
280 | Y_features_in = Y_features_in.detach()
281 | else:
282 | loss = torch.mean((X_features - Y_features[:, :, NN_index].detach()) ** 2)
283 | Y_features_in = Y_features[:, :, NN_index].detach()
284 |
285 | # re-arrange image
286 | image_y_rearrange = image_y[:, :, NN_index]
287 | image_y_rearrange = image_y_rearrange.view(batch_size, 3, feature_size, feature_size)
288 | image_x = image_x.view(batch_size, 3, feature_size, feature_size)
289 | image_y = image_y.view(batch_size, 3, feature_size, feature_size)
290 |
291 | return loss, Y_features_in, X_features
292 |
293 |
294 | if __name__ == "__main__":
295 | contextual_loss = ContextualLoss()
296 | batch_size = 32
297 | feature_depth = 8
298 | feature_size = 16
299 | X_features = torch.zeros(batch_size, feature_depth, feature_size, feature_size)
300 | Y_features = torch.zeros(batch_size, feature_depth, feature_size, feature_size)
301 |
302 | cx_loss = contextual_loss(X_features, Y_features, 1)
303 | print(cx_loss)
304 |
--------------------------------------------------------------------------------
/models/DEVC/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import sys
4 |
5 | import cv2
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import torch
9 | import torchvision.utils as vutils
10 | from skimage import color, io
11 | from torch.autograd import Variable
12 |
13 |
14 | from utils.loss import mse_loss
15 |
16 | cv2.setNumThreads(0)
17 |
18 | # l: [-50,50]
19 | # ab: [-128, 128]
20 | l_norm, ab_norm = 1.0, 1.0
21 | l_mean, ab_mean = 50.0, 0
22 |
23 |
24 | ###### utility ######
25 | def to_np(x):
26 | return x.data.cpu().numpy()
27 |
28 |
29 | def utf8_str(in_str):
30 | try:
31 | in_str = in_str.decode("UTF-8")
32 | except Exception:
33 | in_str = in_str.encode("UTF-8").decode("UTF-8")
34 | return in_str
35 |
36 |
37 | class MovingAvg(object):
38 | def __init__(self, pool_size=100):
39 | from queue import Queue
40 |
41 | self.pool = Queue(maxsize=pool_size)
42 | self.sum = 0
43 | self.curr_pool_size = 0
44 |
45 | def set_curr_val(self, val):
46 | if not self.pool.full():
47 | self.curr_pool_size += 1
48 | self.pool.put_nowait(val)
49 | else:
50 | last_first_val = self.pool.get_nowait()
51 | self.pool.put_nowait(val)
52 | self.sum -= last_first_val
53 |
54 | self.sum += val
55 | return self.sum / self.curr_pool_size
56 |
57 |
58 | ###### image normalization ######
59 | def center_l(l):
60 | # normalization for l
61 | l_mc = (l - l_mean) / l_norm
62 | return l_mc
63 |
64 |
65 | # denormalization for l
66 | def uncenter_l(l):
67 | return l * l_norm + l_mean
68 |
69 |
70 | # normalization for ab
71 | def center_ab(ab):
72 | return (ab - ab_mean) / ab_norm
73 |
74 |
75 | # normalization for lab image
76 | def center_lab_img(img_lab):
77 | return (
78 | img_lab / np.array((l_norm, ab_norm, ab_norm))[:, np.newaxis, np.newaxis]
79 | - np.array((l_mean / l_norm, ab_mean / ab_norm, ab_mean / ab_norm))[:, np.newaxis, np.newaxis]
80 | )
81 |
82 |
83 | ###### color space transformation ######
84 | def rgb2lab_transpose(img_rgb):
85 | return color.rgb2lab(img_rgb).transpose((2, 0, 1))
86 |
87 |
88 | def lab2rgb(img_l, img_ab):
89 | """INPUTS
90 | img_l XxXx1 [0,100]
91 | img_ab XxXx2 [-100,100]
92 | OUTPUTS
93 | returned value is XxXx3"""
94 | pred_lab = np.concatenate((img_l, img_ab), axis=2).astype("float64")
95 | pred_rgb = color.lab2rgb(pred_lab)
96 | pred_rgb = (np.clip(pred_rgb, 0, 1) * 255).astype("uint8")
97 | return pred_rgb
98 |
99 |
100 | def gray2rgb_batch(l):
101 | # gray image tensor to rgb image tensor
102 | l_uncenter = uncenter_l(l)
103 | l_uncenter = l_uncenter / (2 * l_mean)
104 | return torch.cat((l_uncenter, l_uncenter, l_uncenter), dim=1)
105 |
106 |
107 | def lab2rgb_transpose(img_l, img_ab):
108 | """INPUTS
109 | img_l 1xXxX [0,100]
110 | img_ab 2xXxX [-100,100]
111 | OUTPUTS
112 | returned value is XxXx3"""
113 | pred_lab = np.concatenate((img_l, img_ab), axis=0).transpose((1, 2, 0))
114 | return (np.clip(color.lab2rgb(pred_lab), 0, 1) * 255).astype("uint8")
115 |
116 |
117 | def lab2rgb_transpose_mc(img_l_mc, img_ab_mc):
118 | if isinstance(img_l_mc, Variable):
119 | img_l_mc = img_l_mc.data.cpu()
120 | if isinstance(img_ab_mc, Variable):
121 | img_ab_mc = img_ab_mc.data.cpu()
122 |
123 | if img_l_mc.is_cuda:
124 | img_l_mc = img_l_mc.cpu()
125 | if img_ab_mc.is_cuda:
126 | img_ab_mc = img_ab_mc.cpu()
127 |
128 | assert img_l_mc.dim() == 3 and img_ab_mc.dim() == 3, "only for batch input"
129 |
130 | img_l = img_l_mc * l_norm + l_mean
131 | img_ab = img_ab_mc * ab_norm + ab_mean
132 | pred_lab = torch.cat((img_l, img_ab), dim=0)
133 | grid_lab = pred_lab.numpy().astype("float64")
134 | return (np.clip(color.lab2rgb(grid_lab.transpose((1, 2, 0))), 0, 1) * 255).astype("uint8")
135 |
136 |
137 | def batch_lab2rgb_transpose_mc(img_l_mc, img_ab_mc, nrow=8):
138 | if isinstance(img_l_mc, Variable):
139 | img_l_mc = img_l_mc.data.cpu()
140 | if isinstance(img_ab_mc, Variable):
141 | img_ab_mc = img_ab_mc.data.cpu()
142 |
143 | if img_l_mc.is_cuda:
144 | img_l_mc = img_l_mc.cpu()
145 | if img_ab_mc.is_cuda:
146 | img_ab_mc = img_ab_mc.cpu()
147 |
148 | assert img_l_mc.dim() == 4 and img_ab_mc.dim() == 4, "only for batch input"
149 |
150 | img_l = img_l_mc * l_norm + l_mean
151 | img_ab = img_ab_mc * ab_norm + ab_mean
152 | pred_lab = torch.cat((img_l, img_ab), dim=1)
153 | grid_lab = vutils.make_grid(pred_lab, nrow=nrow).numpy().astype("float64")
154 | return (np.clip(color.lab2rgb(grid_lab.transpose((1, 2, 0))), 0, 1) * 255).astype("uint8")
155 |
156 |
157 | ###### loss functions ######
158 | def feature_normalize(feature_in):
159 | feature_in_norm = torch.norm(feature_in, 2, 1, keepdim=True) + sys.float_info.epsilon
160 | feature_in_norm = torch.div(feature_in, feature_in_norm)
161 | return feature_in_norm
162 |
163 |
164 | def statistics_matching(feature1, feature2):
165 | N, C, H, W = feature1.shape
166 | feature1 = feature1.view(N, C, -1)
167 | feature2 = feature2.view(N, C, -1)
168 |
169 | mean1 = feature1.mean(dim=-1)
170 | mean2 = feature2.mean(dim=-1)
171 | std1 = feature1.var(dim=-1).sqrt()
172 | std2 = feature2.var(dim=-1).sqrt()
173 |
174 | return mse_loss(mean1, mean2) + mse_loss(std1, std2)
175 |
176 |
177 | def calc_ab_gradient(input_ab):
178 | x_grad = input_ab[:, :, :, 1:] - input_ab[:, :, :, :-1]
179 | y_grad = input_ab[:, :, 1:, :] - input_ab[:, :, :-1, :]
180 | return x_grad, y_grad
181 |
182 |
183 | def calc_tv_loss(input):
184 | x_grad = input[:, :, :, 1:] - input[:, :, :, :-1]
185 | y_grad = input[:, :, 1:, :] - input[:, :, :-1, :]
186 | return torch.sum(x_grad ** 2) / x_grad.nelement() + torch.sum(y_grad ** 2) / y_grad.nelement()
187 |
188 |
189 | def calc_cosine_dist_loss(input, target):
190 | input_norm = torch.norm(input, 2, 1, keepdim=True) + sys.float_info.epsilon
191 | target_norm = torch.norm(target, 2, 1, keepdim=True) + sys.float_info.epsilon
192 | normalized_input = torch.div(input, input_norm)
193 | normalized_target = torch.div(target, target_norm)
194 | cos_dist = torch.mul(normalized_input, normalized_target)
195 | return torch.mean(1 - torch.sum(cos_dist, dim=1))
196 |
197 |
198 | ###### video related #######
199 | def save_frames(image, image_folder, index=None):
200 | if image is not None:
201 | image = np.clip(image, 0, 255).astype(np.uint8)
202 | io.imsave(os.path.join(image_folder, "frame" + str(index).zfill(5) + ".jpg"), image)
203 |
204 |
205 |
206 | ###### file system ######
207 | def get_size(start_path="."):
208 | total_size = 0
209 | for dirpath, dirnames, filenames in os.walk(start_path):
210 | for f in filenames:
211 | fp = os.path.join(dirpath, f)
212 | total_size += os.path.getsize(fp)
213 | return total_size
214 |
215 |
216 | def parse(parser, save=True):
217 | opt = parser.parse_args(args=[])
218 | args = vars(opt)
219 |
220 | from time import gmtime, strftime
221 |
222 | print("------------ Options -------------")
223 | for k, v in sorted(args.items()):
224 | print("%s: %s" % (str(k), str(v)))
225 | print("-------------- End ----------------")
226 |
227 | # save to the disk
228 | if save:
229 | file_name = os.path.join("opt.txt")
230 | with open(file_name, "wt") as opt_file:
231 | opt_file.write(os.path.basename(sys.argv[0]) + " " + strftime("%Y-%m-%d %H:%M:%S", gmtime()) + "\n")
232 | opt_file.write("------------ Options -------------\n")
233 | for k, v in sorted(args.items()):
234 | opt_file.write("%s: %s\n" % (str(k), str(v)))
235 | opt_file.write("-------------- End ----------------\n")
236 | return opt
237 |
238 |
239 | ###### interactive ######
240 | def clean_tensorboard(directory):
241 | folder_list = os.walk(directory).__next__()[1]
242 | for folder in folder_list:
243 | folder = directory + folder
244 | if get_size(folder) < 10000000:
245 | print("delete the folder of " + folder)
246 | shutil.rmtree(folder)
247 |
248 |
249 | def imshow(input_image, title=None, type_conversion=False):
250 | inp = input_image
251 | if type_conversion or type(input_image) is torch.Tensor:
252 | inp = input_image.numpy()
253 | else:
254 | inp = input_image
255 | fig = plt.figure()
256 | if inp.ndim == 2:
257 | fig = plt.imshow(inp, cmap="gray", clim=[0, 255])
258 | else:
259 | fig = plt.imshow(np.transpose(inp, [1, 2, 0]).astype(np.uint8))
260 | plt.axis("off")
261 | fig.axes.get_xaxis().set_visible(False)
262 | fig.axes.get_yaxis().set_visible(False)
263 | plt.title(title)
264 |
265 |
266 | def imshow_lab(input_lab):
267 | plt.imshow((batch_lab2rgb_transpose_mc(input_lab[:32, 0:1, :, :], input_lab[:32, 1:3, :, :])).astype(np.uint8))
268 |
269 |
270 | ###### vgg preprocessing ######
271 | def vgg_preprocess(tensor):
272 | # input is RGB tensor which ranges in [0,1]
273 | # output is BGR tensor which ranges in [0,255]
274 | tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1)
275 | tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view(1, 3, 1, 1)
276 | return tensor_bgr_ml * 255
277 |
278 |
279 | def torch_vgg_preprocess(tensor):
280 | # pytorch version normalization
281 | # note that both input and output are RGB tensors;
282 | # input and output ranges in [0,1]
283 | # normalize the tensor with mean and variance
284 | tensor_mc = tensor - torch.Tensor([0.485, 0.456, 0.406]).type_as(tensor).view(1, 3, 1, 1)
285 | return tensor_mc / torch.Tensor([0.229, 0.224, 0.225]).type_as(tensor_mc).view(1, 3, 1, 1)
286 |
287 |
288 | def network_gradient(net, gradient_on=True):
289 | for param in net.parameters():
290 | param.requires_grad = bool(gradient_on)
291 | return net
292 |
293 |
294 | ##### color space
295 | xyz_from_rgb = np.array(
296 | [[0.412453, 0.357580, 0.180423], [0.212671, 0.715160, 0.072169], [0.019334, 0.119193, 0.950227]]
297 | )
298 | rgb_from_xyz = np.array(
299 | [[3.24048134, -0.96925495, 0.05564664], [-1.53715152, 1.87599, -0.20404134], [-0.49853633, 0.04155593, 1.05731107]]
300 | )
301 |
302 |
303 | def tensor_lab2rgb(input):
304 | """
305 | n * 3* h *w
306 | """
307 | input_trans = input.transpose(1, 2).transpose(2, 3) # n * h * w * 3
308 | L, a, b = input_trans[:, :, :, 0:1], input_trans[:, :, :, 1:2], input_trans[:, :, :, 2:]
309 | y = (L + 16.0) / 116.0
310 | x = (a / 500.0) + y
311 | z = y - (b / 200.0)
312 |
313 | neg_mask = z.data < 0
314 | z[neg_mask] = 0
315 | xyz = torch.cat((x, y, z), dim=3)
316 |
317 | mask = xyz.data > 0.2068966
318 | mask_xyz = xyz.clone()
319 | mask_xyz[mask] = torch.pow(xyz[mask], 3.0)
320 | mask_xyz[~mask] = (xyz[~mask] - 16.0 / 116.0) / 7.787
321 | mask_xyz[:, :, :, 0] = mask_xyz[:, :, :, 0] * 0.95047
322 | mask_xyz[:, :, :, 2] = mask_xyz[:, :, :, 2] * 1.08883
323 |
324 | rgb_trans = torch.mm(mask_xyz.view(-1, 3), torch.from_numpy(rgb_from_xyz).type_as(xyz)).view(
325 | input.size(0), input.size(2), input.size(3), 3
326 | )
327 | rgb = rgb_trans.transpose(2, 3).transpose(1, 2)
328 |
329 | mask = rgb > 0.0031308
330 | mask_rgb = rgb.clone()
331 | mask_rgb[mask] = 1.055 * torch.pow(rgb[mask], 1 / 2.4) - 0.055
332 | mask_rgb[~mask] = rgb[~mask] * 12.92
333 |
334 | neg_mask = mask_rgb.data < 0
335 | large_mask = mask_rgb.data > 1
336 | mask_rgb[neg_mask] = 0
337 | mask_rgb[large_mask] = 1
338 | return mask_rgb
339 |
--------------------------------------------------------------------------------
/models/DVP/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-deep-video-prior (DVP)
2 | Official PyTorch implementation for NeurIPS 2020 paper: Blind Video Temporal Consistency via Deep Video Prior
3 |
4 | [TensorFlow implementation](https://github.com/ChenyangLEI/deep-video-prior)
5 | | [paper](https://arxiv.org/abs/2010.11838)
6 | | [project website](https://chenyanglei.github.io/DVP/index.html)
7 |
8 |
9 | ## Introduction
10 | Our method is a general framework to improve the temporal consistency of video processed by image algorithms.
11 |
12 | For example, combining single image colorization or single image dehazing algorithm with our framework, we can achieve the goal of video colorization or video dehazing.
13 |
14 |
15 |
16 |
17 |
18 |
19 | ## Dependency
20 |
21 | ### Environment
22 | This code is based on PyTorch. It has been tested on Ubuntu 18.04 LTS.
23 |
24 | Anaconda is recommended: [Ubuntu 18.04](https://www.digitalocean.com/community/tutorials/how-to-install-the-anaconda-python-distribution-on-ubuntu-18-04)
25 | | [Ubuntu 16.04](https://www.digitalocean.com/community/tutorials/how-to-install-the-anaconda-python-distribution-on-ubuntu-16-04)
26 |
27 | After installing Anaconda, you can setup the environment simply by
28 |
29 | ```
30 | conda env create -f environment.yml
31 | ```
32 |
33 |
34 | ## Inference
35 |
36 | ### Demo
37 | ```
38 | bash test.sh
39 | ```
40 | The results will be saved in ./result
41 |
42 | ### Use your own data
43 | For the video with unimodal inconsistency:
44 |
45 | ```
46 | python main_IRT.py --max_epoch 25 --input PATH_TO_YOUR_INPUT_FOLDER --processed PATH_TO_YOUR_PROCESSED_FOLDER --model NAME_OF_YOUR_MODEL --with_IRT 0 --IRT_initialization 0 --output ./result/OWN_DATA
47 | ```
48 |
49 | For the video with multimodal inconsistency:
50 |
51 | ```
52 | python main_IRT.py --max_epoch 25 --input PATH_TO_YOUR_INPUT_FOLDER --processed PATH_TO_YOUR_PROCESSED_FOLDER --model NAME_OF_YOUR_MODEL --with_IRT 1 --IRT_initialization 1 --output ./result/OWN_DATA
53 | ```
54 |
55 |
56 | ## Citation
57 | If you find this work useful for your research, please cite:
58 | ```
59 | @inproceedings{lei2020dvp,
60 | title={Blind Video Temporal Consistency via Deep Video Prior},
61 | author={Lei, Chenyang and Xing, Yazhou and Chen, Qifeng},
62 | booktitle={Advances in Neural Information Processing Systems},
63 | year={2020}
64 | }
65 | ```
66 |
67 |
68 | ## Contact
69 | Feel free to contact me if there is any question. (Yazhou Xing, yzxing87@gmail.com)
70 |
--------------------------------------------------------------------------------
/models/DVP/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import cv2
4 | import torch
5 | from PIL import Image
6 | from tqdm import tqdm
7 | import numpy as np
8 | import torchvision.transforms as transform_lib
9 | import matplotlib.pyplot as plt
10 |
11 | from utils.util import download_zipfile, mkdir
12 | from utils.v2i import convert_frames_to_video
13 |
14 | class OPT():
15 | pass
16 |
17 | class DVP():
18 | def __init__(self):
19 | self.small = (320, 180)
20 | self.in_size = (0, 0)
21 |
22 | def test(self, black_white_path, colorized_path, output_path, opt=None):
23 | assert os.path.exists(black_white_path) and os.path.exists(colorized_path)
24 |
25 | self.downscale(black_white_path)
26 | self.downscale(colorized_path)
27 |
28 | os.system(f'python3 ./models/DVP/main_IRT.py --save_freq {opt.sf} --max_epoch {opt.me} --input {black_white_path} --processed {colorized_path} --model temp --with_IRT 1 --IRT_initialization 1 --output {opt.op}')
29 |
30 | frames_path = f"{opt.op}/temp_IRT1_initial1/{os.path.basename(black_white_path)}/00{opt.me}"
31 |
32 | self.upscale(frames_path)
33 |
34 | length = len(os.listdir(frames_path))
35 | frames = [f"out_main_{str(i).zfill(5)}.jpg" for i in range(length)]
36 | convert_frames_to_video(frames_path, output_path, frames)
37 |
38 | def downscale(self, path):
39 | frames = os.listdir(path)
40 |
41 | frame = Image.open(os.path.join(path, frames[0]))
42 | self.in_size = frame.size
43 |
44 | for each in frames:
45 | img = Image.open(os.path.join(path, each))
46 | img = img.resize(self.small, Image.ANTIALIAS)
47 | img.save(os.path.join(path, each))
48 |
49 | def upscale(self, path):
50 | frames = os.listdir(path)
51 | for each in frames:
52 | img = Image.open(os.path.join(path, each))
53 | img = img.resize(self.in_size, Image.ANTIALIAS)
54 | img.save(os.path.join(path, each))
--------------------------------------------------------------------------------
/models/DVP/environment.yml:
--------------------------------------------------------------------------------
1 | name: pytorch-DVP
2 | channels:
3 | - defaults
4 | dependencies:
5 | - _libgcc_mutex=0.1=main
6 | - blas=1.0=mkl
7 | - ca-certificates=2020.10.14=0
8 | - certifi=2020.11.8=py36h06a4308_0
9 | - cffi=1.14.4=py36h261ae71_0
10 | - cudatoolkit=10.0.130=0
11 | - cudnn=7.6.5=cuda10.0_0
12 | - freetype=2.10.4=h5ab3b9f_0
13 | - intel-openmp=2020.2=254
14 | - jpeg=9b=h024ee3a_2
15 | - lcms2=2.11=h396b838_0
16 | - ld_impl_linux-64=2.33.1=h53a641e_7
17 | - libedit=3.1.20191231=h14c3975_1
18 | - libffi=3.3=he6710b0_2
19 | - libgcc-ng=9.1.0=hdf63c60_0
20 | - libpng=1.6.37=hbc83047_0
21 | - libstdcxx-ng=9.1.0=hdf63c60_0
22 | - libtiff=4.1.0=h2733197_1
23 | - lz4-c=1.9.2=heb0550a_3
24 | - mkl=2020.2=256
25 | - mkl-service=2.3.0=py36he904b0f_0
26 | - mkl_fft=1.2.0=py36h23d657b_0
27 | - mkl_random=1.1.1=py36h0573a6f_0
28 | - ncurses=6.2=he6710b0_1
29 | - ninja=1.10.2=py36hff7bd54_0
30 | - numpy=1.19.2=py36h54aff64_0
31 | - numpy-base=1.19.2=py36hfa32c7d_0
32 | - olefile=0.46=py36_0
33 | - openssl=1.1.1h=h7b6447c_0
34 | - pillow=8.0.1=py36he98fc37_0
35 | - pip=20.3=py36h06a4308_0
36 | - pycparser=2.20=py_2
37 | - python=3.6.12=hcff3b4d_2
38 | - pytorch=1.1.0=cuda100py36he554f03_0
39 | - readline=8.0=h7b6447c_0
40 | - setuptools=50.3.2=py36h06a4308_2
41 | - six=1.15.0=py36h06a4308_0
42 | - sqlite=3.33.0=h62c20be_0
43 | - tk=8.6.10=hbc83047_0
44 | - torchvision=0.3.0=cuda100py36h72fc40a_0
45 | - wheel=0.36.0=pyhd3eb1b0_0
46 | - xz=5.2.5=h7b6447c_0
47 | - zlib=1.2.11=h7b6447c_3
48 | - zstd=1.4.5=h9ceee32_0
49 | - pip:
50 | - scipy==1.2.0
51 | prefix: /home/yazhou/anaconda3/envs/pytorch-DVP
52 |
53 |
--------------------------------------------------------------------------------
/models/DVP/main_IRT.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 | import os
5 | import time
6 | import scipy.io
7 | import torch
8 | import torch.nn as nn
9 | import numpy as np
10 | from glob import glob
11 | import scipy.misc as sic
12 | import subprocess
13 | import models.network as net
14 | import argparse
15 | import random
16 | import imageio
17 | from vgg import VGG19
18 |
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--model", default='Test', type=str, help="Name of model")
21 | parser.add_argument("--save_freq", default=5, type=int, help="save frequency of epochs")
22 | parser.add_argument("--use_gpu", default=1, type=int, help="use gpu or not")
23 | parser.add_argument("--with_IRT", default=0, type=int, help="use IRT or not")
24 | parser.add_argument("--IRT_initialization", default=0, type=int, help="use initialization for IRT or not")
25 | parser.add_argument("--max_epoch", default=25, type=int, help="The max number of epochs for training")
26 | parser.add_argument("--input", default='./demo/colorization/goat_input', type=str, help="dir of input video")
27 | parser.add_argument("--processed", default='./demo/colorization/goat_processed', type=str, help="dir of processed video")
28 | parser.add_argument("--output", default='None', type=str, help="dir of output video")
29 |
30 | # set random seed
31 | seed = 2020
32 | np.random.seed(seed)
33 | torch.manual_seed(seed)
34 | random.seed(seed)
35 |
36 | # process arguments
37 | ARGS = parser.parse_args()
38 | print(ARGS)
39 | save_freq = ARGS.save_freq
40 | input_folder = ARGS.input
41 | processed_folder = ARGS.processed
42 | with_IRT = ARGS.with_IRT
43 | maxepoch = ARGS.max_epoch + 1
44 | model= ARGS.model
45 | task = "/{}_IRT{}_initial{}".format(model, with_IRT, ARGS.IRT_initialization) #Colorization, HDR, StyleTransfer, Dehazing
46 |
47 | # set gpu
48 | if ARGS.use_gpu:
49 | os.environ["CUDA_VISIBLE_DEVICES"]=str(np.argmax([int(x.split()[2])
50 | for x in subprocess.Popen("nvidia-smi -q -d Memory | grep -A4 GPU | grep Free", shell=True, stdout=subprocess.PIPE).stdout.readlines()]))
51 | else:
52 | os.environ["CUDA_VISIBLE_DEVICES"] = ''
53 |
54 | device = torch.device("cuda:0" if ARGS.use_gpu else "cpu")
55 |
56 | # clear cuda cache
57 | torch.cuda.empty_cache()
58 |
59 | # define loss function
60 | def compute_error(real,fake):
61 | # return tf.reduce_mean(tf.abs(fake-real))
62 | return torch.mean(torch.abs(fake-real))
63 |
64 | def Lp_loss(x, y):
65 | vgg_real = VGG_19(normalize_batch(x))
66 | vgg_fake = VGG_19(normalize_batch(y))
67 | p0 = compute_error(normalize_batch(x), normalize_batch(y))
68 |
69 | content_loss_list = []
70 | content_loss_list.append(p0)
71 | feat_layers = {'conv1_2' : 1.0/2.6, 'conv2_2' : 1.0/4.8, 'conv3_2': 1.0/3.7, 'conv4_2':1.0/5.6, 'conv5_2':10.0/1.5}
72 |
73 | for layer, w in feat_layers.items():
74 | pi = compute_error(vgg_real[layer], vgg_fake[layer])
75 | content_loss_list.append(w * pi)
76 |
77 | content_loss = torch.sum(torch.stack(content_loss_list))
78 |
79 | return content_loss
80 |
81 | loss_L2 = torch.nn.MSELoss()
82 | loss_L1 = torch.nn.L1Loss()
83 |
84 |
85 | # Define model .
86 | out_channels = 6 if with_IRT else 3
87 | net = net.UNet(in_channels=3, out_channels=out_channels, init_features=32)
88 | net.to(device)
89 | optimizer = torch.optim.Adam(net.parameters(), lr=0.0001)
90 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3000,8000], gamma=0.5)
91 |
92 | VGG_19 = VGG19(requires_grad=False).to(device)
93 |
94 | # prepare data
95 | input_folders = [input_folder]
96 | processed_folders = [processed_folder]
97 |
98 |
99 | def prepare_paired_input(task, id, input_names, processed_names, is_train=0):
100 | net_in = np.float32(imageio.imread(input_names[id]))/255.0
101 | if len(net_in.shape) == 2:
102 | net_in = np.tile(net_in[:,:,np.newaxis], [1,1,3])
103 | net_gt = np.float32(imageio.imread(processed_names[id]))/255.0
104 | org_h,org_w = net_in.shape[:2]
105 | h = org_h // 32 * 32
106 | w = org_w // 32 * 32
107 | print(net_in.shape, net_gt.shape)
108 | return net_in[np.newaxis, :h, :w, :], net_gt[np.newaxis, :h, :w, :]
109 |
110 | # some functions
111 | def initialize_weights(model):
112 | for module in model.modules():
113 | if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
114 | # nn.init.kaiming_normal_(module.weight)
115 | nn.init.xavier_normal_(module.weight)
116 | if module.bias is not None:
117 | module.bias.data.zero_()
118 | elif isinstance(module, nn.BatchNorm2d):
119 | module.weight.data.fill_(1)
120 | module.bias.data.zero_()
121 |
122 | def normalize_batch(batch):
123 | # Normalize batch using ImageNet mean and std
124 | mean = batch.new_tensor([0.485, 0.456, 0.406]).view(1, -1, 1, 1)
125 | std = batch.new_tensor([0.229, 0.224, 0.225]).view(1, -1, 1, 1)
126 | return (batch - mean) / std
127 |
128 |
129 |
130 | # start to train
131 | for folder_idx, input_folder in enumerate(input_folders):
132 | # -----------load data-------------
133 | input_names = sorted(glob(input_folders[folder_idx] + "/*"))
134 | processed_names = sorted(glob(processed_folders[folder_idx] + "/*"))
135 | if ARGS.output == "None":
136 | output_folder = "./result/{}".format(task + '/' + input_folder.split("/")[-2] + '/' + input_folder.split("/")[-1])
137 | else:
138 | output_folder = ARGS.output + "/" + task + '/' + input_folder.split("/")[-1]
139 | print(output_folder, input_folders[folder_idx], processed_folders[folder_idx] )
140 |
141 | num_of_sample = min(len(input_names), len(processed_names))
142 | data_in_memory = [None] * num_of_sample #Speedup
143 | for id in range(min(len(input_names), len(processed_names))): #Speedup
144 | net_in,net_gt = prepare_paired_input(task, id, input_names, processed_names) #Speedup
145 | net_in = torch.from_numpy(net_in).permute(0,3,1,2).float().to(device)
146 | net_gt = torch.from_numpy(net_gt).permute(0,3,1,2).float().to(device)
147 | data_in_memory[id] = [net_in,net_gt] #Speedup
148 |
149 | # model re-initialization
150 | initialize_weights(net)
151 |
152 | step = 0
153 | for epoch in range(1,maxepoch):
154 | # -----------start to train-------------
155 | print("Processing epoch {}".format(epoch))
156 | frame_id = 0
157 | if os.path.isdir("{}/{:04d}".format(output_folder, epoch)):
158 | continue
159 | else:
160 | os.makedirs("{}/{:04d}".format(output_folder, epoch))
161 | if not os.path.isdir("{}/training".format(output_folder)):
162 | os.makedirs("{}/training".format(output_folder))
163 |
164 | print(len(input_names), len(processed_names))
165 | for id in range(num_of_sample):
166 | if with_IRT:
167 | if epoch < 6 and ARGS.IRT_initialization:
168 | net_in,net_gt = data_in_memory[0] #Option:
169 | prediction = net(net_in)
170 |
171 | crt_loss = loss_L1(prediction[:,:3,:,:], net_gt) + 0.9*loss_L1(prediction[:,3:,:,:], net_gt)
172 |
173 | else:
174 | net_in,net_gt = data_in_memory[id]
175 | prediction = net(net_in)
176 |
177 | prediction_main = prediction[:,:3,:,:]
178 | prediction_minor = prediction[:,3:,:,:]
179 | diff_map_main,_ = torch.max(torch.abs(prediction_main - net_gt) / (net_in+1e-1), dim=1, keepdim=True)
180 | diff_map_minor,_ = torch.max(torch.abs(prediction_minor - net_gt) / (net_in+1e-1), dim=1, keepdim=True)
181 | confidence_map = torch.lt(diff_map_main, diff_map_minor).repeat(1,3,1,1).float()
182 | crt_loss = loss_L1(prediction_main*confidence_map, net_gt*confidence_map) \
183 | + loss_L1(prediction_minor*(1-confidence_map), net_gt*(1-confidence_map))
184 | else:
185 | net_in,net_gt = data_in_memory[id]
186 | prediction = net(net_in)
187 | crt_loss = Lp_loss(prediction, net_gt)
188 |
189 | optimizer.zero_grad()
190 | crt_loss.backward()
191 | optimizer.step()
192 |
193 | frame_id+=1
194 | step+=1
195 | if step % 10 == 0:
196 | print("Image iter: {} {} {} || Loss: {:.4f} ".format(epoch, frame_id, step, crt_loss))
197 | if step % 100 == 0 :
198 | net_in = net_in.permute(0,2,3,1).cpu().numpy()
199 | net_gt = net_gt.permute(0,2,3,1).cpu().numpy()
200 | prediction = prediction.detach().permute(0,2,3,1).cpu().numpy()
201 | if with_IRT:
202 | prediction = prediction[...,:3]
203 | imageio.imsave("{}/training/step{:06d}_{:06d}.jpg".format(output_folder, step, id),
204 | np.uint8(np.concatenate([net_in[0], prediction[0], net_gt[0]], axis=1).clip(0,1) * 255.0))
205 |
206 | # # -----------save intermidiate results-------------
207 | if epoch % save_freq == 0:
208 | for id in range(num_of_sample):
209 | st=time.time()
210 | net_in,net_gt = data_in_memory[id]
211 | print("Test: {}-{} \r".format(id, num_of_sample))
212 |
213 | with torch.no_grad():
214 | prediction = net(net_in)
215 | net_in = net_in.permute(0,2,3,1).cpu().numpy()
216 | net_gt = net_gt.permute(0,2,3,1).cpu().numpy()
217 | prediction = prediction.detach().permute(0,2,3,1).cpu().numpy()
218 |
219 | if with_IRT:
220 | prediction_main = prediction[...,:3]
221 | prediction_minor = prediction[...,3:]
222 | diff_map_main = np.amax(np.absolute(prediction_main - net_gt) / (net_in+1e-1), axis=3, keepdims=True)
223 | diff_map_minor = np.amax(np.absolute(prediction_minor - net_gt) / (net_in+1e-1), axis=3, keepdims=True)
224 | confidence_map = np.tile(np.less(diff_map_main, diff_map_minor), (1,1,1,3)).astype('float32')
225 |
226 | imageio.imsave("{}/{:04d}/predictions_{:05d}.jpg".format(output_folder, epoch, id),
227 | np.uint8(np.concatenate([net_in[0,:,:,:3],prediction_main[0], prediction_minor[0],net_gt[0], confidence_map[0]], axis=1).clip(0,1) * 255.0))
228 | imageio.imsave("{}/{:04d}/out_main_{:05d}.jpg".format(output_folder, epoch, id),np.uint8(prediction_main[0].clip(0,1) * 255.0))
229 | imageio.imsave("{}/{:04d}/out_minor_{:05d}.jpg".format(output_folder, epoch, id),np.uint8(prediction_minor[0].clip(0,1) * 255.0))
230 |
231 | else:
232 |
233 | imageio.imsave("{}/{:04d}/predictions_{:05d}.jpg".format(output_folder, epoch, id),
234 | np.uint8(np.concatenate([net_in[0,:,:,:3], prediction[0], net_gt[0]],axis=1).clip(0,1) * 255.0))
235 | imageio.imsave("{}/{:04d}/out_main_{:05d}.jpg".format(output_folder, epoch, id),
236 | np.uint8(prediction[0].clip(0,1) * 255.0))
237 |
238 |
--------------------------------------------------------------------------------
/models/DVP/models/__pycache__/network.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wensi-ai/Colorizer/ef382f44885f9d083c02a83952db9b4a6b4149c1/models/DVP/models/__pycache__/network.cpython-36.pyc
--------------------------------------------------------------------------------
/models/DVP/models/__pycache__/network.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wensi-ai/Colorizer/ef382f44885f9d083c02a83952db9b4a6b4149c1/models/DVP/models/__pycache__/network.cpython-39.pyc
--------------------------------------------------------------------------------
/models/DVP/models/network.py:
--------------------------------------------------------------------------------
1 |
2 | from collections import OrderedDict
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class UNet(nn.Module):
9 |
10 | def __init__(self, in_channels=3, out_channels=1, init_features=32):
11 | super(UNet, self).__init__()
12 |
13 | features = init_features
14 | self.encoder1 = UNet._block(in_channels, features, name="enc1")
15 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
16 | self.encoder2 = UNet._block(features, features * 2, name="enc2")
17 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
18 | self.encoder3 = UNet._block(features * 2, features * 4, name="enc3")
19 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
20 | self.encoder4 = UNet._block(features * 4, features * 8, name="enc4")
21 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
22 |
23 | self.bottleneck = UNet._block(features * 8, features * 16, name="bottleneck")
24 |
25 | self.upconv4 = nn.Sequential(
26 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
27 | nn.Conv2d(features * 16, features * 8, kernel_size=3, padding=1)
28 | )
29 | self.decoder4 = UNet._block((features * 8) * 2, features * 8, name="dec4")
30 | self.upconv3 = nn.Sequential(
31 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
32 | nn.Conv2d(features * 8, features * 4, kernel_size=3, padding=1)
33 | )
34 |
35 | self.decoder3 = UNet._block((features * 4) * 2, features * 4, name="dec3")
36 | self.upconv2 = nn.Sequential(
37 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
38 | nn.Conv2d(features * 4, features * 2, kernel_size=3, padding=1)
39 | )
40 |
41 | self.decoder2 = UNet._block((features * 2) * 2, features * 2, name="dec2")
42 | self.upconv1 = nn.Sequential(
43 | nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
44 | nn.Conv2d(features * 2, features, kernel_size=3, padding=1)
45 | )
46 |
47 | self.decoder1 = UNet._block(features * 2, features, name="dec1")
48 |
49 | self.conv = nn.Conv2d(
50 | in_channels=features, out_channels=out_channels, kernel_size=1
51 | )
52 |
53 | def forward(self, x):
54 | enc1 = self.encoder1(x)
55 | enc2 = self.encoder2(self.pool1(enc1))
56 | enc3 = self.encoder3(self.pool2(enc2))
57 | enc4 = self.encoder4(self.pool3(enc3))
58 |
59 | bottleneck = self.bottleneck(self.pool4(enc4))
60 |
61 | dec4 = self.upconv4(bottleneck)
62 | dec4 = torch.cat((dec4, enc4), dim=1)
63 | dec4 = self.decoder4(dec4)
64 | dec3 = self.upconv3(dec4)
65 | dec3 = torch.cat((dec3, enc3), dim=1)
66 | dec3 = self.decoder3(dec3)
67 | dec2 = self.upconv2(dec3)
68 | dec2 = torch.cat((dec2, enc2), dim=1)
69 | dec2 = self.decoder2(dec2)
70 | dec1 = self.upconv1(dec2)
71 | dec1 = torch.cat((dec1, enc1), dim=1)
72 | dec1 = self.decoder1(dec1)
73 | out = self.conv(dec1)
74 | return out
75 |
76 | @staticmethod
77 | def _block(in_channels, features, name):
78 | return nn.Sequential(
79 | OrderedDict(
80 | [
81 | (
82 | name + "conv1",
83 | nn.Conv2d(
84 | in_channels=in_channels,
85 | out_channels=features,
86 | kernel_size=3,
87 | padding=1,
88 | bias=False,
89 | ),
90 | ),
91 | # (name + "norm1", nn.BatchNorm2d(num_features=features)),
92 | (name + "relu1", nn.ReLU(inplace=True)),
93 | (
94 | name + "conv2",
95 | nn.Conv2d(
96 | in_channels=features,
97 | out_channels=features,
98 | kernel_size=3,
99 | padding=1,
100 | bias=False,
101 | ),
102 | ),
103 | # (name + "norm2", nn.BatchNorm2d(num_features=features)),
104 | (name + "relu2", nn.ReLU(inplace=True)),
105 | ]
106 | )
107 | )
108 |
109 |
110 |
111 |
--------------------------------------------------------------------------------
/models/DVP/options/test_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch.backends.cudnn as cudnn
3 | class TestOptions():
4 | def __init__(self):
5 | self.sf = 10 # save frequency
6 | self.op = "test/results"
7 | self.me = 30 # max epoch
8 |
9 | def parse(self):
10 | return self
--------------------------------------------------------------------------------
/models/DVP/test.sh:
--------------------------------------------------------------------------------
1 | python main_IRT.py --max_epoch 25 --input demo/colorization/goat_input --processed demo/colorization/goat_processed --model colorization --with_IRT 1 --IRT_initialization 1 --output ./result/colorization
--------------------------------------------------------------------------------
/models/DVP/vgg.py:
--------------------------------------------------------------------------------
1 | from torchvision import models
2 | import torch.nn as nn
3 | import torch
4 |
5 |
6 | class VGG19(torch.nn.Module):
7 | """docstring for Vgg19"""
8 | def __init__(self, requires_grad=False):
9 | super(VGG19, self).__init__()
10 | vgg_pretrained_features = models.vgg19(pretrained=True).features
11 | self.slice1 = torch.nn.Sequential()
12 | self.slice2 = torch.nn.Sequential()
13 | self.slice3 = torch.nn.Sequential()
14 | self.slice4 = torch.nn.Sequential()
15 | self.slice5 = torch.nn.Sequential()
16 | for x in range(4): # conv1_2
17 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
18 | for x in range(4, 9): # conv2_2
19 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
20 | for x in range(9, 14): # conv3_2
21 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
22 | for x in range(14, 23): # conv4_2
23 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
24 | for x in range(23, 32): # conv5_2
25 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
26 |
27 | if not requires_grad:
28 | for param in self.parameters():
29 | param.requires_grad = False
30 |
31 | def forward(self, X):
32 | h = self.slice1(X)
33 | h_relu1_2 = h
34 | h = self.slice2(h)
35 | h_relu2_2 = h
36 | h = self.slice3(h)
37 | h_relu3_2 = h
38 | h = self.slice4(h)
39 | h_relu4_2 = h
40 | h = self.slice5(h)
41 | h_relu5_2 = h
42 |
43 | out = {}
44 | out['conv1_2'] = h_relu1_2
45 | out['conv2_2'] = h_relu2_2
46 | out['conv3_2'] = h_relu3_2
47 | out['conv4_2'] = h_relu4_2
48 | out['conv5_2'] = h_relu5_2
49 |
50 | return out
51 |
52 |
53 | '''
54 | Sequential(
55 | (0): Sequential(
56 | (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
57 | (1): ReLU(inplace)
58 | (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
59 | (3): ReLU(inplace) ##con1_2
60 | (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
61 | (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
62 | (6): ReLU(inplace)
63 | (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
64 | (8): ReLU(inplace) ##con2_2
65 | (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
66 | (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
67 | (11): ReLU(inplace)
68 | (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
69 | (13): ReLU(inplace) ##con3_2
70 | (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
71 | (15): ReLU(inplace)
72 | (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
73 | (17): ReLU(inplace)
74 | (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
75 | (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
76 | (20): ReLU(inplace)
77 | (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
78 | (22): ReLU(inplace) ##conv4_2
79 | (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
80 | (24): ReLU(inplace)
81 | (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
82 | (26): ReLU(inplace)
83 | (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
84 | (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
85 | (29): ReLU(inplace)
86 | (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
87 | (31): ReLU(inplace) ##conv5_2
88 | (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
89 | (33): ReLU(inplace)
90 | (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
91 | (35): ReLU(inplace)
92 | (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
93 | )
94 | (1): AdaptiveAvgPool2d(output_size=(7, 7))
95 | (2): Sequential(
96 | (0): Linear(in_features=25088, out_features=4096, bias=True)
97 | (1): ReLU(inplace)
98 | (2): Dropout(p=0.5)
99 | (3): Linear(in_features=4096, out_features=4096, bias=True)
100 | (4): ReLU(inplace)
101 | (5): Dropout(p=0.5)
102 | (6): Linear(in_features=4096, out_features=1000, bias=True)
103 | )
104 | )
105 | '''
106 |
--------------------------------------------------------------------------------
/models/InstaColor/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | import importlib
4 | import numpy as np
5 | import torch
6 | import cv2
7 | from typing import List
8 | from tqdm import tqdm
9 | from utils.util import download_zipfile, mkdir
10 | from utils.v2i import convert_frames_to_video
11 |
12 | from models.InstaColor.models.base_model import BaseModel
13 | from models.InstaColor.utils import util
14 | from models.InstaColor.utils.datasets import *
15 | from utils.util import download_zipfile
16 |
17 | from detectron2 import model_zoo
18 | from detectron2.engine import DefaultPredictor
19 | from detectron2.config import get_cfg
20 |
21 | class InstaColor:
22 | def __init__(self, pretrained=True):
23 | # Bounding box predictor
24 | cfg = get_cfg()
25 | cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"))
26 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
27 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml")
28 | self.predictor = DefaultPredictor(cfg)
29 | self.model = None
30 |
31 | # download pretrained model
32 | if pretrained is True:
33 | download_zipfile("https://docs.google.com/uc?export=download&id=1Xb-DKAA9ibCVLqm8teKd1MWk6imjwTBh&confirm=t", "InstaColor_checkpoints.zip")
34 |
35 | def find_model_using_name(self, model_name):
36 | # Given the option --model [modelname],
37 | # the file "models/modelname_model.py"
38 | # will be imported.
39 | model_filename = "models.InstaColor.models." + model_name + "_model"
40 | modellib = importlib.import_module(model_filename)
41 |
42 | # In the file, the class called ModelNameModel() will
43 | # be instantiated. It has to be a subclass of BaseModel,
44 | # and it is case-insensitive.
45 | model = None
46 | target_model_name = model_name.replace('_', '') + 'model'
47 | for name, cls in modellib.__dict__.items():
48 | if name.lower() == target_model_name.lower() \
49 | and issubclass(cls, BaseModel):
50 | model = cls
51 |
52 | if model is None:
53 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
54 | exit(0)
55 |
56 | return model
57 |
58 |
59 | def create_model(self, opt):
60 | model = self.find_model_using_name(opt.model)
61 | self.model = model()
62 | self.model.initialize(opt)
63 | print("model [%s] was created" % (self.model.name()))
64 |
65 | def test_images(self, input_dir: str, file_names: List[str], opt):
66 | """
67 | Testing function
68 | """
69 | # get bounding box for each image
70 | print("Getting bounding boxes...")
71 | with torch.no_grad():
72 | for image_name in tqdm(file_names):
73 | img = cv2.imread(f"{input_dir}/{image_name}")
74 | lab_image = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
75 | l_channel, a_channel, b_channel = cv2.split(lab_image)
76 | l_stack = np.stack([l_channel, l_channel, l_channel], axis=2)
77 | outputs = self.predictor(l_stack)
78 | pred_bbox = outputs["instances"].pred_boxes.to(torch.device('cpu')).tensor.numpy()
79 | pred_scores = outputs["instances"].scores.cpu().data.numpy()
80 | np.savez(f"{opt.output_npz_dir}/{image_name.split('.')[0]}", bbox = pred_bbox, scores = pred_scores)
81 |
82 | # setup dataset loader
83 | opt.batch_size = 1
84 | opt.test_img_dir = input_dir
85 | dataset = Fusion_Testing_Dataset(opt, file_names, -1)
86 | dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size)
87 |
88 | # setup model to test
89 | self.create_model(opt)
90 | self.model.setup_to_test('coco_finetuned_mask_256_ffs')
91 |
92 | print("Colorizing...")
93 | # colorize image
94 | with torch.no_grad():
95 | for data_raw in tqdm(dataset_loader):
96 | data_raw['full_img'][0] = data_raw['full_img'][0].cuda()
97 | if data_raw['empty_box'][0] == 0:
98 | data_raw['cropped_img'][0] = data_raw['cropped_img'][0].cuda()
99 | box_info = data_raw['box_info'][0]
100 | box_info_2x = data_raw['box_info_2x'][0]
101 | box_info_4x = data_raw['box_info_4x'][0]
102 | box_info_8x = data_raw['box_info_8x'][0]
103 | cropped_data = util.get_colorization_data(data_raw['cropped_img'], opt, ab_thresh=0, p=opt.sample_p)
104 | full_img_data = util.get_colorization_data(data_raw['full_img'], opt, ab_thresh=0, p=opt.sample_p)
105 | self.model.set_input(cropped_data)
106 | self.model.set_fusion_input(full_img_data, [box_info, box_info_2x, box_info_4x, box_info_8x])
107 | self.model.forward()
108 | else:
109 | full_img_data = util.get_colorization_data(data_raw['full_img'], opt, ab_thresh=0, p=opt.sample_p)
110 | self.model.set_forward_without_box(full_img_data)
111 | self.model.save_current_imgs(os.path.join(opt.results_img_dir, data_raw['file_id'][0] + '.png'))
112 |
113 | def test(self, input_dir: str, output_path: str, opt):
114 | # get file names
115 | input_files = os.listdir(input_dir)
116 | #create colorized output save folder
117 | mkdir(opt.results_img_dir)
118 | #create bounding box save folder
119 | output_npz_dir = "{0}_bbox".format(input_dir)
120 | mkdir(output_npz_dir)
121 | opt.output_npz_dir = output_npz_dir
122 |
123 | for i in range(0, len(input_files), 5):
124 | self.test_images(input_dir, input_files[i : min(len(input_files), i + 5)], opt)
125 |
126 | output_frames = sorted(os.listdir(opt.results_img_dir))
127 | # resize back
128 | for output in output_frames:
129 | img = Image.open(os.path.join(opt.results_img_dir, output))
130 | img = img.resize((320, 180), Image.ANTIALIAS)
131 | img.save(os.path.join(opt.results_img_dir, output))
132 |
133 | convert_frames_to_video(opt.results_img_dir, output_path)
134 | # remove bounding box dir
135 | shutil.rmtree(output_npz_dir)
136 | # shutil.rmtree(opt.results_img_dir)
137 | print("Task Complete")
--------------------------------------------------------------------------------
/models/InstaColor/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torch.utils.data
3 | from models.InstaColor.data.base_data_loader import BaseDataLoader
4 | from models.InstaColor.data.base_dataset import BaseDataset
5 |
6 |
7 | def find_dataset_using_name(dataset_name):
8 | # Given the option --dataset_mode [datasetname],
9 | # the file "data/datasetname_dataset.py"
10 | # will be imported.
11 | dataset_filename = "models.InstaColor.data." + dataset_name + "_dataset"
12 | datasetlib = importlib.import_module(dataset_filename)
13 |
14 | # In the file, the class called DatasetNameDataset() will
15 | # be instantiated. It has to be a subclass of BaseDataset,
16 | # and it is case-insensitive.
17 | dataset = None
18 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
19 | for name, cls in datasetlib.__dict__.items():
20 | if name.lower() == target_dataset_name.lower() \
21 | and issubclass(cls, BaseDataset):
22 | dataset = cls
23 |
24 | if dataset is None:
25 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
26 | exit(0)
27 |
28 | return dataset
29 |
30 |
31 | def get_option_setter(dataset_name):
32 | dataset_class = find_dataset_using_name(dataset_name)
33 | return dataset_class.modify_commandline_options
34 |
35 |
36 | def create_dataset(opt):
37 | dataset = find_dataset_using_name(opt.dataset_mode)
38 | instance = dataset()
39 | instance.initialize(opt)
40 | print("dataset [%s] was created" % (instance.name()))
41 | return instance
42 |
43 |
44 | def CreateDataLoader(opt):
45 | data_loader = CustomDatasetDataLoader()
46 | data_loader.initialize(opt)
47 | return data_loader
48 |
49 |
50 | # Wrapper class of Dataset class that performs
51 | # multi-threaded data loading
52 | class CustomDatasetDataLoader(BaseDataLoader):
53 | def name(self):
54 | return 'CustomDatasetDataLoader'
55 |
56 | def initialize(self, opt):
57 | BaseDataLoader.initialize(self, opt)
58 | self.dataset = create_dataset(opt)
59 | self.dataloader = torch.utils.data.DataLoader(
60 | self.dataset,
61 | batch_size=opt.batch_size,
62 | shuffle=not opt.serial_batches,
63 | num_workers=int(opt.num_threads))
64 |
65 | def load_data(self):
66 | return self
67 |
68 | def __len__(self):
69 | return min(len(self.dataset), self.opt.max_dataset_size)
70 |
71 | def __iter__(self):
72 | for i, data in enumerate(self.dataloader):
73 | if i * self.opt.batch_size >= self.opt.max_dataset_size:
74 | break
75 | yield data
76 |
--------------------------------------------------------------------------------
/models/InstaColor/data/aligned_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import random
3 | import torchvision.transforms as transforms
4 | import torch
5 | from models.InstaColor.data.base_dataset import BaseDataset
6 | from models.InstaColor.data.image_folder import make_dataset
7 | from PIL import Image
8 |
9 |
10 | class AlignedDataset(BaseDataset):
11 | @staticmethod
12 | def modify_commandline_options(parser, is_train):
13 | return parser
14 |
15 | def initialize(self, opt):
16 | self.opt = opt
17 | self.root = opt.dataroot
18 | self.dir_AB = os.path.join(opt.dataroot, opt.phase)
19 | self.AB_paths = sorted(make_dataset(self.dir_AB))
20 | assert(opt.resize_or_crop == 'resize_and_crop')
21 |
22 | def __getitem__(self, index):
23 | AB_path = self.AB_paths[index]
24 | AB = Image.open(AB_path).convert('RGB')
25 | w, h = AB.size
26 | w2 = int(w / 2)
27 | A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
28 | B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
29 | A = transforms.ToTensor()(A)
30 | B = transforms.ToTensor()(B)
31 | w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
32 | h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
33 |
34 | A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
35 | B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
36 |
37 | A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A)
38 | B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B)
39 |
40 | if self.opt.which_direction == 'BtoA':
41 | input_nc = self.opt.output_nc
42 | output_nc = self.opt.input_nc
43 | else:
44 | input_nc = self.opt.input_nc
45 | output_nc = self.opt.output_nc
46 |
47 | if (not self.opt.no_flip) and random.random() < 0.5:
48 | idx = [i for i in range(A.size(2) - 1, -1, -1)]
49 | idx = torch.LongTensor(idx)
50 | A = A.index_select(2, idx)
51 | B = B.index_select(2, idx)
52 |
53 | if input_nc == 1: # RGB to gray
54 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
55 | A = tmp.unsqueeze(0)
56 |
57 | if output_nc == 1: # RGB to gray
58 | tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
59 | B = tmp.unsqueeze(0)
60 |
61 | return {'A': A, 'B': B,
62 | 'A_paths': AB_path, 'B_paths': AB_path}
63 |
64 | def __len__(self):
65 | return len(self.AB_paths)
66 |
67 | def name(self):
68 | return 'AlignedDataset'
69 |
--------------------------------------------------------------------------------
/models/InstaColor/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 | class BaseDataLoader():
2 | def __init__(self):
3 | pass
4 |
5 | def initialize(self, opt):
6 | self.opt = opt
7 | pass
8 |
9 | def load_data():
10 | return None
11 |
--------------------------------------------------------------------------------
/models/InstaColor/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image
3 | import torchvision.transforms as transforms
4 |
5 |
6 | class BaseDataset(data.Dataset):
7 | def __init__(self):
8 | super(BaseDataset, self).__init__()
9 |
10 | def name(self):
11 | return 'BaseDataset'
12 |
13 | @staticmethod
14 | def modify_commandline_options(parser, is_train):
15 | return parser
16 |
17 | def initialize(self, opt):
18 | pass
19 |
20 | def __len__(self):
21 | return 0
22 |
23 |
24 | def get_transform(opt):
25 | transform_list = []
26 | if opt.resize_or_crop == 'resize_and_crop':
27 | osize = [opt.loadSize, opt.loadSize]
28 | transform_list.append(transforms.Resize(osize, Image.BICUBIC))
29 | transform_list.append(transforms.RandomCrop(opt.fineSize))
30 | elif opt.resize_or_crop == 'crop':
31 | transform_list.append(transforms.RandomCrop(opt.fineSize))
32 | elif opt.resize_or_crop == 'scale_width':
33 | transform_list.append(transforms.Lambda(
34 | lambda img: __scale_width(img, opt.fineSize)))
35 | elif opt.resize_or_crop == 'scale_width_and_crop':
36 | transform_list.append(transforms.Lambda(
37 | lambda img: __scale_width(img, opt.loadSize)))
38 | transform_list.append(transforms.RandomCrop(opt.fineSize))
39 | elif opt.resize_or_crop == 'none':
40 | transform_list.append(transforms.Lambda(
41 | lambda img: __adjust(img)))
42 | else:
43 | raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop)
44 |
45 | if opt.isTrain and not opt.no_flip:
46 | transform_list.append(transforms.RandomHorizontalFlip())
47 |
48 | transform_list += [transforms.ToTensor(),
49 | transforms.Normalize((0.5, 0.5, 0.5),
50 | (0.5, 0.5, 0.5))]
51 | return transforms.Compose(transform_list)
52 |
53 | # just modify the width and height to be multiple of 4
54 |
55 |
56 | def __adjust(img):
57 | ow, oh = img.size
58 |
59 | # the size needs to be a multiple of this number,
60 | # because going through generator network may change img size
61 | # and eventually cause size mismatch error
62 | mult = 4
63 | if ow % mult == 0 and oh % mult == 0:
64 | return img
65 | w = (ow - 1) // mult
66 | w = (w + 1) * mult
67 | h = (oh - 1) // mult
68 | h = (h + 1) * mult
69 |
70 | if ow != w or oh != h:
71 | __print_size_warning(ow, oh, w, h)
72 |
73 | return img.resize((w, h), Image.BICUBIC)
74 |
75 |
76 | def __scale_width(img, target_width):
77 | ow, oh = img.size
78 |
79 | # the size needs to be a multiple of this number,
80 | # because going through generator network may change img size
81 | # and eventually cause size mismatch error
82 | mult = 4
83 | assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult
84 | if (ow == target_width and oh % mult == 0):
85 | return img
86 | w = target_width
87 | target_height = int(target_width * oh / ow)
88 | m = (target_height - 1) // mult
89 | h = (m + 1) * mult
90 |
91 | if target_height != h:
92 | __print_size_warning(target_width, target_height, w, h)
93 |
94 | return img.resize((w, h), Image.BICUBIC)
95 |
96 |
97 | def __print_size_warning(ow, oh, w, h):
98 | if not hasattr(__print_size_warning, 'has_printed'):
99 | print("The image size needs to be a multiple of 4. "
100 | "The loaded image size was (%d, %d), so it was adjusted to "
101 | "(%d, %d). This adjustment will be done to all images "
102 | "whose sizes are not multiples of 4" % (ow, oh, w, h))
103 | __print_size_warning.has_printed = True
104 |
--------------------------------------------------------------------------------
/models/InstaColor/data/color_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from models.InstaColor.data.base_dataset import BaseDataset, get_transform
3 | from models.InstaColor.data.image_folder import make_dataset
4 | from PIL import Image
5 |
6 |
7 | class ColorDataset(BaseDataset):
8 | @staticmethod
9 | def modify_commandline_options(parser, is_train):
10 | return parser
11 |
12 | def initialize(self, opt):
13 | self.opt = opt
14 | self.root = opt.dataroot
15 | self.dir_A = os.path.join(opt.dataroot)
16 |
17 | self.A_paths = make_dataset(self.dir_A)
18 |
19 | self.A_paths = sorted(self.A_paths)
20 |
21 | self.transform = get_transform(opt)
22 |
23 | def __getitem__(self, index):
24 | A_path = self.A_paths[index]
25 | A_img = Image.open(A_path).convert('RGB')
26 | A = self.transform(A_img)
27 | if self.opt.which_direction == 'BtoA':
28 | input_nc = self.opt.output_nc
29 | else:
30 | input_nc = self.opt.input_nc
31 |
32 | # convert to Lab
33 | # rgb2lab(A_img)
34 |
35 | if input_nc == 1: # RGB to gray
36 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
37 | A = tmp.unsqueeze(0)
38 |
39 | return {'A': A, 'A_paths': A_path}
40 |
41 | def __len__(self):
42 | return len(self.A_paths)
43 |
44 | def name(self):
45 | return 'ColorImageDataset'
46 |
--------------------------------------------------------------------------------
/models/InstaColor/data/image_folder.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that it also loads images from the current
5 | # directory as well as the subdirectories
6 | ###############################################################################
7 |
8 | import torch.utils.data as data
9 |
10 | from PIL import Image
11 | import os
12 | import os.path
13 |
14 | IMG_EXTENSIONS = [
15 | '.jpg', '.JPG', '.jpeg', '.JPEG',
16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
17 | ]
18 |
19 |
20 | def is_image_file(filename):
21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22 |
23 |
24 | def make_dataset(dir):
25 | images = []
26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
27 |
28 | for root, _, fnames in sorted(os.walk(dir)):
29 | for fname in fnames:
30 | if is_image_file(fname):
31 | path = os.path.join(root, fname)
32 | images.append(path)
33 |
34 | return images
35 |
36 |
37 | def default_loader(path):
38 | return Image.open(path).convert('RGB')
39 |
40 |
41 | class ImageFolder(data.Dataset):
42 |
43 | def __init__(self, root, transform=None, return_paths=False,
44 | loader=default_loader):
45 | imgs = make_dataset(root)
46 | if len(imgs) == 0:
47 | raise(RuntimeError("Found 0 images in: " + root + "\n"
48 | "Supported image extensions are: " +
49 | ",".join(IMG_EXTENSIONS)))
50 |
51 | self.root = root
52 | self.imgs = imgs
53 | self.transform = transform
54 | self.return_paths = return_paths
55 | self.loader = loader
56 |
57 | def __getitem__(self, index):
58 | path = self.imgs[index]
59 | img = self.loader(path)
60 | if self.transform is not None:
61 | img = self.transform(img)
62 | if self.return_paths:
63 | return img, path
64 | else:
65 | return img
66 |
67 | def __len__(self):
68 | return len(self.imgs)
69 |
--------------------------------------------------------------------------------
/models/InstaColor/data/single_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from models.InstaColor.data.base_dataset import BaseDataset, get_transform
3 | from models.InstaColor.data.image_folder import make_dataset
4 | from PIL import Image
5 |
6 |
7 | class SingleDataset(BaseDataset):
8 | @staticmethod
9 | def modify_commandline_options(parser, is_train):
10 | return parser
11 |
12 | def initialize(self, opt):
13 | self.opt = opt
14 | self.root = opt.dataroot
15 | self.dir_A = os.path.join(opt.dataroot)
16 |
17 | self.A_paths = make_dataset(self.dir_A)
18 |
19 | self.A_paths = sorted(self.A_paths)
20 |
21 | self.transform = get_transform(opt)
22 |
23 | def __getitem__(self, index):
24 | A_path = self.A_paths[index]
25 | A_img = Image.open(A_path).convert('RGB')
26 | A = self.transform(A_img)
27 | if self.opt.which_direction == 'BtoA':
28 | input_nc = self.opt.output_nc
29 | else:
30 | input_nc = self.opt.input_nc
31 |
32 | if input_nc == 1: # RGB to gray
33 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
34 | A = tmp.unsqueeze(0)
35 |
36 | return {'A': A, 'A_paths': A_path}
37 |
38 | def __len__(self):
39 | return len(self.A_paths)
40 |
41 | def name(self):
42 | return 'SingleImageDataset'
43 |
--------------------------------------------------------------------------------
/models/InstaColor/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from models.InstaColor.models.base_model import BaseModel
3 |
4 | def find_model_using_name(model_name):
5 | # Given the option --model [modelname],
6 | # the file "models/modelname_model.py"
7 | # will be imported.
8 | model_filename = "models.InstaColor.models." + model_name + "_model"
9 | modellib = importlib.import_module(model_filename)
10 |
11 | # In the file, the class called ModelNameModel() will
12 | # be instantiated. It has to be a subclass of BaseModel,
13 | # and it is case-insensitive.
14 | model = None
15 | target_model_name = model_name.replace('_', '') + 'model'
16 | for name, cls in modellib.__dict__.items():
17 | if name.lower() == target_model_name.lower() \
18 | and issubclass(cls, BaseModel):
19 | model = cls
20 |
21 | if model is None:
22 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
23 | exit(0)
24 |
25 | return model
26 |
27 | def get_option_setter(model_name):
28 | model_class = find_model_using_name(model_name)
29 | return model_class.modify_commandline_options
--------------------------------------------------------------------------------
/models/InstaColor/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from . import networks
5 |
6 |
7 | class BaseModel():
8 |
9 | # modify parser to add command line options,
10 | # and also change the default values if needed
11 | @staticmethod
12 | def modify_commandline_options(parser, is_train):
13 | return parser
14 |
15 | def name(self):
16 | return 'BaseModel'
17 |
18 | def initialize(self, opt):
19 | self.opt = opt
20 | self.gpu_ids = opt.gpu_ids
21 | self.isTrain = opt.isTrain
22 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
23 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
24 | if opt.resize_or_crop != 'scale_width':
25 | torch.backends.cudnn.benchmark = True
26 | self.loss_names = []
27 | self.model_names = []
28 | self.visual_names = []
29 | self.image_paths = []
30 |
31 | def set_input(self, input):
32 | self.input = input
33 |
34 | def forward(self):
35 | pass
36 |
37 | # load and print networks; create schedulers
38 | def setup(self, opt, parser=None):
39 | if self.isTrain:
40 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
41 |
42 | if not self.isTrain or opt.load_model:
43 | self.load_networks(opt.which_epoch)
44 |
45 | # make models eval mode during test time
46 | def eval(self):
47 | for name in self.model_names:
48 | if isinstance(name, str):
49 | net = getattr(self, 'net' + name)
50 | net.eval()
51 |
52 | # used in test time, wrapping `forward` in no_grad() so we don't save
53 | # intermediate steps for backprop
54 | def test(self, compute_losses=False):
55 | with torch.no_grad():
56 | self.forward()
57 | if(compute_losses):
58 | self.compute_losses_G()
59 |
60 | # get image paths
61 | def get_image_paths(self):
62 | return self.image_paths
63 |
64 | def optimize_parameters(self):
65 | pass
66 |
67 | # update learning rate (called once every epoch)
68 | def update_learning_rate(self):
69 | for scheduler in self.schedulers:
70 | scheduler.step()
71 | lr = self.optimizers[0].param_groups[0]['lr']
72 | # print('learning rate = %.7f' % lr)
73 |
74 | # return visualization images. train.py will display these images, and save the images to a html
75 | def get_current_visuals(self):
76 | visual_ret = OrderedDict()
77 | for name in self.visual_names:
78 | if isinstance(name, str):
79 | visual_ret[name] = getattr(self, name)
80 | return visual_ret
81 |
82 | # return traning losses/errors. train.py will print out these errors as debugging information
83 | def get_current_losses(self):
84 | errors_ret = OrderedDict()
85 | for name in self.loss_names:
86 | if isinstance(name, str):
87 | # float(...) works for both scalar tensor and float number
88 | errors_ret[name] = float(getattr(self, 'loss_' + name))
89 | return errors_ret
90 |
91 | # save models to the disk
92 | def save_networks(self, which_epoch):
93 | for name in self.model_names:
94 | if isinstance(name, str):
95 | save_filename = '%s_net_%s.pth' % (which_epoch, name)
96 | save_path = os.path.join(self.save_dir, save_filename)
97 | net = getattr(self, 'net' + name)
98 |
99 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
100 | torch.save(net.module.cpu().state_dict(), save_path)
101 | net.cuda(self.gpu_ids[0])
102 | else:
103 | torch.save(net.cpu().state_dict(), save_path)
104 |
105 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
106 | key = keys[i]
107 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
108 | if module.__class__.__name__.startswith('InstanceNorm') and \
109 | (key == 'running_mean' or key == 'running_var'):
110 | if getattr(module, key) is None:
111 | state_dict.pop('.'.join(keys))
112 | if module.__class__.__name__.startswith('InstanceNorm') and \
113 | (key == 'num_batches_tracked'):
114 | state_dict.pop('.'.join(keys))
115 | else:
116 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
117 |
118 | # load models from the disk
119 | def load_networks(self, which_epoch):
120 | for name in self.model_names:
121 | if isinstance(name, str):
122 | load_filename = '%s_net_%s.pth' % (which_epoch, name)
123 | load_path = os.path.join(self.save_dir, load_filename)
124 | if os.path.isfile(load_path) is False:
125 | continue
126 | net = getattr(self, 'net' + name)
127 | if isinstance(net, torch.nn.DataParallel):
128 | net = net.module
129 | print('loading the model from %s' % load_path)
130 | # if you are using PyTorch newer than 0.4 (e.g., built from
131 | # GitHub source), you can remove str() on self.device
132 | state_dict = torch.load(load_path, map_location=str(self.device))
133 | if hasattr(state_dict, '_metadata'):
134 | del state_dict._metadata
135 |
136 | # patch InstanceNorm checkpoints prior to 0.4
137 | # for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
138 | # self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
139 | net.load_state_dict(state_dict, strict=False)
140 |
141 | # print network information
142 | def print_networks(self, verbose):
143 | print('---------- Networks initialized -------------')
144 | for name in self.model_names:
145 | if isinstance(name, str):
146 | net = getattr(self, 'net' + name)
147 | num_params = 0
148 | for param in net.parameters():
149 | num_params += param.numel()
150 | if verbose:
151 | print(net)
152 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
153 | print('-----------------------------------------------')
154 |
155 | # set requies_grad=Fasle to avoid computation
156 | def set_requires_grad(self, nets, requires_grad=False):
157 | if not isinstance(nets, list):
158 | nets = [nets]
159 | for net in nets:
160 | if net is not None:
161 | for param in net.parameters():
162 | param.requires_grad = requires_grad
163 |
--------------------------------------------------------------------------------
/models/InstaColor/models/fusion_model.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append("..")
3 |
4 | import torch
5 | from utils import color_format
6 | from models.InstaColor.utils import util
7 | from .base_model import BaseModel
8 | from . import networks
9 | import numpy as np
10 | from skimage import io
11 | from skimage import img_as_ubyte
12 |
13 |
14 |
15 | class FusionModel(BaseModel):
16 | def name(self):
17 | return 'FusionModel'
18 |
19 | @staticmethod
20 | def modify_commandline_options(parser, is_train=True):
21 | return parser
22 |
23 | def initialize(self, opt):
24 | BaseModel.initialize(self, opt)
25 | self.model_names = ['G', 'GF']
26 |
27 | # load/define networks
28 | num_in = opt.input_nc + opt.output_nc + 1
29 |
30 | self.netG = networks.define_G(num_in, opt.output_nc, opt.ngf,
31 | 'instance', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids,
32 | use_tanh=True, classification=False)
33 | self.netG.eval()
34 |
35 | self.netGF = networks.define_G(num_in, opt.output_nc, opt.ngf,
36 | 'fusion', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids,
37 | use_tanh=True, classification=False)
38 | self.netGF.eval()
39 |
40 | self.netGComp = networks.define_G(num_in, opt.output_nc, opt.ngf,
41 | 'siggraph', opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids,
42 | use_tanh=True, classification=opt.classification)
43 | self.netGComp.eval()
44 |
45 |
46 | def set_input(self, input):
47 | AtoB = self.opt.which_direction == 'AtoB'
48 | self.real_A = input['A' if AtoB else 'B'].to(self.device)
49 | self.real_B = input['B' if AtoB else 'A'].to(self.device)
50 | self.hint_B = input['hint_B'].to(self.device)
51 |
52 | self.mask_B = input['mask_B'].to(self.device)
53 | self.mask_B_nc = self.mask_B + self.opt.mask_cent
54 |
55 | self.real_B_enc = util.encode_ab_ind(self.real_B[:, :, ::4, ::4], self.opt)
56 |
57 | def set_fusion_input(self, input, box_info):
58 | AtoB = self.opt.which_direction == 'AtoB'
59 | self.full_real_A = input['A' if AtoB else 'B'].to(self.device)
60 | self.full_real_B = input['B' if AtoB else 'A'].to(self.device)
61 |
62 | self.full_hint_B = input['hint_B'].to(self.device)
63 | self.full_mask_B = input['mask_B'].to(self.device)
64 |
65 | self.full_mask_B_nc = self.full_mask_B + self.opt.mask_cent
66 | self.full_real_B_enc = util.encode_ab_ind(self.full_real_B[:, :, ::4, ::4], self.opt)
67 | self.box_info_list = box_info
68 |
69 | def set_forward_without_box(self, input):
70 | AtoB = self.opt.which_direction == 'AtoB'
71 | self.full_real_A = input['A' if AtoB else 'B'].to(self.device)
72 | self.full_real_B = input['B' if AtoB else 'A'].to(self.device)
73 | # self.image_paths = input['A_paths' if AtoB else 'B_paths']
74 | self.full_hint_B = input['hint_B'].to(self.device)
75 | self.full_mask_B = input['mask_B'].to(self.device)
76 | self.full_mask_B_nc = self.full_mask_B + self.opt.mask_cent
77 | self.full_real_B_enc = util.encode_ab_ind(self.full_real_B[:, :, ::4, ::4], self.opt)
78 |
79 | (_, self.comp_B_reg) = self.netGComp(self.full_real_A, self.full_hint_B, self.full_mask_B)
80 | self.fake_B_reg = self.comp_B_reg
81 |
82 | def forward(self):
83 | (_, feature_map) = self.netG(self.real_A, self.hint_B, self.mask_B)
84 | self.fake_B_reg = self.netGF(self.full_real_A, self.full_hint_B, self.full_mask_B, feature_map, self.box_info_list)
85 |
86 | def save_current_imgs(self, path):
87 | out_img = torch.clamp(color_format.lab2rgb(torch.cat((self.full_real_A.type(torch.cuda.FloatTensor), self.fake_B_reg.type(torch.cuda.FloatTensor)), dim=1), self.opt), 0.0, 1.0)
88 | out_img = np.transpose(out_img.cpu().data.numpy()[0], (1, 2, 0))
89 | io.imsave(path, img_as_ubyte(out_img))
90 |
91 | def setup_to_test(self, fusion_weight_path):
92 | GF_path = 'checkpoints/{0}/latest_net_GF.pth'.format(fusion_weight_path)
93 | print('load Fusion model from %s' % GF_path)
94 | GF_state_dict = torch.load(GF_path)
95 |
96 | # G_path = 'checkpoints/coco_finetuned_mask_256/latest_net_G.pth' # fine tuned on cocostuff
97 | G_path = 'checkpoints/{0}/latest_net_G.pth'.format(fusion_weight_path)
98 | G_state_dict = torch.load(G_path)
99 |
100 | # GComp_path = 'checkpoints/siggraph_retrained/latest_net_G.pth' # original net
101 | # GComp_path = 'checkpoints/coco_finetuned_mask_256/latest_net_GComp.pth' # fine tuned on cocostuff
102 | GComp_path = 'checkpoints/{0}/latest_net_GComp.pth'.format(fusion_weight_path)
103 | GComp_state_dict = torch.load(GComp_path)
104 |
105 | self.netGF.load_state_dict(GF_state_dict, strict=False)
106 | self.netG.module.load_state_dict(G_state_dict, strict=False)
107 | self.netGComp.module.load_state_dict(GComp_state_dict, strict=False)
108 | self.netGF.eval()
109 | self.netG.eval()
110 | self.netGComp.eval()
--------------------------------------------------------------------------------
/models/InstaColor/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from models.InstaColor.utils import util
4 | import torch
5 | import models.InstaColor.models as models
6 | import models.InstaColor.data as data
7 |
8 |
9 | class BaseOptions():
10 | def __init__(self):
11 | self.initialized = False
12 |
13 | def initialize(self, parser):
14 | parser.add_argument('--batch_size', type=int, default=25, help='input batch size')
15 | parser.add_argument('--loadSize', type=int, default=256, help='scale images to this size')
16 | parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')
17 | parser.add_argument('--input_nc', type=int, default=1, help='# of input image channels')
18 | parser.add_argument('--output_nc', type=int, default=2, help='# of output image channels')
19 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
20 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
21 | parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD')
22 | parser.add_argument('--which_model_netG', type=str, default='siggraph', help='selects model to use for netG')
23 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
24 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
25 | parser.add_argument('--dataset_mode', type=str, default='aligned', help='chooses how datasets are loaded. [unaligned | aligned | single]')
26 | parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
27 | parser.add_argument('--nThreads', default=4, type=int, help='# threads for loading data')
28 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
29 | parser.add_argument('--norm', type=str, default='batch', help='instance normalization or batch normalization')
30 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
31 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
32 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
33 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
34 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
35 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
36 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"),
37 | help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
38 | parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
39 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
40 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
41 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
42 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{which_model_netG}_size{loadSize}')
43 | parser.add_argument('--ab_norm', type=float, default=110., help='colorization normalization factor')
44 | parser.add_argument('--ab_max', type=float, default=110., help='maximimum ab value')
45 | parser.add_argument('--ab_quant', type=float, default=10., help='quantization factor')
46 | parser.add_argument('--l_norm', type=float, default=100., help='colorization normalization factor')
47 | parser.add_argument('--l_cent', type=float, default=50., help='colorization centering factor')
48 | parser.add_argument('--mask_cent', type=float, default=.5, help='mask centering factor')
49 | parser.add_argument('--sample_p', type=float, default=1.0, help='sampling geometric distribution, 1.0 means no hints')
50 | parser.add_argument('--sample_Ps', type=int, nargs='+', default=[1, 2, 3, 4, 5, 6, 7, 8, 9, ], help='patch sizes')
51 |
52 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
53 | parser.add_argument('--classification', action='store_true', help='backprop trunk using classification, otherwise use regression')
54 | parser.add_argument('--phase', type=str, default='val', help='train_small, train, val, test, etc')
55 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
56 | parser.add_argument('--how_many', type=int, default=5, help='how many test images to run')
57 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
58 |
59 | parser.add_argument('--load_model', action='store_true', help='load the latest model')
60 | parser.add_argument('--half', action='store_true', help='half precision model')
61 |
62 | self.initialized = True
63 | return parser
64 |
65 | def gather_options(self):
66 | # initialize parser with basic options
67 | if not self.initialized:
68 | parser = argparse.ArgumentParser(
69 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
70 | parser = self.initialize(parser)
71 |
72 | # get the basic options
73 | opt, _ = parser.parse_known_args()
74 |
75 | # modify model-related parser options
76 | model_name = opt.model
77 | model_option_setter = models.get_option_setter(model_name)
78 | parser = model_option_setter(parser, self.isTrain)
79 | opt, _ = parser.parse_known_args() # parse again with the new defaults
80 |
81 | # modify dataset-related parser options
82 | dataset_name = opt.dataset_mode
83 | dataset_option_setter = data.get_option_setter(dataset_name)
84 | parser = dataset_option_setter(parser, self.isTrain)
85 |
86 | self.parser = parser
87 |
88 | return parser.parse_args()
89 |
90 | def print_options(self, opt):
91 | message = ''
92 | message += '----------------- Options ---------------\n'
93 | for k, v in sorted(vars(opt).items()):
94 | comment = ''
95 | default = self.parser.get_default(k)
96 | if v != default:
97 | comment = '\t[default: %s]' % str(default)
98 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
99 | message += '----------------- End -------------------'
100 | print(message)
101 |
102 | # save to the disk
103 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
104 | util.mkdirs(expr_dir)
105 | file_name = os.path.join(expr_dir, 'opt.txt')
106 | with open(file_name, 'wt') as opt_file:
107 | opt_file.write(message)
108 | opt_file.write('\n')
109 |
110 | def parse(self):
111 |
112 | opt = self.gather_options()
113 | opt.isTrain = self.isTrain # train or test
114 |
115 | # process opt.suffix
116 | if opt.suffix:
117 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
118 | opt.name = opt.name + suffix
119 |
120 | # self.print_options(opt)
121 |
122 | # set gpu ids
123 | str_ids = opt.gpu_ids.split(',')
124 | opt.gpu_ids = []
125 | for str_id in str_ids:
126 | id = int(str_id)
127 | if id >= 0:
128 | opt.gpu_ids.append(id)
129 | if len(opt.gpu_ids) > 0:
130 | torch.cuda.set_device(opt.gpu_ids[0])
131 | opt.A = 2 * opt.ab_max / opt.ab_quant + 1
132 | opt.B = opt.A
133 |
134 | self.opt = opt
135 | return self.opt
136 |
--------------------------------------------------------------------------------
/models/InstaColor/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 | class TestOptions(BaseOptions):
4 | def initialize(self, parser):
5 | BaseOptions.initialize(self, parser)
6 | parser.add_argument('--test_img_dir', type=str, default='example', help='testing images folder')
7 | parser.add_argument('--results_img_dir', type=str, default='test/results', help='save the results image folder')
8 | parser.add_argument('--name', type=str, default='test_fusion', help='name of the experiment. It decides where to store samples and models')
9 | parser.add_argument('--model', type=str, default='fusion',
10 | help='chooses which model to use. cycle_gan, pix2pix, test')
11 | parser.add_argument('--display_freq', type=int, default=2000, help='frequency of showing training results on screen')
12 | parser.add_argument('--display_ncols', type=int, default=5, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
13 | parser.add_argument('--update_html_freq', type=int, default=10000, help='frequency of saving training results to html')
14 | parser.add_argument('--print_freq', type=int, default=2000, help='frequency of showing training results on console')
15 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
16 | parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
17 | parser.add_argument('--epoch_count', type=int, default=0, help='the starting epoch count, we save the model by , +, ...')
18 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
19 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
20 | parser.add_argument('--beta1', type=float, default=0.9, help='momentum term of adam')
21 | parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
22 | parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
23 | parser.add_argument('--lambda_GAN', type=float, default=0., help='weight for GAN loss')
24 | parser.add_argument('--lambda_A', type=float, default=1., help='weight for cycle loss (A -> B -> A)')
25 | parser.add_argument('--lambda_B', type=float, default=1., help='weight for cycle loss (B -> A -> B)')
26 | parser.add_argument('--lambda_identity', type=float, default=0.5,
27 | help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss.'
28 | 'For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')
29 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
30 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
31 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau')
32 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
33 | parser.add_argument('--avg_loss_alpha', type=float, default=.986, help='exponential averaging weight for displaying loss')
34 | self.isTrain = False
35 | return parser
--------------------------------------------------------------------------------
/models/InstaColor/utils/datasets.py:
--------------------------------------------------------------------------------
1 | from os import listdir
2 | from os.path import isfile, join
3 | from random import sample
4 |
5 | import numpy as np
6 | import torch
7 | import torch.utils.data as Data
8 | import torchvision.transforms as transforms
9 |
10 | from models.InstaColor.utils.image_utils import *
11 |
12 |
13 | class Fusion_Testing_Dataset(Data.Dataset):
14 | def __init__(self, opt, filenames, box_num=8):
15 | self.PRED_BBOX_DIR = '{0}_bbox'.format(opt.test_img_dir)
16 | self.IMAGE_DIR = opt.test_img_dir
17 | self.IMAGE_ID_LIST = filenames
18 |
19 | self.transforms = transforms.Compose([transforms.Resize((opt.fineSize, opt.fineSize), interpolation=2),
20 | transforms.ToTensor()])
21 | self.final_size = opt.fineSize
22 | self.box_num = box_num
23 |
24 | def __getitem__(self, index):
25 | pred_info_path = join(self.PRED_BBOX_DIR, self.IMAGE_ID_LIST[index].split('.')[0] + '.npz')
26 | output_image_path = join(self.IMAGE_DIR, self.IMAGE_ID_LIST[index])
27 | pred_bbox = gen_maskrcnn_bbox_fromPred(pred_info_path, self.box_num)
28 |
29 | img_list = []
30 | pil_img = read_to_pil(output_image_path)
31 | img_list.append(self.transforms(pil_img))
32 |
33 | cropped_img_list = []
34 | index_list = range(len(pred_bbox))
35 | box_info, box_info_2x, box_info_4x, box_info_8x = np.zeros((4, len(index_list), 6))
36 | for i in index_list:
37 | startx, starty, endx, endy = pred_bbox[i]
38 | box_info[i] = np.array(get_box_info(pred_bbox[i], pil_img.size, self.final_size))
39 | box_info_2x[i] = np.array(get_box_info(pred_bbox[i], pil_img.size, self.final_size // 2))
40 | box_info_4x[i] = np.array(get_box_info(pred_bbox[i], pil_img.size, self.final_size // 4))
41 | box_info_8x[i] = np.array(get_box_info(pred_bbox[i], pil_img.size, self.final_size // 8))
42 | cropped_img = self.transforms(pil_img.crop((startx, starty, endx, endy)))
43 | cropped_img_list.append(cropped_img)
44 | output = {}
45 | output['full_img'] = torch.stack(img_list)
46 | output['file_id'] = self.IMAGE_ID_LIST[index].split('.')[0]
47 | if len(pred_bbox) > 0:
48 | output['cropped_img'] = torch.stack(cropped_img_list)
49 | output['box_info'] = torch.from_numpy(box_info).type(torch.long)
50 | output['box_info_2x'] = torch.from_numpy(box_info_2x).type(torch.long)
51 | output['box_info_4x'] = torch.from_numpy(box_info_4x).type(torch.long)
52 | output['box_info_8x'] = torch.from_numpy(box_info_8x).type(torch.long)
53 | output['empty_box'] = False
54 | else:
55 | output['empty_box'] = True
56 | return output
57 |
58 | def __len__(self):
59 | return len(self.IMAGE_ID_LIST)
60 |
61 |
62 | class Training_Full_Dataset(Data.Dataset):
63 | '''
64 | Training on COCOStuff dataset. [train2017.zip]
65 |
66 | Download the training set from https://github.com/nightrome/cocostuff
67 | '''
68 | def __init__(self, opt):
69 | self.IMAGE_DIR = opt.train_img_dir
70 | self.transforms = transforms.Compose([transforms.Resize((opt.fineSize, opt.fineSize), interpolation=2),
71 | transforms.ToTensor()])
72 | self.IMAGE_ID_LIST = [f for f in listdir(self.IMAGE_DIR) if isfile(join(self.IMAGE_DIR, f))]
73 |
74 | def __getitem__(self, index):
75 | output_image_path = join(self.IMAGE_DIR, self.IMAGE_ID_LIST[index])
76 | rgb_img, gray_img = gen_gray_color_pil(output_image_path)
77 | output = {}
78 | output['rgb_img'] = self.transforms(rgb_img)
79 | output['gray_img'] = self.transforms(gray_img)
80 | return output
81 |
82 | def __len__(self):
83 | return len(self.IMAGE_ID_LIST)
84 |
85 |
86 | class Training_Instance_Dataset(Data.Dataset):
87 | '''
88 | Training on COCOStuff dataset. [train2017.zip]
89 |
90 | Download the training set from https://github.com/nightrome/cocostuff
91 |
92 | Make sure you've predicted all the images' bounding boxes using inference_bbox.py
93 |
94 | It would be better if you can filter out the images which don't have any box.
95 | '''
96 | def __init__(self, opt):
97 | self.PRED_BBOX_DIR = '{0}_bbox'.format(opt.train_img_dir)
98 | self.IMAGE_DIR = opt.train_img_dir
99 | self.IMAGE_ID_LIST = [f for f in listdir(self.IMAGE_DIR) if isfile(join(self.IMAGE_DIR, f))]
100 | self.transforms = transforms.Compose([
101 | transforms.Resize((opt.fineSize, opt.fineSize), interpolation=2),
102 | transforms.ToTensor()
103 | ])
104 |
105 | def __getitem__(self, index):
106 | pred_info_path = join(self.PRED_BBOX_DIR, self.IMAGE_ID_LIST[index].split('.')[0] + '.npz')
107 | output_image_path = join(self.IMAGE_DIR, self.IMAGE_ID_LIST[index])
108 | pred_bbox = gen_maskrcnn_bbox_fromPred(pred_info_path)
109 |
110 | rgb_img, gray_img = gen_gray_color_pil(output_image_path)
111 |
112 | index_list = range(len(pred_bbox))
113 | index_list = sample(index_list, 1)
114 | startx, starty, endx, endy = pred_bbox[index_list[0]]
115 | output = {}
116 | output['rgb_img'] = self.transforms(rgb_img.crop((startx, starty, endx, endy)))
117 | output['gray_img'] = self.transforms(gray_img.crop((startx, starty, endx, endy)))
118 | return output
119 |
120 | def __len__(self):
121 | return len(self.IMAGE_ID_LIST)
122 |
123 |
124 | class Training_Fusion_Dataset(Data.Dataset):
125 | '''
126 | Training on COCOStuff dataset. [train2017.zip]
127 |
128 | Download the training set from https://github.com/nightrome/cocostuff
129 |
130 | Make sure you've predicted all the images' bounding boxes using inference_bbox.py
131 |
132 | It would be better if you can filter out the images which don't have any box.
133 | '''
134 | def __init__(self, opt, box_num=8):
135 | self.PRED_BBOX_DIR = '{0}_bbox'.format(opt.train_img_dir)
136 | self.IMAGE_DIR = opt.train_img_dir
137 | self.IMAGE_ID_LIST = [f for f in listdir(self.IMAGE_DIR) if isfile(join(self.IMAGE_DIR, f))]
138 |
139 | self.transforms = transforms.Compose([transforms.Resize((opt.fineSize, opt.fineSize), interpolation=2),
140 | transforms.ToTensor()])
141 | self.final_size = opt.fineSize
142 | self.box_num = box_num
143 |
144 | def __getitem__(self, index):
145 | pred_info_path = join(self.PRED_BBOX_DIR, self.IMAGE_ID_LIST[index].split('.')[0] + '.npz')
146 | output_image_path = join(self.IMAGE_DIR, self.IMAGE_ID_LIST[index])
147 | pred_bbox = gen_maskrcnn_bbox_fromPred(pred_info_path, self.box_num)
148 |
149 | full_rgb_list = []
150 | full_gray_list = []
151 | rgb_img, gray_image = gen_gray_color_pil(output_image_path)
152 | full_rgb_list.append(self.transforms(rgb_img))
153 | full_gray_list.append(self.transforms(gray_image))
154 |
155 | cropped_rgb_list = []
156 | cropped_gray_list = []
157 | index_list = range(len(pred_bbox))
158 | box_info, box_info_2x, box_info_4x, box_info_8x = np.zeros((4, len(index_list), 6))
159 | for i in range(len(index_list)):
160 | startx, starty, endx, endy = pred_bbox[i]
161 | box_info[i] = np.array(get_box_info(pred_bbox[i], rgb_img.size, self.final_size))
162 | box_info_2x[i] = np.array(get_box_info(pred_bbox[i], rgb_img.size, self.final_size // 2))
163 | box_info_4x[i] = np.array(get_box_info(pred_bbox[i], rgb_img.size, self.final_size // 4))
164 | box_info_8x[i] = np.array(get_box_info(pred_bbox[i], rgb_img.size, self.final_size // 8))
165 | cropped_rgb_list.append(self.transforms(rgb_img.crop((startx, starty, endx, endy))))
166 | cropped_gray_list.append(self.transforms(gray_image.crop((startx, starty, endx, endy))))
167 | output = {}
168 | output['cropped_rgb'] = torch.stack(cropped_rgb_list)
169 | output['cropped_gray'] = torch.stack(cropped_gray_list)
170 | output['full_rgb'] = torch.stack(full_rgb_list)
171 | output['full_gray'] = torch.stack(full_gray_list)
172 | output['box_info'] = torch.from_numpy(box_info).type(torch.long)
173 | output['box_info_2x'] = torch.from_numpy(box_info_2x).type(torch.long)
174 | output['box_info_4x'] = torch.from_numpy(box_info_4x).type(torch.long)
175 | output['box_info_8x'] = torch.from_numpy(box_info_8x).type(torch.long)
176 | output['file_id'] = self.IMAGE_ID_LIST[index]
177 | return output
178 |
179 | def __len__(self):
180 | return len(self.IMAGE_ID_LIST)
--------------------------------------------------------------------------------
/models/InstaColor/utils/download.py:
--------------------------------------------------------------------------------
1 | #taken from this StackOverflow answer: https://stackoverflow.com/a/39225039
2 | from importlib_metadata import requires
3 | import requests
4 | from os.path import join, isdir
5 | import os
6 |
7 | def download_file_from_google_drive(id, destination):
8 | URL = "https://docs.google.com/uc?export=download"
9 |
10 | session = requests.Session()
11 |
12 | response = session.get(URL, params = { 'id' : id }, stream = True)
13 | token = get_confirm_token(response)
14 |
15 | if token:
16 | params = { 'id' : id, 'confirm' : token }
17 | response = session.get(URL, params = params, stream = True)
18 |
19 |
20 | def get_confirm_token(response):
21 | for key, value in response.cookies.items():
22 | if key.startswith('download_warning'):
23 | return value
24 |
25 | return None
26 |
27 | def download_coco_dataset(dataset_dir: str):
28 | print('download cocostuff training dataset')
29 | url = "http://images.cocodataset.org/zips/train2017.zip"
30 | response = requests.get(url, stream = True)
31 | if isdir(join(dataset_dir, "cocostuff")) is False:
32 | os.makedirs(join(dataset_dir, "cocostuff"))
33 | save_response_content(response, join(dataset_dir, "cocostuff", "train.zip"))
34 |
--------------------------------------------------------------------------------
/models/InstaColor/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from skimage import color
4 |
5 |
6 | def gen_gray_color_pil(color_img_path):
7 | '''
8 | return: RGB and GRAY pillow image object
9 | '''
10 | rgb_img = Image.open(color_img_path)
11 | if len(np.asarray(rgb_img).shape) == 2:
12 | rgb_img = np.stack([np.asarray(rgb_img), np.asarray(rgb_img), np.asarray(rgb_img)], 2)
13 | rgb_img = Image.fromarray(rgb_img)
14 | gray_img = np.round(color.rgb2gray(np.asarray(rgb_img)) * 255.0).astype(np.uint8)
15 | gray_img = np.stack([gray_img, gray_img, gray_img], -1)
16 | gray_img = Image.fromarray(gray_img)
17 | return rgb_img, gray_img
18 |
19 | def read_to_pil(img_path):
20 | '''
21 | return: pillow image object HxWx3
22 | '''
23 | out_img = Image.open(img_path)
24 | if len(np.asarray(out_img).shape) == 2:
25 | out_img = np.stack([np.asarray(out_img), np.asarray(out_img), np.asarray(out_img)], 2)
26 | out_img = Image.fromarray(out_img)
27 | return out_img
28 |
29 | def gen_maskrcnn_bbox_fromPred(pred_data_path, box_num_upbound=-1):
30 | '''
31 | ## Arguments:
32 | - pred_data_path: Detectron2 predict results
33 | - box_num_upbound: object bounding boxes number. Default: -1 means use all the instances.
34 | '''
35 | pred_data = np.load(pred_data_path)
36 | assert 'bbox' in pred_data
37 | assert 'scores' in pred_data
38 | pred_bbox = pred_data['bbox'].astype(np.int32)
39 | if box_num_upbound > 0 and pred_bbox.shape[0] > box_num_upbound:
40 | pred_scores = pred_data['scores']
41 | index_mask = np.argsort(pred_scores, axis=0)[pred_scores.shape[0] - box_num_upbound: pred_scores.shape[0]]
42 | pred_bbox = pred_bbox[index_mask]
43 | # pred_scores = pred_data['scores']
44 | # index_mask = pred_scores > 0.9
45 | # pred_bbox = pred_bbox[index_mask].astype(np.int32)
46 | return pred_bbox
47 |
48 | def get_box_info(pred_bbox, original_shape, final_size):
49 | assert len(pred_bbox) == 4
50 | resize_startx = int(pred_bbox[0] / original_shape[0] * final_size)
51 | resize_starty = int(pred_bbox[1] / original_shape[1] * final_size)
52 | resize_endx = int(pred_bbox[2] / original_shape[0] * final_size)
53 | resize_endy = int(pred_bbox[3] / original_shape[1] * final_size)
54 | rh = resize_endx - resize_startx
55 | rw = resize_endy - resize_starty
56 | if rh < 1:
57 | if final_size - resize_endx > 1:
58 | resize_endx += 1
59 | else:
60 | resize_startx -= 1
61 | rh = 1
62 | if rw < 1:
63 | if final_size - resize_endy > 1:
64 | resize_endy += 1
65 | else:
66 | resize_starty -= 1
67 | rw = 1
68 | L_pad = resize_startx
69 | R_pad = final_size - resize_endx
70 | T_pad = resize_starty
71 | B_pad = final_size - resize_endy
72 | return [L_pad, R_pad, T_pad, B_pad, rh, rw]
--------------------------------------------------------------------------------
/models/InstaColor/utils/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import sys
4 | sys.path.append("../..")
5 |
6 |
7 | import torch
8 | import numpy as np
9 | from PIL import Image
10 | import os
11 | from collections import OrderedDict
12 | from IPython import embed
13 | import cv2
14 | from utils.color_format import rgb2lab
15 |
16 | # Converts a Tensor into an image array (numpy)
17 | # |imtype|: the desired type of the converted numpy array
18 | def tensor2im(input_image, imtype=np.uint8):
19 | if isinstance(input_image, torch.Tensor):
20 | image_tensor = input_image.data
21 | else:
22 | return input_image
23 | image_numpy = image_tensor[0].cpu().float().numpy()
24 | if image_numpy.shape[0] == 1:
25 | image_numpy = np.tile(image_numpy, (3, 1, 1))
26 | image_numpy = np.clip((np.transpose(image_numpy, (1, 2, 0)) ),0, 1) * 255.0
27 | return image_numpy.astype(imtype)
28 |
29 |
30 | def diagnose_network(net, name='network'):
31 | mean = 0.0
32 | count = 0
33 | for param in net.parameters():
34 | if param.grad is not None:
35 | mean += torch.mean(torch.abs(param.grad.data))
36 | count += 1
37 | if count > 0:
38 | mean = mean / count
39 | print(name)
40 | print(mean)
41 |
42 |
43 | def save_image(image_numpy, image_path):
44 | image_pil = Image.fromarray(image_numpy)
45 | image_pil.save(image_path)
46 |
47 |
48 | def print_numpy(x, val=True, shp=False):
49 | x = x.astype(np.float64)
50 | if shp:
51 | print('shape,', x.shape)
52 | if val:
53 | x = x.flatten()
54 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
55 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
56 |
57 |
58 | def mkdirs(paths):
59 | if isinstance(paths, list) and not isinstance(paths, str):
60 | for path in paths:
61 | mkdir(path)
62 | else:
63 | mkdir(paths)
64 |
65 |
66 | def mkdir(path):
67 | if not os.path.exists(path):
68 | os.makedirs(path)
69 |
70 |
71 | def get_subset_dict(in_dict,keys):
72 | if(len(keys)):
73 | subset = OrderedDict()
74 | for key in keys:
75 | subset[key] = in_dict[key]
76 | else:
77 | subset = in_dict
78 | return subset
79 |
80 |
81 |
82 | def get_colorization_data(data_raw, opt, ab_thresh=5., p=.125, num_points=None):
83 | data = {}
84 | data_lab = rgb2lab(data_raw[0], opt)
85 | data['A'] = data_lab[:,[0,],:,:]
86 | data['B'] = data_lab[:,1:,:,:]
87 |
88 | if(ab_thresh > 0): # mask out grayscale images
89 | thresh = 1.*ab_thresh/opt.ab_norm
90 | mask = torch.sum(torch.abs(torch.max(torch.max(data['B'],dim=3)[0],dim=2)[0]-torch.min(torch.min(data['B'],dim=3)[0],dim=2)[0]),dim=1) >= thresh
91 | data['A'] = data['A'][mask,:,:,:]
92 | data['B'] = data['B'][mask,:,:,:]
93 | # print('Removed %i points'%torch.sum(mask==0).numpy())
94 | if(torch.sum(mask)==0):
95 | return None
96 |
97 | return add_color_patches_rand_gt(data, opt, p=p, num_points=num_points)
98 |
99 | def add_color_patches_rand_gt(data,opt,p=.125,num_points=None,use_avg=True,samp='normal'):
100 | # Add random color points sampled from ground truth based on:
101 | # Number of points
102 | # - if num_points is 0, then sample from geometric distribution, drawn from probability p
103 | # - if num_points > 0, then sample that number of points
104 | # Location of points
105 | # - if samp is 'normal', draw from N(0.5, 0.25) of image
106 | # - otherwise, draw from U[0, 1] of image
107 | N,C,H,W = data['B'].shape
108 |
109 | data['hint_B'] = torch.zeros_like(data['B'])
110 | data['mask_B'] = torch.zeros_like(data['A'])
111 |
112 | for nn in range(N):
113 | pp = 0
114 | cont_cond = True
115 | while(cont_cond):
116 | if(num_points is None): # draw from geometric
117 | # embed()
118 | cont_cond = np.random.rand() < (1-p)
119 | else: # add certain number of points
120 | cont_cond = pp < num_points
121 | if(not cont_cond): # skip out of loop if condition not met
122 | continue
123 | print('add hint !!!!!!!!!')
124 | P = np.random.choice(opt.sample_Ps) # patch size
125 |
126 | # sample location
127 | if(samp=='normal'): # geometric distribution
128 | h = int(np.clip(np.random.normal( (H-P+1)/2., (H-P+1)/4.), 0, H-P))
129 | w = int(np.clip(np.random.normal( (W-P+1)/2., (W-P+1)/4.), 0, W-P))
130 | else: # uniform distribution
131 | h = np.random.randint(H-P+1)
132 | w = np.random.randint(W-P+1)
133 |
134 | # add color point
135 | if(use_avg):
136 | # embed()
137 | data['hint_B'][nn,:,h:h+P,w:w+P] = torch.mean(torch.mean(data['B'][nn,:,h:h+P,w:w+P],dim=2,keepdim=True),dim=1,keepdim=True).view(1,C,1,1)
138 | else:
139 | data['hint_B'][nn,:,h:h+P,w:w+P] = data['B'][nn,:,h:h+P,w:w+P]
140 |
141 | data['mask_B'][nn,:,h:h+P,w:w+P] = 1
142 |
143 | # increment counter
144 | pp+=1
145 |
146 | data['mask_B']-=opt.mask_cent
147 |
148 | return data
149 |
150 | def add_color_patch(data,mask,opt,P=1,hw=[128,128],ab=[0,0]):
151 | # Add a color patch at (h,w) with color (a,b)
152 | data[:,0,hw[0]:hw[0]+P,hw[1]:hw[1]+P] = 1.*ab[0]/opt.ab_norm
153 | data[:,1,hw[0]:hw[0]+P,hw[1]:hw[1]+P] = 1.*ab[1]/opt.ab_norm
154 | mask[:,:,hw[0]:hw[0]+P,hw[1]:hw[1]+P] = 1-opt.mask_cent
155 |
156 | return (data,mask)
157 |
158 | def crop_mult(data,mult=16,HWmax=[800,1200]):
159 | # crop image to a multiple
160 | H,W = data.shape[2:]
161 | Hnew = int(min(H/mult*mult,HWmax[0]))
162 | Wnew = int(min(W/mult*mult,HWmax[1]))
163 | h = (H-Hnew)/2
164 | w = (W-Wnew)/2
165 |
166 | return data[:,:,h:h+Hnew,w:w+Wnew]
167 |
168 | def encode_ab_ind(data_ab, opt):
169 | # Encode ab value into an index
170 | # INPUTS
171 | # data_ab Nx2xHxW \in [-1,1]
172 | # OUTPUTS
173 | # data_q Nx1xHxW \in [0,Q)
174 |
175 | data_ab_rs = torch.round((data_ab*opt.ab_norm + opt.ab_max)/opt.ab_quant) # normalized bin number
176 | data_q = data_ab_rs[:,[0],:,:]*opt.A + data_ab_rs[:,[1],:,:]
177 | return data_q
178 |
179 | def decode_ind_ab(data_q, opt):
180 | # Decode index into ab value
181 | # INPUTS
182 | # data_q Nx1xHxW \in [0,Q)
183 | # OUTPUTS
184 | # data_ab Nx2xHxW \in [-1,1]
185 |
186 | data_a = data_q/opt.A
187 | data_b = data_q - data_a*opt.A
188 | data_ab = torch.cat((data_a,data_b),dim=1)
189 |
190 | if(data_q.is_cuda):
191 | type_out = torch.cuda.FloatTensor
192 | else:
193 | type_out = torch.FloatTensor
194 | data_ab = ((data_ab.type(type_out)*opt.ab_quant) - opt.ab_max)/opt.ab_norm
195 |
196 | return data_ab
197 |
198 | def decode_max_ab(data_ab_quant, opt):
199 | # Decode probability distribution by using bin with highest probability
200 | # INPUTS
201 | # data_ab_quant NxQxHxW \in [0,1]
202 | # OUTPUTS
203 | # data_ab Nx2xHxW \in [-1,1]
204 |
205 | data_q = torch.argmax(data_ab_quant,dim=1)[:,None,:,:]
206 | return decode_ind_ab(data_q, opt)
207 |
208 | def decode_mean(data_ab_quant, opt):
209 | # Decode probability distribution by taking mean over all bins
210 | # INPUTS
211 | # data_ab_quant NxQxHxW \in [0,1]
212 | # OUTPUTS
213 | # data_ab_inf Nx2xHxW \in [-1,1]
214 |
215 | (N,Q,H,W) = data_ab_quant.shape
216 | a_range = torch.range(-opt.ab_max, opt.ab_max, step=opt.ab_quant).to(data_ab_quant.device)[None,:,None,None]
217 | a_range = a_range.type(data_ab_quant.type())
218 |
219 | # reshape to AB space
220 | data_ab_quant = data_ab_quant.view((N,int(opt.A),int(opt.A),H,W))
221 | data_a_total = torch.sum(data_ab_quant,dim=2)
222 | data_b_total = torch.sum(data_ab_quant,dim=1)
223 |
224 | # matrix multiply
225 | data_a_inf = torch.sum(data_a_total * a_range,dim=1,keepdim=True)
226 | data_b_inf = torch.sum(data_b_total * a_range,dim=1,keepdim=True)
227 |
228 | data_ab_inf = torch.cat((data_a_inf,data_b_inf),dim=1)/opt.ab_norm
229 |
230 | return data_ab_inf
--------------------------------------------------------------------------------
/notebooks/DEVC.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Apply Deep Examplar-based Video Corization (DEVC)"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 1,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "%load_ext autoreload\n",
17 | "%autoreload 2"
18 | ]
19 | },
20 | {
21 | "cell_type": "code",
22 | "execution_count": 2,
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "import os\n",
27 | "import sys\n",
28 | "sys.path.append(\"../\")\n",
29 | "import numpy as np\n",
30 | "from models.DEVC import DEVC"
31 | ]
32 | },
33 | {
34 | "cell_type": "markdown",
35 | "metadata": {},
36 | "source": [
37 | "### Configure Options"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": null,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "from models.DEVC.options.test_options import TestOptions\n",
47 | "sys.argv = [sys.argv[0]]\n",
48 | "opt = TestOptions().parse() "
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "### Setup Models"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "model = DEVC(pretrained=True)"
65 | ]
66 | },
67 | {
68 | "cell_type": "markdown",
69 | "metadata": {},
70 | "source": [
71 | "### Setup Test File"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": null,
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "from utils import v2i\n",
81 | "\n",
82 | "v2i.convert_video_to_frames(\"../test/input.mp4\", \"../test/frames\")"
83 | ]
84 | },
85 | {
86 | "cell_type": "markdown",
87 | "metadata": {},
88 | "source": [
89 | "### Run Model on Test File"
90 | ]
91 | },
92 | {
93 | "cell_type": "code",
94 | "execution_count": null,
95 | "metadata": {},
96 | "outputs": [],
97 | "source": [
98 | "model.test(\"../test/frames/\", \"../test/output.mp4\", opt)"
99 | ]
100 | },
101 | {
102 | "cell_type": "code",
103 | "execution_count": null,
104 | "metadata": {},
105 | "outputs": [],
106 | "source": [
107 | "from utils.util import apply_metric_to_video\n",
108 | "from utils.metrics import *\n",
109 | "apply_metric_to_video(\"../test/input.mp4\", \"../test/output.mp4\", [PSNR, SSIM, LPIPS, cosine_similarity])"
110 | ]
111 | },
112 | {
113 | "cell_type": "code",
114 | "execution_count": null,
115 | "metadata": {},
116 | "outputs": [],
117 | "source": []
118 | }
119 | ],
120 | "metadata": {
121 | "interpreter": {
122 | "hash": "3dc88b2683938781427c3c64a4f5f2f03c4267c2f876dcaeb3e56e854a9d6a9e"
123 | },
124 | "kernelspec": {
125 | "display_name": "Python 3.7.6 ('instacolorization')",
126 | "language": "python",
127 | "name": "python3"
128 | },
129 | "language_info": {
130 | "codemirror_mode": {
131 | "name": "ipython",
132 | "version": 3
133 | },
134 | "file_extension": ".py",
135 | "mimetype": "text/x-python",
136 | "name": "python",
137 | "nbconvert_exporter": "python",
138 | "pygments_lexer": "ipython3",
139 | "version": "3.7.6"
140 | },
141 | "orig_nbformat": 4
142 | },
143 | "nbformat": 4,
144 | "nbformat_minor": 2
145 | }
146 |
--------------------------------------------------------------------------------
/notebooks/colorizer.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Colorful Image Colorization\n",
8 | "\n",
9 | "instacolor's environment is good for this algorithm"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 2,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import matplotlib.pyplot as plt\n",
19 | "import sys, os\n",
20 | "sys.path.append(\"../\")\n",
21 | "from models.colorizers import siggraph17\n",
22 | "from models.colorizers.util import *"
23 | ]
24 | },
25 | {
26 | "cell_type": "code",
27 | "execution_count": 3,
28 | "metadata": {},
29 | "outputs": [],
30 | "source": [
31 | "colorizer_siggraph17 = siggraph17(pretrained=True).eval()\n",
32 | "_ = colorizer_siggraph17.cuda()"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 4,
38 | "metadata": {},
39 | "outputs": [
40 | {
41 | "data": {
42 | "text/plain": [
43 | "['frame00000.jpg', 'frame00001.jpg']"
44 | ]
45 | },
46 | "execution_count": 4,
47 | "metadata": {},
48 | "output_type": "execute_result"
49 | }
50 | ],
51 | "source": [
52 | "VIDEO_NAME = \"yosemite-short\"\n",
53 | "INPUT_PATH = f\"../test/{VIDEO_NAME}-original/\"\n",
54 | "OUTPUT_PATH = f\"../test/{VIDEO_NAME}-colorized/\"\n",
55 | "if not os.path.isdir(OUTPUT_PATH):\n",
56 | " os.makedirs(OUTPUT_PATH)\n",
57 | "\n",
58 | "frames = os.listdir(INPUT_PATH)\n",
59 | "frames.sort()\n",
60 | "frames[:2]"
61 | ]
62 | },
63 | {
64 | "cell_type": "code",
65 | "execution_count": 5,
66 | "metadata": {},
67 | "outputs": [],
68 | "source": [
69 | "def colorize(path):\n",
70 | " img = load_img(path)\n",
71 | " (tens_l_orig, tens_l_rs) = preprocess_img(img, HW=(256,256))\n",
72 | " tens_l_rs = tens_l_rs.cuda()\n",
73 | " out_img_siggraph17 = postprocess_tens(tens_l_orig, colorizer_siggraph17(tens_l_rs).cpu())\n",
74 | " return out_img_siggraph17"
75 | ]
76 | },
77 | {
78 | "cell_type": "code",
79 | "execution_count": 6,
80 | "metadata": {},
81 | "outputs": [
82 | {
83 | "name": "stderr",
84 | "output_type": "stream",
85 | "text": [
86 | "/home/tonyx/Utils/anaconda3/envs/instacolor/lib/python3.9/site-packages/torch/nn/functional.py:3631: UserWarning: Default upsampling behavior when mode=bilinear is changed to align_corners=False since 0.4.0. Please specify align_corners=True if the old behavior is desired. See the documentation of nn.Upsample for details.\n",
87 | " warnings.warn(\n"
88 | ]
89 | }
90 | ],
91 | "source": [
92 | "for frame in frames:\n",
93 | " colorized = colorize(os.path.join(INPUT_PATH, frame))\n",
94 | " plt.imsave(os.path.join(OUTPUT_PATH, frame), colorized)"
95 | ]
96 | }
97 | ],
98 | "metadata": {
99 | "interpreter": {
100 | "hash": "d66f20e7d73dbe1af3ef7c28c682f3f96a708e845fb78d9a29b2de937addae0a"
101 | },
102 | "kernelspec": {
103 | "display_name": "Python 3.9.7 ('instacolor')",
104 | "language": "python",
105 | "name": "python3"
106 | },
107 | "language_info": {
108 | "codemirror_mode": {
109 | "name": "ipython",
110 | "version": 3
111 | },
112 | "file_extension": ".py",
113 | "mimetype": "text/x-python",
114 | "name": "python",
115 | "nbconvert_exporter": "python",
116 | "pygments_lexer": "ipython3",
117 | "version": "3.9.7"
118 | },
119 | "orig_nbformat": 4
120 | },
121 | "nbformat": 4,
122 | "nbformat_minor": 2
123 | }
124 |
--------------------------------------------------------------------------------
/notebooks/video_prepro.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Video Pre-Processing"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": 3,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "import os, sys\n",
17 | "import cv2\n",
18 | "import numpy as np\n",
19 | "sys.path.append(\"../\")\n",
20 | "from utils.v2i import convert_video_to_frames, convert_frames_to_video\n",
21 | "from PIL import Image, ImageOps"
22 | ]
23 | },
24 | {
25 | "cell_type": "markdown",
26 | "metadata": {},
27 | "source": [
28 | "## MP4 to Frames"
29 | ]
30 | },
31 | {
32 | "cell_type": "code",
33 | "execution_count": 6,
34 | "metadata": {},
35 | "outputs": [],
36 | "source": [
37 | "VIDEO_NAME = \"yosemite-short\"\n",
38 | "DIR = \"../test\"\n",
39 | "\n",
40 | "convert_video_to_frames(os.path.join(DIR, VIDEO_NAME + \".mp4\"), os.path.join(DIR, VIDEO_NAME + \"-original\"))\n",
41 | "\n",
42 | "frames = sorted(os.listdir(os.path.join(DIR, VIDEO_NAME + \"-original\")))\n",
43 | "\n",
44 | "for frame in frames:\n",
45 | " f = Image.open(os.path.join(DIR, VIDEO_NAME + \"-original\", frame))\n",
46 | " gray_image = ImageOps.grayscale(f)\n",
47 | " gray_image.save(os.path.join(DIR, VIDEO_NAME + \"-original\", frame))\n"
48 | ]
49 | }
50 | ],
51 | "metadata": {
52 | "interpreter": {
53 | "hash": "189fd69765f0fd89b2713b9afa90d91317e65ce7ee286ac845775eb9093f307d"
54 | },
55 | "kernelspec": {
56 | "display_name": "Python 3.9.7 ('dvp')",
57 | "language": "python",
58 | "name": "python3"
59 | },
60 | "language_info": {
61 | "codemirror_mode": {
62 | "name": "ipython",
63 | "version": 3
64 | },
65 | "file_extension": ".py",
66 | "mimetype": "text/x-python",
67 | "name": "python",
68 | "nbconvert_exporter": "python",
69 | "pygments_lexer": "ipython3",
70 | "version": "3.9.7"
71 | },
72 | "orig_nbformat": 4
73 | },
74 | "nbformat": 4,
75 | "nbformat_minor": 2
76 | }
77 |
--------------------------------------------------------------------------------
/notebooks/webdemo.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## Colorization Webdemo"
8 | ]
9 | },
10 | {
11 | "cell_type": "markdown",
12 | "metadata": {},
13 | "source": [
14 | "#### Install streamlit"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": null,
20 | "metadata": {},
21 | "outputs": [],
22 | "source": [
23 | "!pip install streamlit -q"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "metadata": {},
29 | "source": [
30 | "#### clone git repository"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": null,
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "!git clone https://github.com/Vince-Ai/Colorization.git"
40 | ]
41 | },
42 | {
43 | "cell_type": "code",
44 | "execution_count": null,
45 | "metadata": {},
46 | "outputs": [],
47 | "source": [
48 | "%cd Colorization"
49 | ]
50 | },
51 | {
52 | "cell_type": "markdown",
53 | "metadata": {},
54 | "source": [
55 | "#### Install requirements"
56 | ]
57 | },
58 | {
59 | "cell_type": "code",
60 | "execution_count": null,
61 | "metadata": {},
62 | "outputs": [],
63 | "source": [
64 | "!pip install -r files/general.txt"
65 | ]
66 | },
67 | {
68 | "cell_type": "markdown",
69 | "metadata": {},
70 | "source": [
71 | "#### Start server"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": null,
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "!streamlit run webdemo.py & npx localtunnel --port 8501"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "metadata": {},
86 | "source": [
87 | "#### Click on the link ends with loca.it and enjoy!"
88 | ]
89 | }
90 | ],
91 | "metadata": {
92 | "language_info": {
93 | "name": "python"
94 | },
95 | "orig_nbformat": 4
96 | },
97 | "nbformat": 4,
98 | "nbformat_minor": 2
99 | }
100 |
--------------------------------------------------------------------------------
/sample.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wensi-ai/Colorizer/ef382f44885f9d083c02a83952db9b4a6b4149c1/sample.mp4
--------------------------------------------------------------------------------
/utils/color_format.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def rgb2xyz(rgb): # rgb from [0,1]
5 | # xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423],
6 | # [0.212671, 0.715160, 0.072169],
7 | # [0.019334, 0.119193, 0.950227]])
8 |
9 | mask = (rgb > .04045).type(torch.FloatTensor)
10 | if(rgb.is_cuda):
11 | mask = mask.cuda()
12 |
13 | rgb = (((rgb+.055)/1.055)**2.4)*mask + rgb/12.92*(1-mask)
14 |
15 | x = .412453*rgb[:,0,:,:]+.357580*rgb[:,1,:,:]+.180423*rgb[:,2,:,:]
16 | y = .212671*rgb[:,0,:,:]+.715160*rgb[:,1,:,:]+.072169*rgb[:,2,:,:]
17 | z = .019334*rgb[:,0,:,:]+.119193*rgb[:,1,:,:]+.950227*rgb[:,2,:,:]
18 |
19 | out = torch.cat((x[:,None,:,:],y[:,None,:,:],z[:,None,:,:]),dim=1)
20 | return out
21 |
22 |
23 | def xyz2rgb(xyz):
24 | # array([[ 3.24048134, -1.53715152, -0.49853633],
25 | # [-0.96925495, 1.87599 , 0.04155593],
26 | # [ 0.05564664, -0.20404134, 1.05731107]])
27 |
28 | r = 3.24048134*xyz[:,0,:,:]-1.53715152*xyz[:,1,:,:]-0.49853633*xyz[:,2,:,:]
29 | g = -0.96925495*xyz[:,0,:,:]+1.87599*xyz[:,1,:,:]+.04155593*xyz[:,2,:,:]
30 | b = .05564664*xyz[:,0,:,:]-.20404134*xyz[:,1,:,:]+1.05731107*xyz[:,2,:,:]
31 |
32 | rgb = torch.cat((r[:,None,:,:],g[:,None,:,:],b[:,None,:,:]),dim=1)
33 | rgb = torch.max(rgb,torch.zeros_like(rgb)) # sometimes reaches a small negative number, which causes NaNs
34 |
35 | mask = (rgb > .0031308).type(torch.FloatTensor)
36 | if(rgb.is_cuda):
37 | mask = mask.cuda()
38 |
39 | rgb = (1.055*(rgb**(1./2.4)) - 0.055)*mask + 12.92*rgb*(1-mask)
40 | return rgb
41 |
42 |
43 | def xyz2lab(xyz):
44 | # 0.95047, 1., 1.08883 # white
45 | sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None]
46 | if(xyz.is_cuda):
47 | sc = sc.cuda()
48 |
49 | xyz_scale = xyz/sc
50 |
51 | mask = (xyz_scale > .008856).type(torch.FloatTensor)
52 | if(xyz_scale.is_cuda):
53 | mask = mask.cuda()
54 |
55 | xyz_int = xyz_scale**(1/3.)*mask + (7.787*xyz_scale + 16./116.)*(1-mask)
56 |
57 | L = 116.*xyz_int[:,1,:,:]-16.
58 | a = 500.*(xyz_int[:,0,:,:]-xyz_int[:,1,:,:])
59 | b = 200.*(xyz_int[:,1,:,:]-xyz_int[:,2,:,:])
60 |
61 | out = torch.cat((L[:,None,:,:],a[:,None,:,:],b[:,None,:,:]),dim=1)
62 | return out
63 |
64 |
65 | def lab2xyz(lab):
66 | y_int = (lab[:,0,:,:]+16.)/116.
67 | x_int = (lab[:,1,:,:]/500.) + y_int
68 | z_int = y_int - (lab[:,2,:,:]/200.)
69 | if(z_int.is_cuda):
70 | z_int = torch.max(torch.Tensor((0,)).cuda(), z_int)
71 | else:
72 | z_int = torch.max(torch.Tensor((0,)), z_int)
73 |
74 | out = torch.cat((x_int[:,None,:,:],y_int[:,None,:,:],z_int[:,None,:,:]),dim=1)
75 | mask = (out > .2068966).type(torch.FloatTensor)
76 | if(out.is_cuda):
77 | mask = mask.cuda()
78 |
79 | out = (out**3.)*mask + (out - 16./116.)/7.787*(1-mask)
80 |
81 | sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None]
82 | sc = sc.to(out.device)
83 |
84 | out = out*sc
85 | return out
86 |
87 |
88 | def rgb2lab(rgb, opt):
89 | lab = xyz2lab(rgb2xyz(rgb))
90 | l_rs = (lab[:,[0],:,:]-opt.l_cent)/opt.l_norm
91 | ab_rs = lab[:,1:,:,:]/opt.ab_norm
92 | out = torch.cat((l_rs,ab_rs),dim=1)
93 | return out
94 |
95 |
96 | def lab2rgb(lab_rs, opt):
97 | l = lab_rs[:,[0],:,:]*opt.l_norm + opt.l_cent
98 | ab = lab_rs[:,1:,:,:]*opt.ab_norm
99 | lab = torch.cat((l,ab),dim=1)
100 | out = xyz2rgb(lab2xyz(lab))
101 | return out
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def mse_loss(input, target=0):
4 | return torch.mean((input - target) ** 2)
5 |
6 |
7 | def l1_loss(input, target=0):
8 | return torch.mean(torch.abs(input - target))
9 |
10 |
11 | def weighted_mse_loss(input, target, weights):
12 | out = (input - target) ** 2
13 | out = out * weights.expand_as(out)
14 | return out.mean()
15 |
16 |
17 | def weighted_l1_loss(input, target, weights):
18 | out = torch.abs(input - target)
19 | out = out * weights.expand_as(out)
20 | return out.mean()
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import torch
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from math import exp
6 | import lpips
7 |
8 | def PSNR(img1, img2):
9 | SE_map = (1. * img1 - img2) ** 2
10 | cur_MSE = torch.mean(SE_map)
11 | return float(20 * torch.log10(1. / torch.sqrt(cur_MSE)))
12 |
13 |
14 | def SSIM(img1, img2, window_size=11, size_average=True):
15 | (_, channel, _, _) = img1.size()
16 |
17 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 /float(2 * 1.5 ** 2)) for x in range(window_size)])
18 | gauss = gauss / gauss.sum()
19 | _1D_window = gauss.unsqueeze(1)
20 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
21 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
22 |
23 | if img1.is_cuda:
24 | window = window.cuda(img1.get_device())
25 | window = window.type_as(img1)
26 |
27 | mu1 = F.conv2d(img1, window, padding = window_size // 2, groups = channel)
28 | mu2 = F.conv2d(img2, window, padding = window_size // 2, groups = channel)
29 |
30 | mu1_sq = mu1.pow(2)
31 | mu2_sq = mu2.pow(2)
32 | mu1_mu2 = mu1*mu2
33 |
34 | sigma1_sq = F.conv2d(img1 * img1, window, padding = window_size // 2, groups = channel) - mu1_sq
35 | sigma2_sq = F.conv2d(img2 * img2, window, padding = window_size // 2, groups = channel) - mu2_sq
36 | sigma12 = F.conv2d(img1 * img2, window, padding = window_size // 2, groups = channel) - mu1_mu2
37 |
38 | C1 = 0.01 ** 2
39 | C2 = 0.03 ** 2
40 |
41 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
42 |
43 | if size_average:
44 | return float(ssim_map.mean())
45 | else:
46 | return ssim_map.mean(1).mean(1).mean(1)
47 |
48 |
49 | def LPIPS(img1, img2):
50 | # Normalize to [-1, 1]
51 | img1 = 2 * (img1 / 255.) - 1
52 | img2 = 2 * (img2 / 255.) - 1
53 | loss_fn = lpips.LPIPS(net='alex')
54 | result = float(torch.mean(loss_fn.forward(img1, img2)))
55 | return result
56 |
57 |
58 | def PCVC(img1, img2):
59 | pass
60 |
61 |
62 | def colorfulness(imgs):
63 | """
64 | according to the paper: Measuring colourfulness in natural images
65 | input is batches of ab tensors in lab space
66 | """
67 | N, C, H, W = imgs.shape
68 | a = imgs[:, 0:1, :, :]
69 | b = imgs[:, 1:2, :, :]
70 |
71 | a = a.view(N, -1)
72 | b = b.view(N, -1)
73 |
74 | sigma_a = torch.std(a, dim=-1)
75 | sigma_b = torch.std(b, dim=-1)
76 |
77 | mean_a = torch.mean(a, dim=-1)
78 | mean_b = torch.mean(b, dim=-1)
79 |
80 | return torch.sqrt(sigma_a ** 2 + sigma_b ** 2) + 0.37 * torch.sqrt(mean_a ** 2 + mean_b ** 2)
81 |
82 |
83 | def cosine_similarity(img1, img2):
84 | input_norm = torch.norm(img1, 2, 1, keepdim=True) + sys.float_info.epsilon
85 | target_norm = torch.norm(img2, 2, 1, keepdim=True) + sys.float_info.epsilon
86 | normalized_input = torch.div(img1, input_norm)
87 | normalized_target = torch.div(img2, target_norm)
88 | cos_similarity = torch.mul(normalized_input, normalized_target)
89 | return float(torch.mean(cos_similarity))
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import zipfile
3 | import requests
4 | import torch
5 | from torchvision import io
6 | import torchvideo.transforms as transforms
7 |
8 | def download_zipfile(url: str, destination: str='.', unzip: bool=True):
9 | response = requests.get(url)
10 | CHUNK_SIZE = 32768
11 | with open(destination, "wb") as f:
12 | for chunk in response.iter_content(CHUNK_SIZE):
13 | if chunk: # filter out keep-alive new chunks
14 | f.write(chunk)
15 | if unzip:
16 | with zipfile.ZipFile(destination, 'r') as zip_ref:
17 | zip_ref.extractall(".")
18 |
19 | def mkdir(dir_path):
20 | if not os.path.exists(dir_path):
21 | os.makedirs(dir_path)
22 |
23 |
24 | def apply_metric_to_video(video_path1, video_path2, metrics):
25 | batch_size = 50
26 |
27 | video1, _, _ = io.read_video(video_path1)
28 | video2, _, _ = io.read_video(video_path2)
29 | video1 = video1.cpu().detach().numpy()
30 | transform = transforms.Compose([
31 | transforms.NDArrayToPILVideo(),
32 | transforms.ResizeVideo((video2.shape[1], video2.shape[2])),
33 | transforms.CollectFrames(),
34 | transforms.PILVideoToTensor(rescale=False, ordering='TCHW')]
35 | )
36 | length = min(video1.shape[0], video2.shape[0])
37 | video1 = transform(video1)[:length].float()
38 | video2 = video2.permute((0, 3, 1, 2))[:length].float()
39 | results = []
40 | for metric in metrics:
41 | cur_metric_result = 0
42 | for i in range(0, video1.shape[0], batch_size):
43 | cur_metric_result += min(video1.shape[0] - i, batch_size) * metric(video1[i: min(video1.shape[0], i + batch_size)], video2[i: min(video1.shape[0], i + batch_size)])
44 | results.append(cur_metric_result / video1.shape[0])
45 | return results
46 |
--------------------------------------------------------------------------------
/utils/v2i.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | from typing import List
4 | import skvideo.io
5 |
6 | def convert_video_to_frames(video_path: str, frames_path: str):
7 | """Convert video input to a set of frames in jpg format"""
8 | if os.path.isdir(frames_path) is False:
9 | os.makedirs(frames_path)
10 | vidcap = cv2.VideoCapture(video_path)
11 | success,image = vidcap.read()
12 | count = 0
13 | success = True
14 | success, image = vidcap.read()
15 | while success:
16 | cv2.imwrite(f"{frames_path}/frame" + str(count).zfill(5) + ".jpg", image) # save frame as JPEG file
17 | if cv2.waitKey(10) == 27: # exit if Escape is hit
18 | break
19 | success, image = vidcap.read()
20 | count += 1
21 |
22 |
23 | def convert_frames_to_video(frames_path: str, output_video_path: str, images: List[str]=None):
24 | """Convert video input to a set of frames in jpg format"""
25 | if images is None:
26 | images = sorted(os.listdir(frames_path))
27 | frame = cv2.imread(os.path.join(frames_path, images[0]))
28 | height, width, _ = frame.shape
29 | fourcc = cv2.VideoWriter_fourcc(*'MP4V')
30 | video = cv2.VideoWriter("temp.mp4", fourcc, 30.0, (width, height))
31 |
32 | for image in images:
33 | video.write(cv2.imread(os.path.join(frames_path, image)))
34 |
35 | cv2.destroyAllWindows()
36 | video.release()
37 |
38 | os.system(f"ffmpeg -y -i temp.mp4 -vcodec libx264 {output_video_path}")
39 | os.remove("temp.mp4")
--------------------------------------------------------------------------------
/webdemo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append(".")
3 |
4 | import os
5 | import numpy as np
6 | import matplotlib.pyplot as plt
7 | import streamlit as st
8 | from importlib import import_module
9 | from utils import v2i
10 | from utils.util import apply_metric_to_video, mkdir
11 | from utils.metrics import *
12 | import shutil
13 |
14 | mkdir("test")
15 |
16 | @st.cache
17 | def colorize(model, model_name, opt):
18 | sys.argv = [sys.argv[0]]
19 | model.test("test/frames", f"test/output_{model_name}.mp4", opt)
20 | video = open(f"test/output_{model_name}.mp4", 'rb').read()
21 | return video
22 |
23 | @st.cache
24 | def dvp(model, model_name, opt):
25 | sys.argv = [sys.argv[0]]
26 | model.test("test/frames", "test/results", f"test/output_{model_name}.mp4", opt)
27 | video = open(f"test/output_{model_name}.mp4", 'rb').read()
28 | return video
29 |
30 | @st.cache
31 | def metric(video_out):
32 | results = {}
33 | ret = apply_metric_to_video("test/temp.mp4", video_out, [PSNR, SSIM, LPIPS, cosine_similarity])
34 | results["PSNR"] = ret[0]
35 | results["SSIM"] = ret[1]
36 | results["LPIPS"] = ret[2]
37 | results["cosine_similarity"] = ret[3]
38 | return results
39 |
40 | st.title("Video Colorization Web Demo")
41 |
42 | # Add a selectbox to the sidebar:
43 | model_names = st.sidebar.multiselect(
44 | 'Select colorization model(s):',
45 | ('DEVC', 'InstaColor', 'Colorful')
46 | )
47 |
48 | # Add a slider to the sidebar:
49 | input_method = st.sidebar.selectbox(
50 | 'Select video input:',
51 | ("sample video", "random from dataset", "upload local video")
52 | )
53 | # get input video
54 | if input_method == "upload local video":
55 | video_orig = st.sidebar.file_uploader("Choose a file")
56 | if video_orig:
57 | video_orig = video_orig.getvalue()
58 | elif input_method == "sample video":
59 | video_orig = open('sample.mp4', 'rb').read()
60 | else:
61 | video_orig = None
62 |
63 | # display video
64 | st.subheader("Original video:")
65 | st.video(video_orig, format="video/mp4", start_time=0)
66 | st.subheader("Colorized video:")
67 |
68 | # setup metric
69 | psnr, ssim, lpips, cs = [], [], [], []
70 |
71 | if st.sidebar.button('Colorize'):
72 | with open("test/temp.mp4", "wb") as f:
73 | f.write(video_orig)
74 | video_in = v2i.convert_video_to_frames("test/temp.mp4", "./test/frames")
75 | for i, model_name in enumerate(model_names):
76 | # import model
77 | model_module = import_module(f"models.{model_name}")
78 | Model = getattr(model_module, model_name)
79 | # import and parse test options
80 | try:
81 | opt_module = import_module(f"models.{model_name}.options.test_options")
82 | TestOptions = getattr(opt_module, "TestOptions")
83 | opt = TestOptions().parse()
84 | except AttributeError: # no option needed
85 | opt = None
86 | # run test on model
87 | model = Model()
88 | if model_name == "DVP":
89 | video_out = dvp(model, model_name, opt)
90 | else:
91 | video_out = colorize(model, model_name, opt)
92 | st.write(model_name + ": ")
93 | st.video(video_out, format="video/mp4", start_time=0)
94 | print("video displayed")
95 | # get metrics
96 | results = metric(f"test/output_{model_name}.mp4")
97 | psnr.append(results["PSNR"])
98 | ssim.append(results["SSIM"])
99 | lpips.append(results["LPIPS"])
100 | cs.append(results["cosine_similarity"])
101 | print(f"{model_name} metric complete!")
102 | os.remove("test/temp.mp4")
103 |
104 | shutil.rmtree("test/frames")
105 | shutil.rmtree("test/results")
106 |
107 | # display metrics
108 | st.subheader("Metrics:")
109 | idx = model_names
110 | fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(20, 30))
111 | # plt.subplots_adjust(wspace=1, hspace=1)
112 | ax1.bar(idx, psnr, color=['grey' if (x < max(psnr)) else 'red' for x in psnr ], width=0.4)
113 | ax2.bar(idx, ssim, color=['grey' if (x < max(ssim)) else 'red' for x in ssim ], width=0.4)
114 | ax3.bar(idx, lpips, color=['grey' if (x > min(lpips)) else 'red' for x in lpips ], width=0.4)
115 | ax4.bar(idx, cs, color=['grey' if (x < max(cs)) else 'red' for x in cs ], width=0.4)
116 | ax1.set_title('PNSR (higher is better)')
117 | ax2.set_title('SSIM (higher is better)')
118 | ax3.set_title('LPIPS (lower is better)')
119 | ax4.set_title('Cosine Similarity (higher is better)')
120 | st.pyplot(fig)
--------------------------------------------------------------------------------