├── .gitattributes ├── .github └── workflows │ └── ci.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .streamlit └── config.toml ├── 01_📼_Upload_Video_File.py ├── README.md ├── __init__.py ├── interface.gif ├── models ├── __init__.py └── deep_colorization │ ├── __init__.py │ └── colorizers │ ├── __init__.py │ ├── base_color.py │ ├── eccv16.py │ ├── siggraph17.py │ └── util.py ├── pages ├── 02_🎥_Input_Youtube_Link.py ├── 03_🖼️_Input_Images.py └── __init__.py ├── requirements.txt ├── setup.cfg └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: CI 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | pre-commit: 7 | runs-on: ubuntu-latest 8 | name: Do the code respects Python standards? 9 | steps: 10 | - uses: actions/checkout@v2 11 | - name: Set up Python 12 | uses: actions/setup-python@v2 13 | with: 14 | python-version: '3.10' 15 | - name: Install pre-commit & requirements 16 | run: | 17 | pip install pre-commit pylint 18 | pip install -r requirements.txt 19 | - name: Run pre-commit 20 | run: pre-commit run --all-files -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | .idea/ 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ 145 | 146 | # Cython debug symbols 147 | cython_debug/ 148 | 149 | # PyCharm 150 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 151 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 152 | # and can be added to the global gitignore or merged into this file. For a more nuclear 153 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 154 | #.idea/ 155 | .idea/vcs.xml 156 | '.idea' 157 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 24.4.0 4 | hooks: 5 | - id: black 6 | args: ['--line-length=120', '--verbose'] 7 | exclude: '^models/' 8 | 9 | - repo: https://github.com/pycqa/flake8 10 | rev: '7.0.0' 11 | hooks: 12 | - id: flake8 13 | exclude: '^models/' 14 | 15 | - repo: https://github.com/pre-commit/mirrors-pylint 16 | rev: v3.0.0a5 17 | hooks: 18 | - id: pylint 19 | name: pylint 20 | entry: pylint 21 | language: system 22 | args: ['.', '--rcfile=setup.cfg', '--fail-under=8'] 23 | exclude: '^models/' 24 | types: [python] -------------------------------------------------------------------------------- /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [theme] 2 | primaryColor="#F63366" 3 | backgroundColor="#FFFFFF" 4 | secondaryBackgroundColor="#F0F2F6" 5 | textColor="#262730" 6 | font="sans serif" 7 | [server] 8 | maxUploadSize=1028 -------------------------------------------------------------------------------- /01_📼_Upload_Video_File.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import time 4 | 5 | import cv2 6 | import moviepy.editor as mp 7 | import numpy as np 8 | import streamlit as st 9 | 10 | from tqdm import tqdm 11 | 12 | from utils import format_time, colorize_frame, change_model, load_model, setup_columns, set_page_config 13 | 14 | set_page_config() 15 | loaded_model = load_model() 16 | col2 = setup_columns() 17 | current_model = None 18 | 19 | with col2: 20 | st.write( 21 | """ 22 | ## B&W Videos Colorizer 23 | ##### Upload a black and white video and get a colorized version of it. 24 | ###### ➠ This space is using CPU Basic so it might take a while to colorize a video. 25 | ###### ➠ If you want more models and GPU available please support this space by donating.""" 26 | ) 27 | 28 | 29 | def main(): 30 | """ 31 | Main function to run this page 32 | """ 33 | model = st.selectbox( 34 | "Select Model (Both models have their pros and cons, I recommend trying both and keeping the best for your " 35 | "task)", 36 | ["ECCV16", "SIGGRAPH17"], 37 | index=0, 38 | ) 39 | 40 | loaded_model = change_model(current_model, model) 41 | st.write(f"Model is now {model}") 42 | 43 | uploaded_file = st.file_uploader("Upload your video here...", type=["mp4", "mov", "avi", "mkv"]) 44 | 45 | if st.button("Colorize"): 46 | if uploaded_file is not None: 47 | file_extension = os.path.splitext(uploaded_file.name)[1].lower() 48 | if file_extension in [".mp4", ".avi", ".mov", ".mkv"]: 49 | # Save the video file to a temporary location 50 | temp_file = tempfile.NamedTemporaryFile(delete=False) 51 | temp_file.write(uploaded_file.read()) 52 | 53 | audio = mp.AudioFileClip(temp_file.name) 54 | 55 | # Open the video using cv2.VideoCapture 56 | video = cv2.VideoCapture(temp_file.name) 57 | 58 | # Get video information 59 | fps = video.get(cv2.CAP_PROP_FPS) 60 | 61 | col1, col2 = st.columns([0.5, 0.5]) 62 | with col1: 63 | st.markdown('

Before

', unsafe_allow_html=True) 64 | st.video(temp_file.name) 65 | 66 | with col2: 67 | st.markdown('

After

