├── .gitattributes ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .streamlit └── config.toml ├── 01_📼_Upload_Video_File.py ├── README.md ├── __init__.py ├── interface.gif ├── models ├── __init__.py └── deep_colorization │ ├── __init__.py │ └── colorizers │ ├── __init__.py │ ├── base_color.py │ ├── eccv16.py │ ├── siggraph17.py │ └── util.py ├── pages ├── 02_🎥_Input_Youtube_Link.py ├── 03_🖼️_Input_Images.py └── __init__.py ├── requirements.txt ├── setup.cfg └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | pre-commit: 7 | runs-on: ubuntu-latest 8 | name: Do the code respects Python standards? 9 | steps: 10 | - uses: actions/checkout@v2 11 | - name: Set up Python 12 | uses: actions/setup-python@v2 13 | with: 14 | python-version: '3.10' 15 | - name: Install pre-commit & requirements 16 | run: | 17 | pip install pre-commit pylint 18 | pip install -r requirements.txt 19 | - name: Run pre-commit 20 | run: pre-commit run --all-files -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | .idea/ 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ 145 | 146 | # Cython debug symbols 147 | cython_debug/ 148 | 149 | # PyCharm 150 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 151 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 152 | # and can be added to the global gitignore or merged into this file. For a more nuclear 153 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 154 | #.idea/ 155 | .idea/vcs.xml 156 | '.idea' 157 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 24.4.0 4 | hooks: 5 | - id: black 6 | args: ['--line-length=120', '--verbose'] 7 | exclude: '^models/' 8 | 9 | - repo: https://github.com/pycqa/flake8 10 | rev: '7.0.0' 11 | hooks: 12 | - id: flake8 13 | exclude: '^models/' 14 | 15 | - repo: https://github.com/pre-commit/mirrors-pylint 16 | rev: v3.0.0a5 17 | hooks: 18 | - id: pylint 19 | name: pylint 20 | entry: pylint 21 | language: system 22 | args: ['.', '--rcfile=setup.cfg', '--fail-under=8'] 23 | exclude: '^models/' 24 | types: [python] -------------------------------------------------------------------------------- /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [theme] 2 | primaryColor="#F63366" 3 | backgroundColor="#FFFFFF" 4 | secondaryBackgroundColor="#F0F2F6" 5 | textColor="#262730" 6 | font="sans serif" 7 | [server] 8 | maxUploadSize=1028 -------------------------------------------------------------------------------- /01_📼_Upload_Video_File.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import time 4 | 5 | import cv2 6 | import moviepy.editor as mp 7 | import numpy as np 8 | import streamlit as st 9 | 10 | from tqdm import tqdm 11 | 12 | from utils import format_time, colorize_frame, change_model, load_model, setup_columns, set_page_config 13 | 14 | set_page_config() 15 | loaded_model = load_model() 16 | col2 = setup_columns() 17 | current_model = None 18 | 19 | with col2: 20 | st.write( 21 | """ 22 | ## B&W Videos Colorizer 23 | ##### Upload a black and white video and get a colorized version of it. 24 | ###### ➠ This space is using CPU Basic so it might take a while to colorize a video. 25 | ###### ➠ If you want more models and GPU available please support this space by donating.""" 26 | ) 27 | 28 | 29 | def main(): 30 | """ 31 | Main function to run this page 32 | """ 33 | model = st.selectbox( 34 | "Select Model (Both models have their pros and cons, I recommend trying both and keeping the best for your " 35 | "task)", 36 | ["ECCV16", "SIGGRAPH17"], 37 | index=0, 38 | ) 39 | 40 | loaded_model = change_model(current_model, model) 41 | st.write(f"Model is now {model}") 42 | 43 | uploaded_file = st.file_uploader("Upload your video here...", type=["mp4", "mov", "avi", "mkv"]) 44 | 45 | if st.button("Colorize"): 46 | if uploaded_file is not None: 47 | file_extension = os.path.splitext(uploaded_file.name)[1].lower() 48 | if file_extension in [".mp4", ".avi", ".mov", ".mkv"]: 49 | # Save the video file to a temporary location 50 | temp_file = tempfile.NamedTemporaryFile(delete=False) 51 | temp_file.write(uploaded_file.read()) 52 | 53 | audio = mp.AudioFileClip(temp_file.name) 54 | 55 | # Open the video using cv2.VideoCapture 56 | video = cv2.VideoCapture(temp_file.name) 57 | 58 | # Get video information 59 | fps = video.get(cv2.CAP_PROP_FPS) 60 | 61 | col1, col2 = st.columns([0.5, 0.5]) 62 | with col1: 63 | st.markdown('
Before
', unsafe_allow_html=True) 64 | st.video(temp_file.name) 65 | 66 | with col2: 67 | st.markdown('After
', unsafe_allow_html=True) 68 | 69 | with st.spinner("Colorizing frames..."): 70 | # Colorize video frames and store in a list 71 | output_frames = [] 72 | total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 73 | progress_bar = st.progress(0) # Create a progress bar 74 | 75 | start_time = time.time() 76 | time_text = st.text("Time Remaining: ") # Initialize text value 77 | 78 | for _ in tqdm(range(total_frames), unit="frame", desc="Progress"): 79 | ret, frame = video.read() 80 | if not ret: 81 | break 82 | 83 | colorized_frame = colorize_frame(frame, loaded_model) 84 | output_frames.append((colorized_frame * 255).astype(np.uint8)) 85 | 86 | elapsed_time = time.time() - start_time 87 | frames_completed = len(output_frames) 88 | frames_remaining = total_frames - frames_completed 89 | time_remaining = (frames_remaining / frames_completed) * elapsed_time 90 | 91 | progress_bar.progress(frames_completed / total_frames) # Update progress bar 92 | 93 | if frames_completed < total_frames: 94 | time_text.text(f"Time Remaining: {format_time(time_remaining)}") # Update text value 95 | else: 96 | time_text.empty() # Remove text value 97 | progress_bar.empty() 98 | 99 | with st.spinner("Merging frames to video..."): 100 | frame_size = output_frames[0].shape[:2] 101 | output_filename = "output.mp4" 102 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec for MP4 video 103 | out = cv2.VideoWriter(output_filename, fourcc, fps, (frame_size[1], frame_size[0])) 104 | 105 | # Display the colorized video using st.video 106 | for frame in output_frames: 107 | frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 108 | 109 | out.write(frame_bgr) 110 | 111 | out.release() 112 | 113 | # Convert the output video to a format compatible with Streamlit 114 | converted_filename = "converted_output.mp4" 115 | clip = mp.VideoFileClip(output_filename) 116 | clip = clip.set_audio(audio) 117 | 118 | clip.write_videofile(converted_filename, codec="libx264") 119 | 120 | # Display the converted video using st.video() 121 | st.video(converted_filename) 122 | st.balloons() 123 | 124 | # Add a download button for the colorized video 125 | st.download_button( 126 | label="Download Colorized Video", 127 | data=open(converted_filename, "rb").read(), 128 | file_name="colorized_video.mp4", 129 | ) 130 | 131 | # Close and delete the temporary file after processing 132 | video.release() 133 | temp_file.close() 134 | 135 | 136 | if __name__ == "__main__": 137 | main() 138 | st.markdown( 139 | "###### Made with :heart: by [Clément Delteil](https://www.linkedin.com/in/clementdelteil/) [](https://www.buymeacoffee.com/clementdelteil)" 141 | ) 142 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Video Colorization 2 | This project was conceived as part of the "Sujets Spéciaux" course at UQAC under the supervision of [Kevin Bouchard](http://www.kevin-bouchard.ca/) 3 | 4 | ## Objective 5 | The aim was to deploy a video image colorization application. To this end, I have taken up the GitHub project of [Richard Zhang and his co-authors](https://github.com/richzhang/colorization) with models and papers presented at ECCV16 and SIGGRAPH17. Their solutions are based on convolutional neural networks. 6 | 7 | The interface for the models was designed on Streamlit and deployed in a [Space on Hugging Face](https://huggingface.co/spaces/Wazzzabeee/image-video-colorization). 8 | 9 | The tutorial with all the design steps is available on my [Medium blog](https://medium.com/geekculture/creating-a-web-app-to-colorize-images-and-youtube-videos-80f5be2d0f68). 10 | 11 | The following features are available: 12 | - Colorization of a batch of images. 13 | - Colorization of MP4, MKV and AVI video files. 14 | - Colorization of Youtube videos. 15 | 16 | ## Interface 17 |
18 |
19 | ## Running Locally
20 | If you want to process longer videos and you're limited by the Hugging Face space memory's limits, you can run this app locally.
21 |
22 | `ffmpeg.exe` is needed to run this app, you can install it using `brew install ffmpeg` and update the `IMAGEIO_FFMPEG_EXE` environment variable accordingly.
23 |
24 | ```bash
25 | git clone https://github.com/Wazzabeee/image-video-colorization
26 | cd image-video-colorization
27 | pip install -r requirements.txt
28 | streamlit run 01_📼_Upload_Video_File.py
29 | ```
30 |
31 | ## Todos
32 | Other models based on GANs will probably be implemented in the future if my application for a community grant to gain access to a GPU on Hugging Face is successful.
33 |
34 | ## References
35 | 1. Richard Zhang, Phillip Isola et Alexei A Efros. « Colorful Image Colorization ». In : ECCV. 2016.
36 | 2. Richard Zhang et al. « Real-Time User-Guided Image Colorization with Learned Deep Priors ». In : ACM
37 | Transactions on Graphics (TOG) 9.4 (2017).
38 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wazzabeee/image-video-colorization/069f937320c1cbcd04579724048228600596788a/__init__.py
--------------------------------------------------------------------------------
/interface.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wazzabeee/image-video-colorization/069f937320c1cbcd04579724048228600596788a/interface.gif
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Wazzabeee/image-video-colorization/069f937320c1cbcd04579724048228600596788a/models/__init__.py
--------------------------------------------------------------------------------
/models/deep_colorization/__init__.py:
--------------------------------------------------------------------------------
1 | # inside models/deep_colorization/__init__.py
2 | from .colorizers import eccv16, siggraph17, load_img, preprocess_img, postprocess_tens
3 |
--------------------------------------------------------------------------------
/models/deep_colorization/colorizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_color import BaseColor
2 | from .eccv16 import ECCVGenerator, eccv16
3 | from .siggraph17 import SIGGRAPHGenerator, siggraph17
4 | from .util import load_img, resize_img, preprocess_img, postprocess_tens
5 |
--------------------------------------------------------------------------------
/models/deep_colorization/colorizers/base_color.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 |
4 | class BaseColor(nn.Module):
5 | def __init__(self):
6 | super(BaseColor, self).__init__()
7 |
8 | self.l_cent = 50.0
9 | self.l_norm = 100.0
10 | self.ab_norm = 110.0
11 |
12 | def normalize_l(self, in_l):
13 | return (in_l - self.l_cent) / self.l_norm
14 |
15 | def unnormalize_l(self, in_l):
16 | return in_l * self.l_norm + self.l_cent
17 |
18 | def normalize_ab(self, in_ab):
19 | return in_ab / self.ab_norm
20 |
21 | def unnormalize_ab(self, in_ab):
22 | return in_ab * self.ab_norm
23 |
--------------------------------------------------------------------------------
/models/deep_colorization/colorizers/eccv16.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from IPython import embed
5 |
6 | from .base_color import *
7 |
8 |
9 | class ECCVGenerator(BaseColor):
10 | def __init__(self, norm_layer=nn.BatchNorm2d):
11 | super(ECCVGenerator, self).__init__()
12 |
13 | model1 = [
14 | nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),
15 | ]
16 | model1 += [
17 | nn.ReLU(True),
18 | ]
19 | model1 += [
20 | nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),
21 | ]
22 | model1 += [
23 | nn.ReLU(True),
24 | ]
25 | model1 += [
26 | norm_layer(64),
27 | ]
28 |
29 | model2 = [
30 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),
31 | ]
32 | model2 += [
33 | nn.ReLU(True),
34 | ]
35 | model2 += [
36 | nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),
37 | ]
38 | model2 += [
39 | nn.ReLU(True),
40 | ]
41 | model2 += [
42 | norm_layer(128),
43 | ]
44 |
45 | model3 = [
46 | nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),
47 | ]
48 | model3 += [
49 | nn.ReLU(True),
50 | ]
51 | model3 += [
52 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),
53 | ]
54 | model3 += [
55 | nn.ReLU(True),
56 | ]
57 | model3 += [
58 | nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),
59 | ]
60 | model3 += [
61 | nn.ReLU(True),
62 | ]
63 | model3 += [
64 | norm_layer(256),
65 | ]
66 |
67 | model4 = [
68 | nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),
69 | ]
70 | model4 += [
71 | nn.ReLU(True),
72 | ]
73 | model4 += [
74 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
75 | ]
76 | model4 += [
77 | nn.ReLU(True),
78 | ]
79 | model4 += [
80 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
81 | ]
82 | model4 += [
83 | nn.ReLU(True),
84 | ]
85 | model4 += [
86 | norm_layer(512),
87 | ]
88 |
89 | model5 = [
90 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
91 | ]
92 | model5 += [
93 | nn.ReLU(True),
94 | ]
95 | model5 += [
96 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
97 | ]
98 | model5 += [
99 | nn.ReLU(True),
100 | ]
101 | model5 += [
102 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
103 | ]
104 | model5 += [
105 | nn.ReLU(True),
106 | ]
107 | model5 += [
108 | norm_layer(512),
109 | ]
110 |
111 | model6 = [
112 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
113 | ]
114 | model6 += [
115 | nn.ReLU(True),
116 | ]
117 | model6 += [
118 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
119 | ]
120 | model6 += [
121 | nn.ReLU(True),
122 | ]
123 | model6 += [
124 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
125 | ]
126 | model6 += [
127 | nn.ReLU(True),
128 | ]
129 | model6 += [
130 | norm_layer(512),
131 | ]
132 |
133 | model7 = [
134 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
135 | ]
136 | model7 += [
137 | nn.ReLU(True),
138 | ]
139 | model7 += [
140 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
141 | ]
142 | model7 += [
143 | nn.ReLU(True),
144 | ]
145 | model7 += [
146 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
147 | ]
148 | model7 += [
149 | nn.ReLU(True),
150 | ]
151 | model7 += [
152 | norm_layer(512),
153 | ]
154 |
155 | model8 = [
156 | nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),
157 | ]
158 | model8 += [
159 | nn.ReLU(True),
160 | ]
161 | model8 += [
162 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),
163 | ]
164 | model8 += [
165 | nn.ReLU(True),
166 | ]
167 | model8 += [
168 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),
169 | ]
170 | model8 += [
171 | nn.ReLU(True),
172 | ]
173 |
174 | model8 += [
175 | nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),
176 | ]
177 |
178 | self.model1 = nn.Sequential(*model1)
179 | self.model2 = nn.Sequential(*model2)
180 | self.model3 = nn.Sequential(*model3)
181 | self.model4 = nn.Sequential(*model4)
182 | self.model5 = nn.Sequential(*model5)
183 | self.model6 = nn.Sequential(*model6)
184 | self.model7 = nn.Sequential(*model7)
185 | self.model8 = nn.Sequential(*model8)
186 |
187 | self.softmax = nn.Softmax(dim=1)
188 | self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
189 | self.upsample4 = nn.Upsample(scale_factor=4, mode="bilinear")
190 |
191 | def forward(self, input_l):
192 | conv1_2 = self.model1(self.normalize_l(input_l))
193 | conv2_2 = self.model2(conv1_2)
194 | conv3_3 = self.model3(conv2_2)
195 | conv4_3 = self.model4(conv3_3)
196 | conv5_3 = self.model5(conv4_3)
197 | conv6_3 = self.model6(conv5_3)
198 | conv7_3 = self.model7(conv6_3)
199 | conv8_3 = self.model8(conv7_3)
200 | out_reg = self.model_out(self.softmax(conv8_3))
201 |
202 | return self.unnormalize_ab(self.upsample4(out_reg))
203 |
204 |
205 | def eccv16(pretrained=True):
206 | model = ECCVGenerator()
207 | if pretrained:
208 | import torch.utils.model_zoo as model_zoo
209 |
210 | model.load_state_dict(
211 | model_zoo.load_url(
212 | "https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth",
213 | map_location="cpu",
214 | check_hash=True,
215 | )
216 | )
217 | return model
218 |
--------------------------------------------------------------------------------
/models/deep_colorization/colorizers/siggraph17.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .base_color import *
5 |
6 |
7 | class SIGGRAPHGenerator(BaseColor):
8 | def __init__(self, norm_layer=nn.BatchNorm2d, classes=529):
9 | super(SIGGRAPHGenerator, self).__init__()
10 |
11 | # Conv1
12 | model1 = [
13 | nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True),
14 | ]
15 | model1 += [
16 | nn.ReLU(True),
17 | ]
18 | model1 += [
19 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),
20 | ]
21 | model1 += [
22 | nn.ReLU(True),
23 | ]
24 | model1 += [
25 | norm_layer(64),
26 | ]
27 | # add a subsampling operation
28 |
29 | # Conv2
30 | model2 = [
31 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),
32 | ]
33 | model2 += [
34 | nn.ReLU(True),
35 | ]
36 | model2 += [
37 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),
38 | ]
39 | model2 += [
40 | nn.ReLU(True),
41 | ]
42 | model2 += [
43 | norm_layer(128),
44 | ]
45 | # add a subsampling layer operation
46 |
47 | # Conv3
48 | model3 = [
49 | nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),
50 | ]
51 | model3 += [
52 | nn.ReLU(True),
53 | ]
54 | model3 += [
55 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),
56 | ]
57 | model3 += [
58 | nn.ReLU(True),
59 | ]
60 | model3 += [
61 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),
62 | ]
63 | model3 += [
64 | nn.ReLU(True),
65 | ]
66 | model3 += [
67 | norm_layer(256),
68 | ]
69 | # add a subsampling layer operation
70 |
71 | # Conv4
72 | model4 = [
73 | nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),
74 | ]
75 | model4 += [
76 | nn.ReLU(True),
77 | ]
78 | model4 += [
79 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
80 | ]
81 | model4 += [
82 | nn.ReLU(True),
83 | ]
84 | model4 += [
85 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
86 | ]
87 | model4 += [
88 | nn.ReLU(True),
89 | ]
90 | model4 += [
91 | norm_layer(512),
92 | ]
93 |
94 | # Conv5
95 | model5 = [
96 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
97 | ]
98 | model5 += [
99 | nn.ReLU(True),
100 | ]
101 | model5 += [
102 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
103 | ]
104 | model5 += [
105 | nn.ReLU(True),
106 | ]
107 | model5 += [
108 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
109 | ]
110 | model5 += [
111 | nn.ReLU(True),
112 | ]
113 | model5 += [
114 | norm_layer(512),
115 | ]
116 |
117 | # Conv6
118 | model6 = [
119 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
120 | ]
121 | model6 += [
122 | nn.ReLU(True),
123 | ]
124 | model6 += [
125 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
126 | ]
127 | model6 += [
128 | nn.ReLU(True),
129 | ]
130 | model6 += [
131 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),
132 | ]
133 | model6 += [
134 | nn.ReLU(True),
135 | ]
136 | model6 += [
137 | norm_layer(512),
138 | ]
139 |
140 | # Conv7
141 | model7 = [
142 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
143 | ]
144 | model7 += [
145 | nn.ReLU(True),
146 | ]
147 | model7 += [
148 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
149 | ]
150 | model7 += [
151 | nn.ReLU(True),
152 | ]
153 | model7 += [
154 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),
155 | ]
156 | model7 += [
157 | nn.ReLU(True),
158 | ]
159 | model7 += [
160 | norm_layer(512),
161 | ]
162 |
163 | # Conv7
164 | model8up = [nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)]
165 | model3short8 = [
166 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),
167 | ]
168 |
169 | model8 = [
170 | nn.ReLU(True),
171 | ]
172 | model8 += [
173 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),
174 | ]
175 | model8 += [
176 | nn.ReLU(True),
177 | ]
178 | model8 += [
179 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),
180 | ]
181 | model8 += [
182 | nn.ReLU(True),
183 | ]
184 | model8 += [
185 | norm_layer(256),
186 | ]
187 |
188 | # Conv9
189 | model9up = [
190 | nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),
191 | ]
192 | model2short9 = [
193 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),
194 | ]
195 | # add the two feature maps above
196 |
197 | model9 = [
198 | nn.ReLU(True),
199 | ]
200 | model9 += [
201 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),
202 | ]
203 | model9 += [
204 | nn.ReLU(True),
205 | ]
206 | model9 += [
207 | norm_layer(128),
208 | ]
209 |
210 | # Conv10
211 | model10up = [
212 | nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),
213 | ]
214 | model1short10 = [
215 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),
216 | ]
217 | # add the two feature maps above
218 |
219 | model10 = [
220 | nn.ReLU(True),
221 | ]
222 | model10 += [
223 | nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True),
224 | ]
225 | model10 += [
226 | nn.LeakyReLU(negative_slope=0.2),
227 | ]
228 |
229 | # classification output
230 | model_class = [
231 | nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),
232 | ]
233 |
234 | # regression output
235 | model_out = [
236 | nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),
237 | ]
238 | model_out += [nn.Tanh()]
239 |
240 | self.model1 = nn.Sequential(*model1)
241 | self.model2 = nn.Sequential(*model2)
242 | self.model3 = nn.Sequential(*model3)
243 | self.model4 = nn.Sequential(*model4)
244 | self.model5 = nn.Sequential(*model5)
245 | self.model6 = nn.Sequential(*model6)
246 | self.model7 = nn.Sequential(*model7)
247 | self.model8up = nn.Sequential(*model8up)
248 | self.model8 = nn.Sequential(*model8)
249 | self.model9up = nn.Sequential(*model9up)
250 | self.model9 = nn.Sequential(*model9)
251 | self.model10up = nn.Sequential(*model10up)
252 | self.model10 = nn.Sequential(*model10)
253 | self.model3short8 = nn.Sequential(*model3short8)
254 | self.model2short9 = nn.Sequential(*model2short9)
255 | self.model1short10 = nn.Sequential(*model1short10)
256 |
257 | self.model_class = nn.Sequential(*model_class)
258 | self.model_out = nn.Sequential(*model_out)
259 |
260 | self.upsample4 = nn.Sequential(
261 | *[
262 | nn.Upsample(scale_factor=4, mode="bilinear"),
263 | ]
264 | )
265 | self.softmax = nn.Sequential(
266 | *[
267 | nn.Softmax(dim=1),
268 | ]
269 | )
270 |
271 | def forward(self, input_a, input_b=None, mask_b=None):
272 | if input_b is None:
273 | input_b = torch.cat((input_a * 0, input_a * 0), dim=1)
274 | if mask_b is None:
275 | mask_b = input_a * 0
276 |
277 | conv1_2 = self.model1(torch.cat((self.normalize_l(input_a), self.normalize_ab(input_b), mask_b), dim=1))
278 | conv2_2 = self.model2(conv1_2[:, :, ::2, ::2])
279 | conv3_3 = self.model3(conv2_2[:, :, ::2, ::2])
280 | conv4_3 = self.model4(conv3_3[:, :, ::2, ::2])
281 | conv5_3 = self.model5(conv4_3)
282 | conv6_3 = self.model6(conv5_3)
283 | conv7_3 = self.model7(conv6_3)
284 |
285 | conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
286 | conv8_3 = self.model8(conv8_up)
287 | conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
288 | conv9_3 = self.model9(conv9_up)
289 | conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
290 | conv10_2 = self.model10(conv10_up)
291 | out_reg = self.model_out(conv10_2)
292 |
293 | conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
294 | conv9_3 = self.model9(conv9_up)
295 | conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
296 | conv10_2 = self.model10(conv10_up)
297 | out_reg = self.model_out(conv10_2)
298 |
299 | return self.unnormalize_ab(out_reg)
300 |
301 |
302 | def siggraph17(pretrained=True):
303 | model = SIGGRAPHGenerator()
304 | if pretrained:
305 | import torch.utils.model_zoo as model_zoo
306 |
307 | model.load_state_dict(
308 | model_zoo.load_url(
309 | "https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth",
310 | map_location="cpu",
311 | check_hash=True,
312 | )
313 | )
314 | return model
315 |
--------------------------------------------------------------------------------
/models/deep_colorization/colorizers/util.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | from skimage import color
4 | import torch
5 | import torch.nn.functional as F
6 | from IPython import embed
7 |
8 |
9 | def load_img(img_path):
10 | out_np = np.asarray(Image.open(img_path))
11 | if out_np.ndim == 2:
12 | out_np = np.tile(out_np[:, :, None], 3)
13 | return out_np
14 |
15 |
16 | def resize_img(img, HW=(256, 256), resample=3):
17 | return np.asarray(Image.fromarray(img).resize((HW[1], HW[0]), resample=resample))
18 |
19 |
20 | def preprocess_img(img_rgb_orig, HW=(256, 256), resample=3):
21 | # return original size L and resized L as torch Tensors
22 | img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample)
23 |
24 | img_lab_orig = color.rgb2lab(img_rgb_orig)
25 | img_lab_rs = color.rgb2lab(img_rgb_rs)
26 |
27 | img_l_orig = img_lab_orig[:, :, 0]
28 | img_l_rs = img_lab_rs[:, :, 0]
29 |
30 | tens_orig_l = torch.Tensor(img_l_orig)[None, None, :, :]
31 | tens_rs_l = torch.Tensor(img_l_rs)[None, None, :, :]
32 |
33 | return tens_orig_l, tens_rs_l
34 |
35 |
36 | def postprocess_tens(tens_orig_l, out_ab, mode="bilinear"):
37 | # tens_orig_l 1 x 1 x H_orig x W_orig
38 | # out_ab 1 x 2 x H x W
39 |
40 | HW_orig = tens_orig_l.shape[2:]
41 | HW = out_ab.shape[2:]
42 |
43 | # call resize function if needed
44 | if HW_orig[0] != HW[0] or HW_orig[1] != HW[1]:
45 | out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode="bilinear")
46 | else:
47 | out_ab_orig = out_ab
48 |
49 | out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1)
50 | return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0, ...].transpose((1, 2, 0)))
51 |
--------------------------------------------------------------------------------
/pages/02_🎥_Input_Youtube_Link.py:
--------------------------------------------------------------------------------
1 | import time
2 |
3 | import cv2
4 | import moviepy.editor as mp
5 | import numpy as np
6 | import streamlit as st
7 | from pytube import YouTube
8 | from tqdm import tqdm
9 |
10 |
11 | from utils import format_time, colorize_frame, change_model, load_model, setup_columns, set_page_config
12 |
13 | set_page_config()
14 | loaded_model = load_model()
15 | col2 = setup_columns()
16 | current_model = None
17 |
18 | with col2:
19 | st.write(
20 | """
21 | ## B&W Videos Colorizer
22 | ##### Input a YouTube black and white video link and get a colorized version of it.
23 | ###### ➠ This space is using CPU Basic so it might take a while to colorize a video.
24 | ###### ➠ If you want more models and GPU available please support this space by donating."""
25 | )
26 |
27 |
28 | @st.cache_data()
29 | def download_video(link):
30 | """
31 | Download video from YouTube
32 | """
33 | yt = YouTube(link)
34 | video = (
35 | yt.streams.filter(progressive=True, file_extension="mp4")
36 | .order_by("resolution")
37 | .desc()
38 | .first()
39 | .download(filename="video.mp4")
40 | )
41 | return video
42 |
43 |
44 | def main():
45 | """
46 | Main function
47 | """
48 | model = st.selectbox(
49 | "Select Model (Both models have their pros and cons,"
50 | "I recommend trying both and keeping the best for you task)",
51 | ["ECCV16", "SIGGRAPH17"],
52 | index=0,
53 | )
54 |
55 | loaded_model = change_model(current_model, model)
56 | st.write(f"Model is now {model}")
57 |
58 | link = st.text_input("YouTube Link (The longer the video, the longer the processing time)")
59 | if st.button("Colorize"):
60 | yt_video = download_video(link)
61 | print(yt_video)
62 | col1, col2 = st.columns([0.5, 0.5])
63 | with col1:
64 | st.markdown('Before
', unsafe_allow_html=True) 65 | st.video(yt_video) 66 | with col2: 67 | st.markdown('After
', unsafe_allow_html=True) 68 | with st.spinner("Colorizing frames..."): 69 | # Colorize video frames and store in a list 70 | output_frames = [] 71 | 72 | audio = mp.AudioFileClip("video.mp4") 73 | video = cv2.VideoCapture("video.mp4") 74 | 75 | total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 76 | fps = video.get(cv2.CAP_PROP_FPS) 77 | 78 | progress_bar = st.progress(0) # Create a progress bar 79 | start_time = time.time() 80 | time_text = st.text("Time Remaining: ") # Initialize text value 81 | 82 | for _ in tqdm(range(total_frames), unit="frame", desc="Progress"): 83 | ret, frame = video.read() 84 | if not ret: 85 | break 86 | 87 | colorized_frame = colorize_frame(frame, loaded_model) 88 | output_frames.append((colorized_frame * 255).astype(np.uint8)) 89 | 90 | elapsed_time = time.time() - start_time 91 | frames_completed = len(output_frames) 92 | frames_remaining = total_frames - frames_completed 93 | time_remaining = (frames_remaining / frames_completed) * elapsed_time 94 | 95 | progress_bar.progress(frames_completed / total_frames) # Update progress bar 96 | 97 | if frames_completed < total_frames: 98 | time_text.text(f"Time Remaining: {format_time(time_remaining)}") # Update text value 99 | else: 100 | time_text.empty() # Remove text value 101 | progress_bar.empty() 102 | 103 | with st.spinner("Merging frames to video..."): 104 | frame_size = output_frames[0].shape[:2] 105 | output_filename = "output.mp4" 106 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec for MP4 video 107 | out = cv2.VideoWriter(output_filename, fourcc, fps, (frame_size[1], frame_size[0])) 108 | 109 | # Display the colorized video using st.video 110 | for frame in output_frames: 111 | frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 112 | 113 | out.write(frame_bgr) 114 | 115 | out.release() 116 | 117 | # Convert the output video to a format compatible with Streamlit 118 | converted_filename = "converted_output.mp4" 119 | clip = mp.VideoFileClip(output_filename) 120 | clip = clip.set_audio(audio) 121 | 122 | clip.write_videofile(converted_filename, codec="libx264") 123 | 124 | # Display the converted video using st.video() 125 | st.video(converted_filename) 126 | st.balloons() 127 | 128 | # Add a download button for the colorized video 129 | st.download_button( 130 | label="Download Colorized Video", 131 | data=open(converted_filename, "rb").read(), 132 | file_name="colorized_video.mp4", 133 | ) 134 | 135 | # Close and delete the temporary file after processing 136 | video.release() 137 | 138 | 139 | if __name__ == "__main__": 140 | main() 141 | st.markdown( 142 | "###### Made with :heart: by [Clément Delteil](https://www.linkedin.com/in/clementdelteil/) [](https://www.buymeacoffee.com/clementdelteil)" 144 | ) 145 | -------------------------------------------------------------------------------- /pages/03_🖼️_Input_Images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | 4 | import streamlit as st 5 | from PIL import Image 6 | 7 | from utils import change_model, load_model, setup_columns, set_page_config, colorize_image 8 | 9 | set_page_config() 10 | loaded_model = load_model() 11 | col2 = setup_columns() 12 | current_model = None 13 | 14 | with col2: 15 | st.write( 16 | """ 17 | ## B&W Images Colorizer 18 | ##### Input a black and white image and get a colorized version of it. 19 | ###### ➠ If you want to colorize multiple images just upload them all at once. 20 | ###### ➠ Uploading already colored images won't raise errors but images won't look good. 21 | ###### ➠ I recommend starting with the first model and then experimenting with the second one.""" 22 | ) 23 | 24 | 25 | def main(): 26 | """ 27 | Main function 28 | """ 29 | model = st.selectbox( 30 | "Select Model (Both models have their pros and cons, " 31 | "I recommend trying both and keeping the best for you task)", 32 | ["ECCV16", "SIGGRAPH17"], 33 | index=0, 34 | ) 35 | 36 | # Make the user select a model 37 | loaded_model = change_model(current_model, model) 38 | st.write(f"Model is now {model}") 39 | 40 | # Ask the user if he wants to see colorization 41 | display_results = st.checkbox("Display results in real time", value=True) 42 | 43 | # Input for the user to upload images 44 | uploaded_file = st.file_uploader( 45 | "Upload your images here...", type=["jpg", "png", "jpeg"], accept_multiple_files=True 46 | ) 47 | 48 | # If the user clicks on the button 49 | if st.button("Colorize"): 50 | # If the user uploaded images 51 | if uploaded_file is not None: 52 | if display_results: 53 | col1, col2 = st.columns([0.5, 0.5]) 54 | with col1: 55 | st.markdown('Before
', unsafe_allow_html=True) 56 | with col2: 57 | st.markdown('After
', unsafe_allow_html=True) 58 | else: 59 | col1, col2, _ = st.columns(3) 60 | 61 | for i, file in enumerate(uploaded_file): 62 | file_extension = os.path.splitext(file.name)[1].lower() 63 | if file_extension in [".jpg", ".png", ".jpeg"]: 64 | image = Image.open(file) 65 | if display_results: 66 | with col1: 67 | st.image(image, use_column_width="always") 68 | with col2: 69 | with st.spinner("Colorizing image..."): 70 | out_img, new_img = colorize_image(file, loaded_model) 71 | new_img.save("IMG_" + str(i + 1) + ".jpg") 72 | st.image(out_img, use_column_width="always") 73 | 74 | else: 75 | out_img, new_img = colorize_image(file, loaded_model) 76 | new_img.save("IMG_" + str(i + 1) + ".jpg") 77 | 78 | if len(uploaded_file) > 1: 79 | # Create a zip file 80 | zip_filename = "colorized_images.zip" 81 | with zipfile.ZipFile(zip_filename, "w") as zip_file: 82 | # Add colorized images to the zip file 83 | for i in range(len(uploaded_file)): 84 | zip_file.write("IMG_" + str(i + 1) + ".jpg", "IMG_" + str(i) + ".jpg") 85 | with col2: 86 | # Provide the zip file data for download 87 | st.download_button( 88 | label="Download Colorized Images" if len(uploaded_file) > 1 else "Download Colorized Image", 89 | data=open(zip_filename, "rb").read(), 90 | file_name=zip_filename, 91 | ) 92 | else: 93 | with col2: 94 | st.download_button( 95 | label="Download Colorized Image", 96 | data=open("IMG_1.jpg", "rb").read(), 97 | file_name="IMG_1.jpg", 98 | ) 99 | 100 | else: 101 | st.warning("Upload a file", icon="⚠️") 102 | 103 | 104 | if __name__ == "__main__": 105 | main() 106 | st.markdown( 107 | "###### Made with :heart: by [Clément Delteil](https://www.linkedin.com/in/clementdelteil/) [](https://www.buymeacoffee.com/clementdelteil)" 109 | ) 110 | -------------------------------------------------------------------------------- /pages/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wazzabeee/image-video-colorization/069f937320c1cbcd04579724048228600596788a/pages/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipython==8.10.0 2 | moviepy==1.0.3 3 | numpy==1.23.2 4 | opencv_python==4.8.1.78 5 | Pillow==10.3.0 6 | scikit-image==0.20.0 7 | streamlit==1.30.0 8 | torch==1.13.1 9 | streamlit_lottie==0.0.5 10 | requests==2.31.0 11 | tqdm==4.64.1 12 | pytube==15.0.0 13 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pylint] 2 | disable=C0103,C0114,R0913,R0914,C0200,C0301,E1101 3 | 4 | [pep8] 5 | max-line-length=120 6 | ignore=E121,E123,E126,E226,E24,E704,E203,W503 7 | exclude=venv,test_env,test_venv 8 | 9 | [flake8] 10 | max-line-length=120 11 | ignore=E121,E123,E126,E226,E24,E704,E203,W503 12 | exclude=venv,test_env,test_venv -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import requests 3 | import streamlit as st 4 | from PIL import Image 5 | from streamlit_lottie import st_lottie 6 | 7 | from models.deep_colorization import eccv16 8 | from models.deep_colorization import siggraph17 9 | from models.deep_colorization import postprocess_tens, preprocess_img, load_img 10 | 11 | 12 | class SameModelException(ValueError): 13 | """Exception raised when the same model is attempted to be reloaded.""" 14 | 15 | 16 | def set_page_config(): 17 | """ 18 | Sets up the page config. 19 | """ 20 | st.set_page_config(page_title="Image & Video Colorizer", page_icon="🎨", layout="wide") 21 | 22 | 23 | def load_model(): 24 | """ 25 | Loads the default model. 26 | """ 27 | return eccv16(pretrained=True).eval() 28 | 29 | 30 | def setup_columns(): 31 | """ 32 | Sets up the columns. 33 | """ 34 | col1, col2 = st.columns([1, 3]) 35 | lottie = load_lottieurl("https://assets5.lottiefiles.com/packages/lf20_RHdEuzVfEL.json") 36 | with col1: 37 | st_lottie(lottie) 38 | return col2 39 | 40 | 41 | # Define a function that we can use to load lottie files from a link. 42 | @st.cache_data() 43 | def load_lottieurl(url: str): 44 | """ 45 | Load lottieurl image 46 | """ 47 | try: 48 | r = requests.get(url, timeout=10) # Timeout set to 10 seconds 49 | r.raise_for_status() # This will raise an exception for HTTP errors 50 | return r.json() 51 | except requests.RequestException as e: 52 | print(f"Request failed: {e}") 53 | return None 54 | 55 | 56 | @st.cache_resource() 57 | def change_model(current_model, model): 58 | """ 59 | Change model 60 | """ 61 | loaded_model = "None" 62 | 63 | if current_model != model: 64 | if model == "ECCV16": 65 | loaded_model = eccv16(pretrained=True).eval() 66 | elif model == "SIGGRAPH17": 67 | loaded_model = siggraph17(pretrained=True).eval() 68 | return loaded_model 69 | 70 | raise SameModelException("Model is the same as the current one.") 71 | 72 | 73 | def format_time(seconds: float) -> str: 74 | """Formats time in seconds to a human readable format""" 75 | if seconds < 60: 76 | return f"{int(seconds)} seconds" 77 | if seconds < 3600: 78 | minutes = seconds // 60 79 | seconds %= 60 80 | return f"{minutes} minutes and {int(seconds)} seconds" 81 | if seconds < 86400: 82 | hours = seconds // 3600 83 | minutes = (seconds % 3600) // 60 84 | seconds %= 60 85 | return f"{hours} hours, {minutes} minutes, and {int(seconds)} seconds" 86 | 87 | days = seconds // 86400 88 | hours = (seconds % 86400) // 3600 89 | minutes = (seconds % 3600) // 60 90 | seconds %= 60 91 | return f"{days} days, {hours} hours, {minutes} minutes, and {int(seconds)} seconds" 92 | 93 | 94 | # Function to colorize video frames 95 | def colorize_frame(frame, colorizer) -> np.ndarray: 96 | """ 97 | Colorize frame 98 | """ 99 | tens_l_orig, tens_l_rs = preprocess_img(frame, HW=(256, 256)) 100 | return postprocess_tens(tens_l_orig, colorizer(tens_l_rs).cpu()) 101 | 102 | 103 | def colorize_image(file, loaded_model): 104 | """ 105 | Colorize image 106 | """ 107 | img = load_img(file) 108 | # If user input a colored image with 4 channels, discard the fourth channel 109 | if img.shape[2] == 4: 110 | img = img[:, :, :3] 111 | 112 | tens_l_orig, tens_l_rs = preprocess_img(img, HW=(256, 256)) 113 | out_img = postprocess_tens(tens_l_orig, loaded_model(tens_l_rs).cpu()) 114 | new_img = Image.fromarray((out_img * 255).astype(np.uint8)) 115 | 116 | return out_img, new_img 117 | --------------------------------------------------------------------------------