', unsafe_allow_html=True) 68 | 69 | with st.spinner("Colorizing frames..."): 70 | # Colorize video frames and store in a list 71 | output_frames = [] 72 | total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 73 | progress_bar = st.progress(0) # Create a progress bar 74 | 75 | start_time = time.time() 76 | time_text = st.text("Time Remaining: ") # Initialize text value 77 | 78 | for _ in tqdm(range(total_frames), unit="frame", desc="Progress"): 79 | ret, frame = video.read() 80 | if not ret: 81 | break 82 | 83 | colorized_frame = colorize_frame(frame, loaded_model) 84 | output_frames.append((colorized_frame * 255).astype(np.uint8)) 85 | 86 | elapsed_time = time.time() - start_time 87 | frames_completed = len(output_frames) 88 | frames_remaining = total_frames - frames_completed 89 | time_remaining = (frames_remaining / frames_completed) * elapsed_time 90 | 91 | progress_bar.progress(frames_completed / total_frames) # Update progress bar 92 | 93 | if frames_completed < total_frames: 94 | time_text.text(f"Time Remaining: {format_time(time_remaining)}") # Update text value 95 | else: 96 | time_text.empty() # Remove text value 97 | progress_bar.empty() 98 | 99 | with st.spinner("Merging frames to video..."): 100 | frame_size = output_frames[0].shape[:2] 101 | output_filename = "output.mp4" 102 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec for MP4 video 103 | out = cv2.VideoWriter(output_filename, fourcc, fps, (frame_size[1], frame_size[0])) 104 | 105 | # Display the colorized video using st.video 106 | for frame in output_frames: 107 | frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 108 | 109 | out.write(frame_bgr) 110 | 111 | out.release() 112 | 113 | # Convert the output video to a format compatible with Streamlit 114 | converted_filename = "converted_output.mp4" 115 | clip = mp.VideoFileClip(output_filename) 116 | clip = clip.set_audio(audio) 117 | 118 | clip.write_videofile(converted_filename, codec="libx264") 119 | 120 | # Display the converted video using st.video() 121 | st.video(converted_filename) 122 | st.balloons() 123 | 124 | # Add a download button for the colorized video 125 | st.download_button( 126 | label="Download Colorized Video", 127 | data=open(converted_filename, "rb").read(), 128 | file_name="colorized_video.mp4", 129 | ) 130 | 131 | # Close and delete the temporary file after processing 132 | video.release() 133 | temp_file.close() 134 | 135 | 136 | if __name__ == "__main__": 137 | main() 138 | st.markdown( 139 | "###### Made with :heart: by [Clément Delteil](https://www.linkedin.com/in/clementdelteil/) [![this is an " 140 | "image link](https://i.imgur.com/thJhzOO.png)](https://www.buymeacoffee.com/clementdelteil)" 141 | ) 142 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image Video Colorization 2 | This project was conceived as part of the "Sujets Spéciaux" course at UQAC under the supervision of [Kevin Bouchard](http://www.kevin-bouchard.ca/) 3 | 4 | ## Objective 5 | The aim was to deploy a video image colorization application. To this end, I have taken up the GitHub project of [Richard Zhang and his co-authors](https://github.com/richzhang/colorization) with models and papers presented at ECCV16 and SIGGRAPH17. Their solutions are based on convolutional neural networks. 6 | 7 | The interface for the models was designed on Streamlit and deployed in a [Space on Hugging Face](https://huggingface.co/spaces/Wazzzabeee/image-video-colorization). 8 | 9 | The tutorial with all the design steps is available on my [Medium blog](https://medium.com/geekculture/creating-a-web-app-to-colorize-images-and-youtube-videos-80f5be2d0f68). 10 | 11 | The following features are available: 12 | - Colorization of a batch of images. 13 | - Colorization of MP4, MKV and AVI video files. 14 | - Colorization of Youtube videos. 15 | 16 | ## Interface 17 | 18 | 19 | ## Running Locally 20 | If you want to process longer videos and you're limited by the Hugging Face space memory's limits, you can run this app locally. 21 | 22 | `ffmpeg.exe` is needed to run this app, you can install it using `brew install ffmpeg` and update the `IMAGEIO_FFMPEG_EXE` environment variable accordingly. 23 | 24 | ```bash 25 | git clone https://github.com/Wazzabeee/image-video-colorization 26 | cd image-video-colorization 27 | pip install -r requirements.txt 28 | streamlit run 01_📼_Upload_Video_File.py 29 | ``` 30 | 31 | ## Todos 32 | Other models based on GANs will probably be implemented in the future if my application for a community grant to gain access to a GPU on Hugging Face is successful. 33 | 34 | ## References 35 | 1. Richard Zhang, Phillip Isola et Alexei A Efros. « Colorful Image Colorization ». In : ECCV. 2016. 36 | 2. Richard Zhang et al. « Real-Time User-Guided Image Colorization with Learned Deep Priors ». In : ACM 37 | Transactions on Graphics (TOG) 9.4 (2017). 38 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wazzabeee/image-video-colorization/069f937320c1cbcd04579724048228600596788a/__init__.py -------------------------------------------------------------------------------- /interface.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wazzabeee/image-video-colorization/069f937320c1cbcd04579724048228600596788a/interface.gif -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wazzabeee/image-video-colorization/069f937320c1cbcd04579724048228600596788a/models/__init__.py -------------------------------------------------------------------------------- /models/deep_colorization/__init__.py: -------------------------------------------------------------------------------- 1 | # inside models/deep_colorization/__init__.py 2 | from .colorizers import eccv16, siggraph17, load_img, preprocess_img, postprocess_tens 3 | -------------------------------------------------------------------------------- /models/deep_colorization/colorizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_color import BaseColor 2 | from .eccv16 import ECCVGenerator, eccv16 3 | from .siggraph17 import SIGGRAPHGenerator, siggraph17 4 | from .util import load_img, resize_img, preprocess_img, postprocess_tens 5 | -------------------------------------------------------------------------------- /models/deep_colorization/colorizers/base_color.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class BaseColor(nn.Module): 5 | def __init__(self): 6 | super(BaseColor, self).__init__() 7 | 8 | self.l_cent = 50.0 9 | self.l_norm = 100.0 10 | self.ab_norm = 110.0 11 | 12 | def normalize_l(self, in_l): 13 | return (in_l - self.l_cent) / self.l_norm 14 | 15 | def unnormalize_l(self, in_l): 16 | return in_l * self.l_norm + self.l_cent 17 | 18 | def normalize_ab(self, in_ab): 19 | return in_ab / self.ab_norm 20 | 21 | def unnormalize_ab(self, in_ab): 22 | return in_ab * self.ab_norm 23 | -------------------------------------------------------------------------------- /models/deep_colorization/colorizers/eccv16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from IPython import embed 5 | 6 | from .base_color import * 7 | 8 | 9 | class ECCVGenerator(BaseColor): 10 | def __init__(self, norm_layer=nn.BatchNorm2d): 11 | super(ECCVGenerator, self).__init__() 12 | 13 | model1 = [ 14 | nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True), 15 | ] 16 | model1 += [ 17 | nn.ReLU(True), 18 | ] 19 | model1 += [ 20 | nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True), 21 | ] 22 | model1 += [ 23 | nn.ReLU(True), 24 | ] 25 | model1 += [ 26 | norm_layer(64), 27 | ] 28 | 29 | model2 = [ 30 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True), 31 | ] 32 | model2 += [ 33 | nn.ReLU(True), 34 | ] 35 | model2 += [ 36 | nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True), 37 | ] 38 | model2 += [ 39 | nn.ReLU(True), 40 | ] 41 | model2 += [ 42 | norm_layer(128), 43 | ] 44 | 45 | model3 = [ 46 | nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True), 47 | ] 48 | model3 += [ 49 | nn.ReLU(True), 50 | ] 51 | model3 += [ 52 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), 53 | ] 54 | model3 += [ 55 | nn.ReLU(True), 56 | ] 57 | model3 += [ 58 | nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True), 59 | ] 60 | model3 += [ 61 | nn.ReLU(True), 62 | ] 63 | model3 += [ 64 | norm_layer(256), 65 | ] 66 | 67 | model4 = [ 68 | nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True), 69 | ] 70 | model4 += [ 71 | nn.ReLU(True), 72 | ] 73 | model4 += [ 74 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 75 | ] 76 | model4 += [ 77 | nn.ReLU(True), 78 | ] 79 | model4 += [ 80 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 81 | ] 82 | model4 += [ 83 | nn.ReLU(True), 84 | ] 85 | model4 += [ 86 | norm_layer(512), 87 | ] 88 | 89 | model5 = [ 90 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), 91 | ] 92 | model5 += [ 93 | nn.ReLU(True), 94 | ] 95 | model5 += [ 96 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), 97 | ] 98 | model5 += [ 99 | nn.ReLU(True), 100 | ] 101 | model5 += [ 102 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), 103 | ] 104 | model5 += [ 105 | nn.ReLU(True), 106 | ] 107 | model5 += [ 108 | norm_layer(512), 109 | ] 110 | 111 | model6 = [ 112 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), 113 | ] 114 | model6 += [ 115 | nn.ReLU(True), 116 | ] 117 | model6 += [ 118 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), 119 | ] 120 | model6 += [ 121 | nn.ReLU(True), 122 | ] 123 | model6 += [ 124 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), 125 | ] 126 | model6 += [ 127 | nn.ReLU(True), 128 | ] 129 | model6 += [ 130 | norm_layer(512), 131 | ] 132 | 133 | model7 = [ 134 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 135 | ] 136 | model7 += [ 137 | nn.ReLU(True), 138 | ] 139 | model7 += [ 140 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 141 | ] 142 | model7 += [ 143 | nn.ReLU(True), 144 | ] 145 | model7 += [ 146 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 147 | ] 148 | model7 += [ 149 | nn.ReLU(True), 150 | ] 151 | model7 += [ 152 | norm_layer(512), 153 | ] 154 | 155 | model8 = [ 156 | nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True), 157 | ] 158 | model8 += [ 159 | nn.ReLU(True), 160 | ] 161 | model8 += [ 162 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), 163 | ] 164 | model8 += [ 165 | nn.ReLU(True), 166 | ] 167 | model8 += [ 168 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), 169 | ] 170 | model8 += [ 171 | nn.ReLU(True), 172 | ] 173 | 174 | model8 += [ 175 | nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True), 176 | ] 177 | 178 | self.model1 = nn.Sequential(*model1) 179 | self.model2 = nn.Sequential(*model2) 180 | self.model3 = nn.Sequential(*model3) 181 | self.model4 = nn.Sequential(*model4) 182 | self.model5 = nn.Sequential(*model5) 183 | self.model6 = nn.Sequential(*model6) 184 | self.model7 = nn.Sequential(*model7) 185 | self.model8 = nn.Sequential(*model8) 186 | 187 | self.softmax = nn.Softmax(dim=1) 188 | self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False) 189 | self.upsample4 = nn.Upsample(scale_factor=4, mode="bilinear") 190 | 191 | def forward(self, input_l): 192 | conv1_2 = self.model1(self.normalize_l(input_l)) 193 | conv2_2 = self.model2(conv1_2) 194 | conv3_3 = self.model3(conv2_2) 195 | conv4_3 = self.model4(conv3_3) 196 | conv5_3 = self.model5(conv4_3) 197 | conv6_3 = self.model6(conv5_3) 198 | conv7_3 = self.model7(conv6_3) 199 | conv8_3 = self.model8(conv7_3) 200 | out_reg = self.model_out(self.softmax(conv8_3)) 201 | 202 | return self.unnormalize_ab(self.upsample4(out_reg)) 203 | 204 | 205 | def eccv16(pretrained=True): 206 | model = ECCVGenerator() 207 | if pretrained: 208 | import torch.utils.model_zoo as model_zoo 209 | 210 | model.load_state_dict( 211 | model_zoo.load_url( 212 | "https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth", 213 | map_location="cpu", 214 | check_hash=True, 215 | ) 216 | ) 217 | return model 218 | -------------------------------------------------------------------------------- /models/deep_colorization/colorizers/siggraph17.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .base_color import * 5 | 6 | 7 | class SIGGRAPHGenerator(BaseColor): 8 | def __init__(self, norm_layer=nn.BatchNorm2d, classes=529): 9 | super(SIGGRAPHGenerator, self).__init__() 10 | 11 | # Conv1 12 | model1 = [ 13 | nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True), 14 | ] 15 | model1 += [ 16 | nn.ReLU(True), 17 | ] 18 | model1 += [ 19 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True), 20 | ] 21 | model1 += [ 22 | nn.ReLU(True), 23 | ] 24 | model1 += [ 25 | norm_layer(64), 26 | ] 27 | # add a subsampling operation 28 | 29 | # Conv2 30 | model2 = [ 31 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True), 32 | ] 33 | model2 += [ 34 | nn.ReLU(True), 35 | ] 36 | model2 += [ 37 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True), 38 | ] 39 | model2 += [ 40 | nn.ReLU(True), 41 | ] 42 | model2 += [ 43 | norm_layer(128), 44 | ] 45 | # add a subsampling layer operation 46 | 47 | # Conv3 48 | model3 = [ 49 | nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True), 50 | ] 51 | model3 += [ 52 | nn.ReLU(True), 53 | ] 54 | model3 += [ 55 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), 56 | ] 57 | model3 += [ 58 | nn.ReLU(True), 59 | ] 60 | model3 += [ 61 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), 62 | ] 63 | model3 += [ 64 | nn.ReLU(True), 65 | ] 66 | model3 += [ 67 | norm_layer(256), 68 | ] 69 | # add a subsampling layer operation 70 | 71 | # Conv4 72 | model4 = [ 73 | nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True), 74 | ] 75 | model4 += [ 76 | nn.ReLU(True), 77 | ] 78 | model4 += [ 79 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 80 | ] 81 | model4 += [ 82 | nn.ReLU(True), 83 | ] 84 | model4 += [ 85 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 86 | ] 87 | model4 += [ 88 | nn.ReLU(True), 89 | ] 90 | model4 += [ 91 | norm_layer(512), 92 | ] 93 | 94 | # Conv5 95 | model5 = [ 96 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), 97 | ] 98 | model5 += [ 99 | nn.ReLU(True), 100 | ] 101 | model5 += [ 102 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), 103 | ] 104 | model5 += [ 105 | nn.ReLU(True), 106 | ] 107 | model5 += [ 108 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), 109 | ] 110 | model5 += [ 111 | nn.ReLU(True), 112 | ] 113 | model5 += [ 114 | norm_layer(512), 115 | ] 116 | 117 | # Conv6 118 | model6 = [ 119 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), 120 | ] 121 | model6 += [ 122 | nn.ReLU(True), 123 | ] 124 | model6 += [ 125 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), 126 | ] 127 | model6 += [ 128 | nn.ReLU(True), 129 | ] 130 | model6 += [ 131 | nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True), 132 | ] 133 | model6 += [ 134 | nn.ReLU(True), 135 | ] 136 | model6 += [ 137 | norm_layer(512), 138 | ] 139 | 140 | # Conv7 141 | model7 = [ 142 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 143 | ] 144 | model7 += [ 145 | nn.ReLU(True), 146 | ] 147 | model7 += [ 148 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 149 | ] 150 | model7 += [ 151 | nn.ReLU(True), 152 | ] 153 | model7 += [ 154 | nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True), 155 | ] 156 | model7 += [ 157 | nn.ReLU(True), 158 | ] 159 | model7 += [ 160 | norm_layer(512), 161 | ] 162 | 163 | # Conv7 164 | model8up = [nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)] 165 | model3short8 = [ 166 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), 167 | ] 168 | 169 | model8 = [ 170 | nn.ReLU(True), 171 | ] 172 | model8 += [ 173 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), 174 | ] 175 | model8 += [ 176 | nn.ReLU(True), 177 | ] 178 | model8 += [ 179 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True), 180 | ] 181 | model8 += [ 182 | nn.ReLU(True), 183 | ] 184 | model8 += [ 185 | norm_layer(256), 186 | ] 187 | 188 | # Conv9 189 | model9up = [ 190 | nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True), 191 | ] 192 | model2short9 = [ 193 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True), 194 | ] 195 | # add the two feature maps above 196 | 197 | model9 = [ 198 | nn.ReLU(True), 199 | ] 200 | model9 += [ 201 | nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True), 202 | ] 203 | model9 += [ 204 | nn.ReLU(True), 205 | ] 206 | model9 += [ 207 | norm_layer(128), 208 | ] 209 | 210 | # Conv10 211 | model10up = [ 212 | nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True), 213 | ] 214 | model1short10 = [ 215 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True), 216 | ] 217 | # add the two feature maps above 218 | 219 | model10 = [ 220 | nn.ReLU(True), 221 | ] 222 | model10 += [ 223 | nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True), 224 | ] 225 | model10 += [ 226 | nn.LeakyReLU(negative_slope=0.2), 227 | ] 228 | 229 | # classification output 230 | model_class = [ 231 | nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True), 232 | ] 233 | 234 | # regression output 235 | model_out = [ 236 | nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True), 237 | ] 238 | model_out += [nn.Tanh()] 239 | 240 | self.model1 = nn.Sequential(*model1) 241 | self.model2 = nn.Sequential(*model2) 242 | self.model3 = nn.Sequential(*model3) 243 | self.model4 = nn.Sequential(*model4) 244 | self.model5 = nn.Sequential(*model5) 245 | self.model6 = nn.Sequential(*model6) 246 | self.model7 = nn.Sequential(*model7) 247 | self.model8up = nn.Sequential(*model8up) 248 | self.model8 = nn.Sequential(*model8) 249 | self.model9up = nn.Sequential(*model9up) 250 | self.model9 = nn.Sequential(*model9) 251 | self.model10up = nn.Sequential(*model10up) 252 | self.model10 = nn.Sequential(*model10) 253 | self.model3short8 = nn.Sequential(*model3short8) 254 | self.model2short9 = nn.Sequential(*model2short9) 255 | self.model1short10 = nn.Sequential(*model1short10) 256 | 257 | self.model_class = nn.Sequential(*model_class) 258 | self.model_out = nn.Sequential(*model_out) 259 | 260 | self.upsample4 = nn.Sequential( 261 | *[ 262 | nn.Upsample(scale_factor=4, mode="bilinear"), 263 | ] 264 | ) 265 | self.softmax = nn.Sequential( 266 | *[ 267 | nn.Softmax(dim=1), 268 | ] 269 | ) 270 | 271 | def forward(self, input_a, input_b=None, mask_b=None): 272 | if input_b is None: 273 | input_b = torch.cat((input_a * 0, input_a * 0), dim=1) 274 | if mask_b is None: 275 | mask_b = input_a * 0 276 | 277 | conv1_2 = self.model1(torch.cat((self.normalize_l(input_a), self.normalize_ab(input_b), mask_b), dim=1)) 278 | conv2_2 = self.model2(conv1_2[:, :, ::2, ::2]) 279 | conv3_3 = self.model3(conv2_2[:, :, ::2, ::2]) 280 | conv4_3 = self.model4(conv3_3[:, :, ::2, ::2]) 281 | conv5_3 = self.model5(conv4_3) 282 | conv6_3 = self.model6(conv5_3) 283 | conv7_3 = self.model7(conv6_3) 284 | 285 | conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3) 286 | conv8_3 = self.model8(conv8_up) 287 | conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) 288 | conv9_3 = self.model9(conv9_up) 289 | conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) 290 | conv10_2 = self.model10(conv10_up) 291 | out_reg = self.model_out(conv10_2) 292 | 293 | conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2) 294 | conv9_3 = self.model9(conv9_up) 295 | conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2) 296 | conv10_2 = self.model10(conv10_up) 297 | out_reg = self.model_out(conv10_2) 298 | 299 | return self.unnormalize_ab(out_reg) 300 | 301 | 302 | def siggraph17(pretrained=True): 303 | model = SIGGRAPHGenerator() 304 | if pretrained: 305 | import torch.utils.model_zoo as model_zoo 306 | 307 | model.load_state_dict( 308 | model_zoo.load_url( 309 | "https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth", 310 | map_location="cpu", 311 | check_hash=True, 312 | ) 313 | ) 314 | return model 315 | -------------------------------------------------------------------------------- /models/deep_colorization/colorizers/util.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | from skimage import color 4 | import torch 5 | import torch.nn.functional as F 6 | from IPython import embed 7 | 8 | 9 | def load_img(img_path): 10 | out_np = np.asarray(Image.open(img_path)) 11 | if out_np.ndim == 2: 12 | out_np = np.tile(out_np[:, :, None], 3) 13 | return out_np 14 | 15 | 16 | def resize_img(img, HW=(256, 256), resample=3): 17 | return np.asarray(Image.fromarray(img).resize((HW[1], HW[0]), resample=resample)) 18 | 19 | 20 | def preprocess_img(img_rgb_orig, HW=(256, 256), resample=3): 21 | # return original size L and resized L as torch Tensors 22 | img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample) 23 | 24 | img_lab_orig = color.rgb2lab(img_rgb_orig) 25 | img_lab_rs = color.rgb2lab(img_rgb_rs) 26 | 27 | img_l_orig = img_lab_orig[:, :, 0] 28 | img_l_rs = img_lab_rs[:, :, 0] 29 | 30 | tens_orig_l = torch.Tensor(img_l_orig)[None, None, :, :] 31 | tens_rs_l = torch.Tensor(img_l_rs)[None, None, :, :] 32 | 33 | return tens_orig_l, tens_rs_l 34 | 35 | 36 | def postprocess_tens(tens_orig_l, out_ab, mode="bilinear"): 37 | # tens_orig_l 1 x 1 x H_orig x W_orig 38 | # out_ab 1 x 2 x H x W 39 | 40 | HW_orig = tens_orig_l.shape[2:] 41 | HW = out_ab.shape[2:] 42 | 43 | # call resize function if needed 44 | if HW_orig[0] != HW[0] or HW_orig[1] != HW[1]: 45 | out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode="bilinear") 46 | else: 47 | out_ab_orig = out_ab 48 | 49 | out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1) 50 | return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0, ...].transpose((1, 2, 0))) 51 | -------------------------------------------------------------------------------- /pages/02_🎥_Input_Youtube_Link.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import cv2 4 | import moviepy.editor as mp 5 | import numpy as np 6 | import streamlit as st 7 | from pytube import YouTube 8 | from tqdm import tqdm 9 | 10 | 11 | from utils import format_time, colorize_frame, change_model, load_model, setup_columns, set_page_config 12 | 13 | set_page_config() 14 | loaded_model = load_model() 15 | col2 = setup_columns() 16 | current_model = None 17 | 18 | with col2: 19 | st.write( 20 | """ 21 | ## B&W Videos Colorizer 22 | ##### Input a YouTube black and white video link and get a colorized version of it. 23 | ###### ➠ This space is using CPU Basic so it might take a while to colorize a video. 24 | ###### ➠ If you want more models and GPU available please support this space by donating.""" 25 | ) 26 | 27 | 28 | @st.cache_data() 29 | def download_video(link): 30 | """ 31 | Download video from YouTube 32 | """ 33 | yt = YouTube(link) 34 | video = ( 35 | yt.streams.filter(progressive=True, file_extension="mp4") 36 | .order_by("resolution") 37 | .desc() 38 | .first() 39 | .download(filename="video.mp4") 40 | ) 41 | return video 42 | 43 | 44 | def main(): 45 | """ 46 | Main function 47 | """ 48 | model = st.selectbox( 49 | "Select Model (Both models have their pros and cons," 50 | "I recommend trying both and keeping the best for you task)", 51 | ["ECCV16", "SIGGRAPH17"], 52 | index=0, 53 | ) 54 | 55 | loaded_model = change_model(current_model, model) 56 | st.write(f"Model is now {model}") 57 | 58 | link = st.text_input("YouTube Link (The longer the video, the longer the processing time)") 59 | if st.button("Colorize"): 60 | yt_video = download_video(link) 61 | print(yt_video) 62 | col1, col2 = st.columns([0.5, 0.5]) 63 | with col1: 64 | st.markdown('

Before

', unsafe_allow_html=True) 65 | st.video(yt_video) 66 | with col2: 67 | st.markdown('

After

', unsafe_allow_html=True) 68 | with st.spinner("Colorizing frames..."): 69 | # Colorize video frames and store in a list 70 | output_frames = [] 71 | 72 | audio = mp.AudioFileClip("video.mp4") 73 | video = cv2.VideoCapture("video.mp4") 74 | 75 | total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 76 | fps = video.get(cv2.CAP_PROP_FPS) 77 | 78 | progress_bar = st.progress(0) # Create a progress bar 79 | start_time = time.time() 80 | time_text = st.text("Time Remaining: ") # Initialize text value 81 | 82 | for _ in tqdm(range(total_frames), unit="frame", desc="Progress"): 83 | ret, frame = video.read() 84 | if not ret: 85 | break 86 | 87 | colorized_frame = colorize_frame(frame, loaded_model) 88 | output_frames.append((colorized_frame * 255).astype(np.uint8)) 89 | 90 | elapsed_time = time.time() - start_time 91 | frames_completed = len(output_frames) 92 | frames_remaining = total_frames - frames_completed 93 | time_remaining = (frames_remaining / frames_completed) * elapsed_time 94 | 95 | progress_bar.progress(frames_completed / total_frames) # Update progress bar 96 | 97 | if frames_completed < total_frames: 98 | time_text.text(f"Time Remaining: {format_time(time_remaining)}") # Update text value 99 | else: 100 | time_text.empty() # Remove text value 101 | progress_bar.empty() 102 | 103 | with st.spinner("Merging frames to video..."): 104 | frame_size = output_frames[0].shape[:2] 105 | output_filename = "output.mp4" 106 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") # Codec for MP4 video 107 | out = cv2.VideoWriter(output_filename, fourcc, fps, (frame_size[1], frame_size[0])) 108 | 109 | # Display the colorized video using st.video 110 | for frame in output_frames: 111 | frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 112 | 113 | out.write(frame_bgr) 114 | 115 | out.release() 116 | 117 | # Convert the output video to a format compatible with Streamlit 118 | converted_filename = "converted_output.mp4" 119 | clip = mp.VideoFileClip(output_filename) 120 | clip = clip.set_audio(audio) 121 | 122 | clip.write_videofile(converted_filename, codec="libx264") 123 | 124 | # Display the converted video using st.video() 125 | st.video(converted_filename) 126 | st.balloons() 127 | 128 | # Add a download button for the colorized video 129 | st.download_button( 130 | label="Download Colorized Video", 131 | data=open(converted_filename, "rb").read(), 132 | file_name="colorized_video.mp4", 133 | ) 134 | 135 | # Close and delete the temporary file after processing 136 | video.release() 137 | 138 | 139 | if __name__ == "__main__": 140 | main() 141 | st.markdown( 142 | "###### Made with :heart: by [Clément Delteil](https://www.linkedin.com/in/clementdelteil/) [![this is an " 143 | "image link](https://i.imgur.com/thJhzOO.png)](https://www.buymeacoffee.com/clementdelteil)" 144 | ) 145 | -------------------------------------------------------------------------------- /pages/03_🖼️_Input_Images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | 4 | import streamlit as st 5 | from PIL import Image 6 | 7 | from utils import change_model, load_model, setup_columns, set_page_config, colorize_image 8 | 9 | set_page_config() 10 | loaded_model = load_model() 11 | col2 = setup_columns() 12 | current_model = None 13 | 14 | with col2: 15 | st.write( 16 | """ 17 | ## B&W Images Colorizer 18 | ##### Input a black and white image and get a colorized version of it. 19 | ###### ➠ If you want to colorize multiple images just upload them all at once. 20 | ###### ➠ Uploading already colored images won't raise errors but images won't look good. 21 | ###### ➠ I recommend starting with the first model and then experimenting with the second one.""" 22 | ) 23 | 24 | 25 | def main(): 26 | """ 27 | Main function 28 | """ 29 | model = st.selectbox( 30 | "Select Model (Both models have their pros and cons, " 31 | "I recommend trying both and keeping the best for you task)", 32 | ["ECCV16", "SIGGRAPH17"], 33 | index=0, 34 | ) 35 | 36 | # Make the user select a model 37 | loaded_model = change_model(current_model, model) 38 | st.write(f"Model is now {model}") 39 | 40 | # Ask the user if he wants to see colorization 41 | display_results = st.checkbox("Display results in real time", value=True) 42 | 43 | # Input for the user to upload images 44 | uploaded_file = st.file_uploader( 45 | "Upload your images here...", type=["jpg", "png", "jpeg"], accept_multiple_files=True 46 | ) 47 | 48 | # If the user clicks on the button 49 | if st.button("Colorize"): 50 | # If the user uploaded images 51 | if uploaded_file is not None: 52 | if display_results: 53 | col1, col2 = st.columns([0.5, 0.5]) 54 | with col1: 55 | st.markdown('

Before

', unsafe_allow_html=True) 56 | with col2: 57 | st.markdown('

After

', unsafe_allow_html=True) 58 | else: 59 | col1, col2, _ = st.columns(3) 60 | 61 | for i, file in enumerate(uploaded_file): 62 | file_extension = os.path.splitext(file.name)[1].lower() 63 | if file_extension in [".jpg", ".png", ".jpeg"]: 64 | image = Image.open(file) 65 | if display_results: 66 | with col1: 67 | st.image(image, use_column_width="always") 68 | with col2: 69 | with st.spinner("Colorizing image..."): 70 | out_img, new_img = colorize_image(file, loaded_model) 71 | new_img.save("IMG_" + str(i + 1) + ".jpg") 72 | st.image(out_img, use_column_width="always") 73 | 74 | else: 75 | out_img, new_img = colorize_image(file, loaded_model) 76 | new_img.save("IMG_" + str(i + 1) + ".jpg") 77 | 78 | if len(uploaded_file) > 1: 79 | # Create a zip file 80 | zip_filename = "colorized_images.zip" 81 | with zipfile.ZipFile(zip_filename, "w") as zip_file: 82 | # Add colorized images to the zip file 83 | for i in range(len(uploaded_file)): 84 | zip_file.write("IMG_" + str(i + 1) + ".jpg", "IMG_" + str(i) + ".jpg") 85 | with col2: 86 | # Provide the zip file data for download 87 | st.download_button( 88 | label="Download Colorized Images" if len(uploaded_file) > 1 else "Download Colorized Image", 89 | data=open(zip_filename, "rb").read(), 90 | file_name=zip_filename, 91 | ) 92 | else: 93 | with col2: 94 | st.download_button( 95 | label="Download Colorized Image", 96 | data=open("IMG_1.jpg", "rb").read(), 97 | file_name="IMG_1.jpg", 98 | ) 99 | 100 | else: 101 | st.warning("Upload a file", icon="⚠️") 102 | 103 | 104 | if __name__ == "__main__": 105 | main() 106 | st.markdown( 107 | "###### Made with :heart: by [Clément Delteil](https://www.linkedin.com/in/clementdelteil/) [![this is an " 108 | "image link](https://i.imgur.com/thJhzOO.png)](https://www.buymeacoffee.com/clementdelteil)" 109 | ) 110 | -------------------------------------------------------------------------------- /pages/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Wazzabeee/image-video-colorization/069f937320c1cbcd04579724048228600596788a/pages/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipython==8.10.0 2 | moviepy==1.0.3 3 | numpy==1.23.2 4 | opencv_python==4.8.1.78 5 | Pillow==10.3.0 6 | scikit-image==0.20.0 7 | streamlit==1.30.0 8 | torch==1.13.1 9 | streamlit_lottie==0.0.5 10 | requests==2.31.0 11 | tqdm==4.64.1 12 | pytube==15.0.0 13 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pylint] 2 | disable=C0103,C0114,R0913,R0914,C0200,C0301,E1101 3 | 4 | [pep8] 5 | max-line-length=120 6 | ignore=E121,E123,E126,E226,E24,E704,E203,W503 7 | exclude=venv,test_env,test_venv 8 | 9 | [flake8] 10 | max-line-length=120 11 | ignore=E121,E123,E126,E226,E24,E704,E203,W503 12 | exclude=venv,test_env,test_venv -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import requests 3 | import streamlit as st 4 | from PIL import Image 5 | from streamlit_lottie import st_lottie 6 | 7 | from models.deep_colorization import eccv16 8 | from models.deep_colorization import siggraph17 9 | from models.deep_colorization import postprocess_tens, preprocess_img, load_img 10 | 11 | 12 | class SameModelException(ValueError): 13 | """Exception raised when the same model is attempted to be reloaded.""" 14 | 15 | 16 | def set_page_config(): 17 | """ 18 | Sets up the page config. 19 | """ 20 | st.set_page_config(page_title="Image & Video Colorizer", page_icon="🎨", layout="wide") 21 | 22 | 23 | def load_model(): 24 | """ 25 | Loads the default model. 26 | """ 27 | return eccv16(pretrained=True).eval() 28 | 29 | 30 | def setup_columns(): 31 | """ 32 | Sets up the columns. 33 | """ 34 | col1, col2 = st.columns([1, 3]) 35 | lottie = load_lottieurl("https://assets5.lottiefiles.com/packages/lf20_RHdEuzVfEL.json") 36 | with col1: 37 | st_lottie(lottie) 38 | return col2 39 | 40 | 41 | # Define a function that we can use to load lottie files from a link. 42 | @st.cache_data() 43 | def load_lottieurl(url: str): 44 | """ 45 | Load lottieurl image 46 | """ 47 | try: 48 | r = requests.get(url, timeout=10) # Timeout set to 10 seconds 49 | r.raise_for_status() # This will raise an exception for HTTP errors 50 | return r.json() 51 | except requests.RequestException as e: 52 | print(f"Request failed: {e}") 53 | return None 54 | 55 | 56 | @st.cache_resource() 57 | def change_model(current_model, model): 58 | """ 59 | Change model 60 | """ 61 | loaded_model = "None" 62 | 63 | if current_model != model: 64 | if model == "ECCV16": 65 | loaded_model = eccv16(pretrained=True).eval() 66 | elif model == "SIGGRAPH17": 67 | loaded_model = siggraph17(pretrained=True).eval() 68 | return loaded_model 69 | 70 | raise SameModelException("Model is the same as the current one.") 71 | 72 | 73 | def format_time(seconds: float) -> str: 74 | """Formats time in seconds to a human readable format""" 75 | if seconds < 60: 76 | return f"{int(seconds)} seconds" 77 | if seconds < 3600: 78 | minutes = seconds // 60 79 | seconds %= 60 80 | return f"{minutes} minutes and {int(seconds)} seconds" 81 | if seconds < 86400: 82 | hours = seconds // 3600 83 | minutes = (seconds % 3600) // 60 84 | seconds %= 60 85 | return f"{hours} hours, {minutes} minutes, and {int(seconds)} seconds" 86 | 87 | days = seconds // 86400 88 | hours = (seconds % 86400) // 3600 89 | minutes = (seconds % 3600) // 60 90 | seconds %= 60 91 | return f"{days} days, {hours} hours, {minutes} minutes, and {int(seconds)} seconds" 92 | 93 | 94 | # Function to colorize video frames 95 | def colorize_frame(frame, colorizer) -> np.ndarray: 96 | """ 97 | Colorize frame 98 | """ 99 | tens_l_orig, tens_l_rs = preprocess_img(frame, HW=(256, 256)) 100 | return postprocess_tens(tens_l_orig, colorizer(tens_l_rs).cpu()) 101 | 102 | 103 | def colorize_image(file, loaded_model): 104 | """ 105 | Colorize image 106 | """ 107 | img = load_img(file) 108 | # If user input a colored image with 4 channels, discard the fourth channel 109 | if img.shape[2] == 4: 110 | img = img[:, :, :3] 111 | 112 | tens_l_orig, tens_l_rs = preprocess_img(img, HW=(256, 256)) 113 | out_img = postprocess_tens(tens_l_orig, loaded_model(tens_l_rs).cpu()) 114 | new_img = Image.fromarray((out_img * 255).astype(np.uint8)) 115 | 116 | return out_img, new_img 117 | --------------------------------------------------------------------------------