├── .gitignore ├── IO.py ├── LICENSE ├── README.md ├── bubble_artifacts.jpg ├── custom_transforms.py ├── data └── download_data.py ├── dataset.py ├── ffmpeg_tools.py ├── frn.py ├── lib.py ├── model.py ├── requirements.txt ├── style_video.py ├── styles ├── mosaic.jpg ├── mosaic_2.jpg └── starry_night.jpg ├── test_image.jpeg ├── train.py ├── utils.py └── vgg.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ -------------------------------------------------------------------------------- /IO.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import numpy as np 4 | 5 | 6 | def readPFM(file): 7 | file = open(file, 'rb') 8 | 9 | color = None 10 | width = None 11 | height = None 12 | scale = None 13 | endian = None 14 | 15 | header = file.readline().rstrip() 16 | if header.decode("ascii") == 'PF': 17 | color = True 18 | elif header.decode("ascii") == 'Pf': 19 | color = False 20 | else: 21 | raise Exception('Not a PFM file.') 22 | 23 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) 24 | if dim_match: 25 | width, height = list(map(int, dim_match.groups())) 26 | else: 27 | raise Exception('Malformed PFM header.') 28 | 29 | scale = float(file.readline().decode("ascii").rstrip()) 30 | if scale < 0: # little-endian 31 | endian = '<' 32 | scale = -scale 33 | else: 34 | endian = '>' # big-endian 35 | 36 | data = np.fromfile(file, endian + 'f') 37 | shape = (height, width, 3) if color else (height, width) 38 | 39 | data = np.reshape(data, shape) 40 | data = np.flipud(data) 41 | return data, scale 42 | 43 | 44 | def readFlow(name): 45 | if name.endswith('.pfm') or name.endswith('.PFM'): 46 | return readPFM(name)[0][:, :, 0:2] 47 | 48 | f = open(name, 'rb') 49 | 50 | header = f.read(4) 51 | if header.decode("utf-8") != 'PIEH': 52 | raise Exception('Flow file header does not contain PIEH') 53 | 54 | width = np.fromfile(f, np.int32, 1).squeeze() 55 | height = np.fromfile(f, np.int32, 1).squeeze() 56 | 57 | flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2)) 58 | 59 | return flow.astype(np.float32) 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Nikita Gryaznov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-reconet 2 | This is PyTorch implementation of 3 | "[ReCoNet: Real-time Coherent Video Style Transfer Network](https://arxiv.org/abs/1807.01197)" paper. 4 | 5 | This model allows to perform style transfer on videos in real-time and preserve temporal consistency between frames. 6 | 7 | ## Training 8 | To train a model: 9 | 10 | 1. Run `python ./data/download_data.py` to download data. 11 | This may take about a day and you need to have >1TB of free space on disk. 12 | You will also need [aria2](https://aria2.github.io/) installed 13 | 2. Install python dependencies via `pip install -r requirements.txt` 14 | 3. Run `python train.py style_image.jpg` to train model with style from some `style_image.jpg`. 15 | This script supports several additional arguments that you can find using `python train.py -h` 16 | 17 | ## Inference 18 | 19 | There are two options for inference: 20 | 21 | 1. There is a programming interface in `lib.py` file. 22 | It contains `ReCoNetModel` class that provides `run` method 23 | that accepts a batch of images as 4-D uint8 NHWC RGB numpy tensor and stylizes it 24 | 2. There is a `style_video.py` file to style videos. Run it as 25 | `python style_video.py input.mp4 output.mp4 model.pth`. It also supports some additional arguments. 26 | Note that you will need `ffmpeg` to be installed on your machine to run this script 27 | 28 | Pre-trained on `./styles/mosaic_2.jpg` model can be downloaded from here: 29 | https://drive.google.com/open?id=1MUPb7qf3QWEixZ6daGGI4lVFGmQl0qna 30 | 31 | Example video with this model: 32 | https://youtu.be/rEJrNL_2Lfs 33 | 34 | ## Bubble Artifacts 35 | 36 | Training model as described in paper leads to bubble artifacts 37 | 38 | ![Bubble artifacts](https://github.com/EmptySamurai/pytorch-reconet/blob/master/bubble_artifacts.jpg?raw=true) 39 | 40 | This issue was addressed in [StyleGAN2 paper](https://arxiv.org/abs/1912.04958) by NVIDIA team. 41 | They discovered that artifacts appear because of Instance Normalization. 42 | They also proposed a novel normalization method, but unfortunately it doesn't work good with ReCoNet architecture — 43 | either style and content losses didn't converge or some blurry artifacts appeared. 44 | 45 | Instead of that in this implementation a [Filter Response Normalization with Thresholded Linear Unit](https://arxiv.org/abs/1911.09737) can be used. 46 | It acts similar to Instance Normalization but preserves mean values in some sense. 47 | This normalization leads to the same results as original architecture, but lacks bubble artifacts. 48 | Every script and class supports `frn` argument that enables Filter Response Normalization instead of Instance Normalization and also replaces ReLU by TLU. 49 | 50 | Pre-trained on `./styles/mosaic_2.jpg` model with FRN can be downloaded from here: 51 | https://drive.google.com/open?id=1T7P5w_V5cMumeEoXs3WFituiiVGhGb3H 52 | 53 | ## Notes 54 | 55 | 1. In this implementation loss weights differ from ones in the paper, 56 | since weights in the paper didn't work. This is probably 57 | due to different image scale and losses normalization constants 58 | 2. Testing using MPI Sintel Dataset is not implemented 59 | -------------------------------------------------------------------------------- /bubble_artifacts.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmptySamurai/pytorch-reconet/a8e7e377f93b508f50a035983bdc8f83689ac097/bubble_artifacts.jpg -------------------------------------------------------------------------------- /custom_transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from torchvision import transforms 4 | import torch 5 | import cv2 6 | import numpy as np 7 | from PIL import Image 8 | 9 | 10 | class ToTensor: 11 | def __call__(self, sample): 12 | return { 13 | "frame": self.image_to_tensor(sample["frame"]), 14 | "previous_frame": self.image_to_tensor(sample["previous_frame"]), 15 | "optical_flow": torch.from_numpy(sample["optical_flow"]), 16 | "reverse_optical_flow": torch.from_numpy(sample["reverse_optical_flow"]), 17 | "motion_boundaries": torch.from_numpy(np.array(sample["motion_boundaries"]).astype(np.bool)), 18 | "index": sample["index"] 19 | } 20 | 21 | @staticmethod 22 | def image_to_tensor(image): 23 | return transforms.ToTensor()(image) 24 | 25 | 26 | class Resize: 27 | 28 | def __init__(self, new_width, new_height): 29 | self.new_width = new_width 30 | self.new_height = new_height 31 | 32 | def resize_image(self, image): 33 | return image.resize((self.new_width, self.new_height)) 34 | 35 | def resize_optical_flow(self, optical_flow): 36 | orig_height, orig_width = optical_flow.shape[:2] 37 | optical_flow_resized = cv2.resize(optical_flow, (self.new_width, self.new_height)) 38 | h_scale, w_scale = self.new_height / orig_height, self.new_width / orig_width 39 | optical_flow_resized[..., 0] *= w_scale 40 | optical_flow_resized[..., 1] *= h_scale 41 | return optical_flow_resized 42 | 43 | def __call__(self, sample): 44 | return { 45 | "frame": self.resize_image(sample["frame"]), 46 | "previous_frame": self.resize_image(sample["previous_frame"]), 47 | "optical_flow": self.resize_optical_flow(sample["optical_flow"]), 48 | "reverse_optical_flow": self.resize_optical_flow(sample["reverse_optical_flow"]), 49 | "motion_boundaries": self.resize_image(sample["motion_boundaries"]), 50 | "index": sample["index"] 51 | } 52 | 53 | 54 | class RandomHorizontalFlip: 55 | 56 | def __init__(self, p=0.5): 57 | self.p = p 58 | 59 | @staticmethod 60 | def flip_image(image): 61 | return image.transpose(Image.FLIP_LEFT_RIGHT) 62 | 63 | @staticmethod 64 | def flip_optical_flow(optical_flow): 65 | optical_flow = np.flip(optical_flow, axis=1).copy() 66 | optical_flow[..., 0] *= -1 67 | return optical_flow 68 | 69 | def __call__(self, sample): 70 | if random.random() < self.p: 71 | return { 72 | "frame": self.flip_image(sample["frame"]), 73 | "previous_frame": self.flip_image(sample["previous_frame"]), 74 | "optical_flow": self.flip_optical_flow(sample["optical_flow"]), 75 | "reverse_optical_flow": self.flip_optical_flow(sample["reverse_optical_flow"]), 76 | "motion_boundaries": self.flip_image(sample["motion_boundaries"]), 77 | "index": sample["index"] 78 | } 79 | else: 80 | return sample 81 | -------------------------------------------------------------------------------- /data/download_data.py: -------------------------------------------------------------------------------- 1 | from subprocess import check_call 2 | from os import remove 3 | from shutil import unpack_archive 4 | import os 5 | 6 | 7 | def download(url, directory): 8 | check_call(["aria2c", 9 | "--check-certificate=false", 10 | "--allow-overwrite=true", 11 | "--seed-time=0", 12 | "--follow-torrent=mem", 13 | "-d", directory, 14 | url]) 15 | 16 | 17 | def download_unpack_delete(url, directory, filename=None): 18 | if filename is None: 19 | filename = url.split("/")[-1] 20 | print(f"Downloading '{filename}'") 21 | download(url, directory) 22 | print(f"'{filename}' is downloaded") 23 | archive_path = os.path.join(directory, filename) 24 | unpack_archive(archive_path, extract_dir=directory) 25 | print(f"'{filename}' is unpacked") 26 | remove(archive_path) 27 | print(f"'{filename}' is cleaned") 28 | 29 | 30 | if __name__ == "__main__": 31 | download_unpack_delete( 32 | "http://academictorrents.com/download/48e5e770aa8469c0826ae322209cdc0ac115a385.torrent", 33 | "flyingthings3d", 34 | "flyingthings3d__frames_finalpass.tar" 35 | ) 36 | 37 | download_unpack_delete( 38 | "http://academictorrents.com/download/93a54256fe2f56dea2c7d247af11d9affa06a06d.torrent", 39 | "flyingthings3d", 40 | "flyingthings3d__optical_flow.tar.bz2" 41 | ) 42 | 43 | download_unpack_delete( 44 | "https://lmb.informatik.uni-freiburg.de/data/SceneFlowDatasets_CVPR16/Release_april16/data/FlyingThings3D/derived_data/flyingthings3d__motion_boundaries.tar.bz2", 45 | "flyingthings3d", 46 | ) 47 | 48 | download_unpack_delete( 49 | "https://lmb.informatik.uni-freiburg.de/data/SceneFlowDatasets_CVPR16/Release_april16/data/Monkaa/raw_data/monkaa__frames_finalpass.tar", 50 | "monkaa", 51 | ) 52 | 53 | download_unpack_delete( 54 | "https://lmb.informatik.uni-freiburg.de/data/SceneFlowDatasets_CVPR16/Release_april16/data/Monkaa/derived_data/monkaa__optical_flow.tar.bz2", 55 | "monkaa", 56 | ) 57 | 58 | download_unpack_delete( 59 | "https://lmb.informatik.uni-freiburg.de/data/SceneFlowDatasets_CVPR16/Release_april16/data/Monkaa/derived_data/monkaa__motion_boundaries.tar.bz2", 60 | "monkaa", 61 | ) 62 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import namedtuple 3 | from abc import abstractmethod 4 | 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | 8 | from IO import readFlow 9 | 10 | SceneFlowEntry = namedtuple("SceneFlowEntry", 11 | ("frame", "previous_frame", "optical_flow", "reverse_optical_flow", "motion_boundaries")) 12 | 13 | 14 | class SceneFlowDataset(Dataset): 15 | def __init__(self, root_dir, transform, use_both_sides=False): 16 | self.entries = list(self.iterate_entries(root_dir, use_both_sides)) 17 | self.transform = transform 18 | 19 | @abstractmethod 20 | def iterate_entries(self, root_dir, use_both_sides): 21 | pass 22 | 23 | def __getitem__(self, index): 24 | entry = self.entries[index] 25 | sample = { 26 | "frame": Image.open(entry.frame).convert("RGB"), 27 | "previous_frame": Image.open(entry.previous_frame).convert("RGB"), 28 | "optical_flow": readFlow(entry.optical_flow).copy(), 29 | "reverse_optical_flow": readFlow(entry.reverse_optical_flow).copy(), 30 | "motion_boundaries": Image.open(entry.motion_boundaries), 31 | "index": index 32 | } 33 | 34 | sample = self.transform(sample) 35 | return sample 36 | 37 | def __len__(self): 38 | return len(self.entries) 39 | 40 | @staticmethod 41 | def frame_number(filename): 42 | return os.path.splitext(filename)[0] 43 | 44 | @staticmethod 45 | def side_letter(side): 46 | return side[0].upper() 47 | 48 | 49 | class MonkaaDataset(SceneFlowDataset): 50 | def iterate_entries(self, root_dir, use_both_sides): 51 | for dirpath, dirnames, filenames in os.walk(os.path.join(root_dir, "frames_finalpass")): 52 | if len(filenames) == 0: 53 | continue 54 | 55 | scene, side = dirpath.split(os.sep)[-2:] 56 | 57 | if not use_both_sides and side != "left": 58 | continue 59 | 60 | filenames.sort() 61 | filenames = [filename for filename in filenames if filename.endswith(".png")] 62 | 63 | for i in range(1, len(filenames)): 64 | yield SceneFlowEntry( 65 | os.path.join(dirpath, filenames[i]), 66 | os.path.join(dirpath, filenames[i - 1]), 67 | self.forward_optical_flow_path(root_dir, scene, side, self.frame_number(filenames[i - 1])), 68 | self.reverse_optical_flow_path(root_dir, scene, side, self.frame_number(filenames[i])), 69 | self.motion_boundaries_path(root_dir, scene, side, self.frame_number(filenames[i])) 70 | ) 71 | 72 | def forward_optical_flow_path(self, root, scene, side, frame_number): 73 | return os.path.join(root, "optical_flow", scene, "into_future", side, 74 | f"OpticalFlowIntoFuture_{frame_number}_{self.side_letter(side)}.pfm") 75 | 76 | def reverse_optical_flow_path(self, root, scene, side, frame_number): 77 | return os.path.join(root, "optical_flow", scene, "into_past", side, 78 | f"OpticalFlowIntoPast_{frame_number}_{self.side_letter(side)}.pfm") 79 | 80 | def motion_boundaries_path(self, root, scene, side, frame_number): 81 | return os.path.join(root, "motion_boundaries", scene, "into_past", side, 82 | f"{frame_number}.pgm") 83 | 84 | 85 | class FlyingThings3DDataset(SceneFlowDataset): 86 | def iterate_entries(self, root_dir, use_both_sides): 87 | for dirpath, dirnames, filenames in os.walk(os.path.join(root_dir, "frames_finalpass")): 88 | if len(filenames) == 0: 89 | continue 90 | 91 | part, subset, scene, side = dirpath.split(os.sep)[-4:] 92 | 93 | if not use_both_sides and side != "left": 94 | continue 95 | 96 | filenames.sort() 97 | filenames = [filename for filename in filenames if filename.endswith(".png")] 98 | 99 | for i in range(1, len(filenames)): 100 | yield SceneFlowEntry( 101 | os.path.join(dirpath, filenames[i]), 102 | os.path.join(dirpath, filenames[i - 1]), 103 | self.forward_optical_flow_path(root_dir, part, subset, scene, side, 104 | self.frame_number(filenames[i - 1])), 105 | self.reverse_optical_flow_path(root_dir, part, subset, scene, side, 106 | self.frame_number(filenames[i])), 107 | self.motion_boundaries_path(root_dir, part, subset, scene, side, self.frame_number(filenames[i])) 108 | ) 109 | 110 | def forward_optical_flow_path(self, root, part, subset, scene, side, frame_number): 111 | return os.path.join(root, "optical_flow", part, subset, scene, "into_future", side, 112 | f"OpticalFlowIntoFuture_{frame_number}_{self.side_letter(side)}.pfm") 113 | 114 | def reverse_optical_flow_path(self, root, part, subset, scene, side, frame_number): 115 | return os.path.join(root, "optical_flow", part, subset, scene, "into_past", side, 116 | f"OpticalFlowIntoPast_{frame_number}_{self.side_letter(side)}.pfm") 117 | 118 | def motion_boundaries_path(self, root, part, subset, scene, side, frame_number): 119 | return os.path.join(root, "motion_boundaries", part, subset, scene, "into_past", side, 120 | f"{frame_number}.pgm") 121 | -------------------------------------------------------------------------------- /ffmpeg_tools.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | import shlex 3 | import json 4 | from subprocess import Popen, DEVNULL, PIPE, check_output 5 | 6 | import numpy as np 7 | 8 | 9 | def _check_wait(p): 10 | status = p.wait() 11 | if status != 0: 12 | raise Exception("{} returned non-zero status {}".format(p.args, status)) 13 | 14 | 15 | def _default_param(value, default_value): 16 | return default_value if value is None else value 17 | 18 | 19 | def _ffprobe(file, cmd): 20 | cmd = "{cmd} -loglevel fatal -print_format json -show_format -show_streams {file}".format(cmd=cmd, file=file) 21 | output = check_output(shlex.split(cmd)) 22 | return json.loads(output) 23 | 24 | 25 | def fraction(s): 26 | if s is None: 27 | return None 28 | num, den = s.split("/") 29 | return int(num) / int(den) 30 | 31 | 32 | class _VideoIterator: 33 | 34 | def __init__(self, reader): 35 | self._reader = reader 36 | self._closed = False 37 | 38 | cmd = [] 39 | cmd.append("{cmd} -loglevel error -y -nostdin".format(cmd=reader.ffmpeg_cmd)) 40 | cmd.append("-i {file}".format(file=reader.filepath)) 41 | if self._reader.fps is not None: 42 | cmd.append("-filter fps=fps={fps}:round=up".format(fps=self._reader.fps)) 43 | cmd.append("-f rawvideo -pix_fmt {pix_fmt} pipe:".format(pix_fmt=reader.format + '24')) 44 | cmd = " ".join(cmd) 45 | 46 | self._ffmpeg_output = Popen(shlex.split(cmd), stdout=PIPE, stdin=DEVNULL) 47 | 48 | def __next__(self): 49 | frame_size = self._reader.width * self._reader.height * 3 50 | in_bytes = self._ffmpeg_output.stdout.read(frame_size) 51 | 52 | assert len(in_bytes) == 0 or len(in_bytes) == frame_size 53 | 54 | if len(in_bytes) == 0: 55 | self._close() 56 | raise StopIteration() 57 | 58 | return np.frombuffer(in_bytes, np.uint8).reshape([self._reader.height, self._reader.width, 3]) 59 | 60 | def __iter__(self): 61 | return self 62 | 63 | def __del__(self): 64 | self._close() 65 | 66 | def _close(self): 67 | if self._closed: 68 | return 69 | else: 70 | self._closed = True 71 | self._ffmpeg_output.kill() 72 | 73 | 74 | class VideoReader: 75 | 76 | def __init__(self, filepath, fps=None, format='rgb', ffmpeg_cmd="ffmpeg", ffprobe_cmd="ffprobe"): 77 | probe = self.probe = _ffprobe(filepath, cmd=ffprobe_cmd) 78 | stream = next((stream for stream in probe["streams"] if stream['codec_type'] == "video")) 79 | 80 | self.width = int(stream["width"]) 81 | self.height = int(stream["height"]) 82 | # FPS from ffprobe can be sometimes be incorrect, so it's better to specify FPS manually 83 | self.fps = fps or fraction(stream.get("r_frame_rate")) 84 | self.duration = float(stream["duration"]) 85 | # self.frames_count = int(ceil(self.duration * fps)) 86 | 87 | self.filepath = filepath 88 | self.format = format 89 | self.ffmpeg_cmd = ffmpeg_cmd 90 | self.ffprobe_cmd = ffprobe_cmd 91 | 92 | def __iter__(self): 93 | return _VideoIterator(self) 94 | 95 | 96 | class VideoWriter: 97 | 98 | def __init__(self, 99 | filepath, 100 | input_width, 101 | input_height, 102 | input_fps, 103 | input_format="rgb", 104 | output_width=None, 105 | output_height=None, 106 | output_format="yuv420p", 107 | ffmpeg_cmd="ffmpeg"): 108 | self.filepath = filepath 109 | self.input_width = input_width 110 | self.input_height = input_height 111 | self.input_fps = input_fps 112 | self.input_format = input_format 113 | self.output_width = output_width 114 | self.output_height = output_height 115 | self.output_format = output_format 116 | self.ffmpeg_cmd = ffmpeg_cmd 117 | 118 | def __enter__(self): 119 | cmd = [] 120 | cmd.append("{cmd} -y -loglevel error".format(cmd=self.ffmpeg_cmd)) 121 | cmd.append("-f rawvideo -pix_fmt {pix_fmt} -video_size {width}x{height} -framerate {fps} -i pipe:".format( 122 | pix_fmt=self.input_format + '24', 123 | width=self.input_width, 124 | height=self.input_height, 125 | fps=self.input_fps 126 | )) 127 | cmd.append("-pix_fmt {pix_fmt}".format(pix_fmt=self.output_format)) 128 | if self.output_width is not None and self.output_height is not None: 129 | cmd.append("-s {width}x{height}".format(width=self.output_width, height=self.output_height)) 130 | 131 | cmd.append(self.filepath) 132 | cmd = " ".join(cmd) 133 | 134 | self._ffmpeg_output = Popen(shlex.split(cmd), stdin=PIPE) 135 | return self 136 | 137 | def write(self, frame): 138 | assert frame.dtype == np.uint8 and frame.ndim == 3 and frame.shape == (self.input_height, self.input_width, 3) 139 | self._ffmpeg_output.stdin.write(frame.tobytes()) 140 | 141 | def __exit__(self, exc_type, exc_val, exc_tb): 142 | self._ffmpeg_output.stdin.close() 143 | _check_wait(self._ffmpeg_output) 144 | -------------------------------------------------------------------------------- /frn.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/yukkyo/PyTorch-FilterResponseNormalizationLayer 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class TLU(nn.Module): 8 | def __init__(self, num_features): 9 | """max(y, tau) = max(y - tau, 0) + tau = ReLU(y - tau) + tau""" 10 | super(TLU, self).__init__() 11 | self.num_features = num_features 12 | self.tau = nn.parameter.Parameter( 13 | torch.Tensor(1, num_features, 1, 1), requires_grad=True) 14 | self.reset_parameters() 15 | 16 | def reset_parameters(self): 17 | nn.init.zeros_(self.tau) 18 | 19 | def extra_repr(self): 20 | return 'num_features={num_features}'.format(**self.__dict__) 21 | 22 | def forward(self, x): 23 | return torch.max(x, self.tau) 24 | 25 | 26 | class FRN(nn.Module): 27 | def __init__(self, num_features, eps=1e-6, is_eps_leanable=False): 28 | """ 29 | weight = gamma, bias = beta 30 | beta, gamma: 31 | Variables of shape [1, 1, 1, C]. if TensorFlow 32 | Variables of shape [1, C, 1, 1]. if PyTorch 33 | eps: A scalar constant or learnable variable. 34 | """ 35 | super(FRN, self).__init__() 36 | 37 | self.num_features = num_features 38 | self.init_eps = eps 39 | self.is_eps_leanable = is_eps_leanable 40 | 41 | self.weight = nn.parameter.Parameter( 42 | torch.Tensor(1, num_features, 1, 1), requires_grad=True) 43 | self.bias = nn.parameter.Parameter( 44 | torch.Tensor(1, num_features, 1, 1), requires_grad=True) 45 | if is_eps_leanable: 46 | self.eps = nn.parameter.Parameter(torch.Tensor(1), requires_grad=True) 47 | else: 48 | self.register_buffer('eps', torch.Tensor([eps])) 49 | self.reset_parameters() 50 | 51 | def reset_parameters(self): 52 | nn.init.ones_(self.weight) 53 | nn.init.zeros_(self.bias) 54 | if self.is_eps_leanable: 55 | nn.init.constant_(self.eps, self.init_eps) 56 | 57 | def extra_repr(self): 58 | return 'num_features={num_features}, eps={init_eps}'.format(**self.__dict__) 59 | 60 | def forward(self, x): 61 | """ 62 | 0, 1, 2, 3 -> (B, H, W, C) in TensorFlow 63 | 0, 1, 2, 3 -> (B, C, H, W) in PyTorch 64 | TensorFlow code 65 | nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True) 66 | x = x * tf.rsqrt(nu2 + tf.abs(eps)) 67 | # This Code include TLU function max(y, tau) 68 | return tf.maximum(gamma * x + beta, tau) 69 | """ 70 | # Compute the mean norm of activations per channel. 71 | nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True) 72 | 73 | # Perform FRN. 74 | x = x * torch.rsqrt(nu2 + self.eps.abs()) 75 | 76 | # Scale and Bias 77 | x = self.weight * x + self.bias 78 | return x 79 | -------------------------------------------------------------------------------- /lib.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | current_dir = os.path.dirname(__file__) 4 | sys.path.insert(0, current_dir) 5 | 6 | import numpy as np 7 | import torch 8 | 9 | from model import ReCoNet 10 | from utils import preprocess_for_reconet, postprocess_reconet, Dummy, nhwc_to_nchw, nchw_to_nhwc 11 | 12 | sys.path.remove(current_dir) 13 | 14 | 15 | class ReCoNetModel: 16 | 17 | def __init__(self, state_dict_path, use_gpu=True, gpu_device=None, frn=False): 18 | self.use_gpu = use_gpu 19 | self.gpu_device = gpu_device 20 | 21 | with self.device(): 22 | self.model = ReCoNet(frn=frn) 23 | self.model.load_state_dict(torch.load(state_dict_path)) 24 | self.model = self.to_device(self.model) 25 | self.model.eval() 26 | 27 | def run(self, images): 28 | assert images.dtype == np.uint8 29 | assert 3 <= images.ndim <= 4 30 | 31 | orig_ndim = images.ndim 32 | if images.ndim == 3: 33 | images = images[None, ...] 34 | 35 | images = torch.from_numpy(images) 36 | images = nhwc_to_nchw(images) 37 | images = images.to(torch.float32) / 255 38 | 39 | with self.device(): 40 | with torch.no_grad(): 41 | images = self.to_device(images) 42 | images = preprocess_for_reconet(images) 43 | styled_images = self.model(images) 44 | styled_images = postprocess_reconet(styled_images) 45 | styled_images = styled_images.cpu() 46 | styled_images = torch.clamp(styled_images * 255, 0, 255).to(torch.uint8) 47 | styled_images = nchw_to_nhwc(styled_images) 48 | styled_images = styled_images.numpy() 49 | if orig_ndim == 3: 50 | styled_images = styled_images[0] 51 | return styled_images 52 | 53 | def to_device(self, x): 54 | if self.use_gpu: 55 | with self.device(): 56 | return x.cuda() 57 | else: 58 | return x 59 | 60 | def device(self): 61 | if self.use_gpu and self.gpu_device is not None: 62 | return torch.cuda.device(self.gpu_device) 63 | else: 64 | return Dummy() 65 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from frn import FRN, TLU 3 | 4 | 5 | class ConvLayer(nn.Module): 6 | def __init__(self, in_channels, out_channels, kernel_size, stride): 7 | super().__init__() 8 | 9 | self.layers = nn.Sequential( 10 | nn.ReflectionPad2d(kernel_size // 2), 11 | nn.Conv2d(in_channels, out_channels, kernel_size, stride), 12 | ) 13 | 14 | def forward(self, x): 15 | return self.layers(x) 16 | 17 | 18 | class ConvNormLayer(nn.Module): 19 | def __init__(self, in_channels, out_channels, kernel_size, stride, activation=True, frn=False): 20 | super().__init__() 21 | 22 | if frn: 23 | layers = [ 24 | ConvLayer(in_channels, out_channels, kernel_size, stride), 25 | FRN(out_channels), 26 | ] 27 | if activation: 28 | layers.append(TLU(out_channels)) 29 | else: 30 | layers = [ 31 | ConvLayer(in_channels, out_channels, kernel_size, stride), 32 | nn.InstanceNorm2d(out_channels, affine=True), 33 | ] 34 | if activation: 35 | layers.append(nn.ReLU(inplace=True)) 36 | 37 | self.layers = nn.Sequential(*layers) 38 | 39 | def forward(self, x): 40 | return self.layers(x) 41 | 42 | 43 | class ResLayer(nn.Module): 44 | 45 | def __init__(self, in_channels, out_channels, kernel_size, frn=False): 46 | super().__init__() 47 | self.branch = nn.Sequential( 48 | ConvNormLayer(in_channels, out_channels, kernel_size, 1, frn=frn), 49 | ConvNormLayer(out_channels, out_channels, kernel_size, 1, activation=False, frn=frn) 50 | ) 51 | 52 | if frn: 53 | self.activation = TLU(out_channels) 54 | else: 55 | self.activation = nn.ReLU(inplace=True) 56 | 57 | def forward(self, x): 58 | x = x + self.branch(x) 59 | x = self.activation(x) 60 | return x 61 | 62 | 63 | class ConvTanhLayer(nn.Module): 64 | def __init__(self, in_channels, out_channels, kernel_size, stride): 65 | super().__init__() 66 | self.layers = nn.Sequential( 67 | ConvLayer(in_channels, out_channels, kernel_size, stride), 68 | nn.Tanh() 69 | ) 70 | 71 | def forward(self, x): 72 | return self.layers(x) 73 | 74 | 75 | class Encoder(nn.Module): 76 | def __init__(self, frn=False): 77 | super().__init__() 78 | self.layers = nn.Sequential( 79 | ConvNormLayer(3, 48, 9, 1, frn=frn), 80 | ConvNormLayer(48, 96, 3, 2, frn=frn), 81 | ConvNormLayer(96, 192, 3, 2, frn=frn), 82 | ResLayer(192, 192, 3, frn=frn), 83 | ResLayer(192, 192, 3, frn=frn), 84 | ResLayer(192, 192, 3, frn=frn), 85 | ResLayer(192, 192, 3, frn=frn) 86 | ) 87 | 88 | def forward(self, x): 89 | return self.layers(x) 90 | 91 | 92 | class Decoder(nn.Module): 93 | def __init__(self, frn=False): 94 | super().__init__() 95 | self.layers = nn.Sequential( 96 | nn.Upsample(scale_factor=2), 97 | ConvNormLayer(192, 96, 3, 1, frn=frn), 98 | nn.Upsample(scale_factor=2), 99 | ConvNormLayer(96, 48, 3, 1, frn=frn), 100 | ConvTanhLayer(48, 3, 9, 1) 101 | ) 102 | 103 | def forward(self, x): 104 | return self.layers(x) 105 | 106 | 107 | class ReCoNet(nn.Module): 108 | def __init__(self, frn=False): 109 | super().__init__() 110 | self.encoder = Encoder(frn=frn) 111 | self.decoder = Decoder(frn=frn) 112 | 113 | def forward(self, x): 114 | x = self.encoder(x) 115 | x = self.decoder(x) 116 | return x 117 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 1.4.0 2 | torchvision 3 | numpy 4 | opencv-python 5 | Pillow 6 | tensorboard >= 1.14 -------------------------------------------------------------------------------- /style_video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | 6 | from lib import ReCoNetModel 7 | from ffmpeg_tools import VideoReader, VideoWriter 8 | 9 | 10 | def create_folder_for_file(path): 11 | folder = os.path.dirname(path) 12 | os.makedirs(folder, exist_ok=True) 13 | 14 | 15 | if __name__ == "__main__": 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("input", help="Path to input video file") 19 | parser.add_argument("output", help="Path to output video file") 20 | parser.add_argument("model", help="Path to model file") 21 | parser.add_argument("--use-cpu", action='store_true', help="Use CPU instead of GPU") 22 | parser.add_argument("--gpu-device", type=int, default=None, help="GPU device index") 23 | parser.add_argument("--batch-size", type=int, default=2, help="Batch size") 24 | parser.add_argument("--fps", type=int, default=None, help="FPS of output video") 25 | parser.add_argument("--frn", action='store_true', help="Use Filter Response Normalization and TLU ") 26 | 27 | args = parser.parse_args() 28 | 29 | batch_size = args.batch_size 30 | 31 | model = ReCoNetModel(args.model, use_gpu=not args.use_cpu, gpu_device=args.gpu_device, frn=args.frn) 32 | 33 | reader = VideoReader(args.input, fps=args.fps) 34 | 35 | create_folder_for_file(args.output) 36 | writer = VideoWriter(args.output, reader.width, reader.height, reader.fps) 37 | 38 | with writer: 39 | batch = [] 40 | 41 | for frame in reader: 42 | batch.append(frame) 43 | 44 | if len(batch) == batch_size: 45 | batch = np.array(batch) 46 | for styled_frame in model.run(batch): 47 | writer.write(styled_frame) 48 | 49 | batch = [] 50 | 51 | if len(batch) != 0: 52 | batch = np.array(batch) 53 | for styled_frame in model.run(batch): 54 | writer.write(styled_frame) 55 | -------------------------------------------------------------------------------- /styles/mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmptySamurai/pytorch-reconet/a8e7e377f93b508f50a035983bdc8f83689ac097/styles/mosaic.jpg -------------------------------------------------------------------------------- /styles/mosaic_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmptySamurai/pytorch-reconet/a8e7e377f93b508f50a035983bdc8f83689ac097/styles/mosaic_2.jpg -------------------------------------------------------------------------------- /styles/starry_night.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmptySamurai/pytorch-reconet/a8e7e377f93b508f50a035983bdc8f83689ac097/styles/starry_night.jpg -------------------------------------------------------------------------------- /test_image.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EmptySamurai/pytorch-reconet/a8e7e377f93b508f50a035983bdc8f83689ac097/test_image.jpeg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from datetime import datetime 4 | 5 | import torch 6 | import torch.utils.data 7 | import torchvision 8 | from torchvision import transforms 9 | from PIL import Image 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | from model import ReCoNet 13 | from dataset import MonkaaDataset, FlyingThings3DDataset 14 | import custom_transforms 15 | from vgg import Vgg16 16 | from utils import \ 17 | warp_optical_flow, \ 18 | rgb_to_luminance, \ 19 | l2_squared, \ 20 | tensors_sum, \ 21 | resize_optical_flow, \ 22 | occlusion_mask_from_flow, \ 23 | gram_matrix, \ 24 | preprocess_for_reconet, \ 25 | preprocess_for_vgg, \ 26 | postprocess_reconet, \ 27 | RunningLossesContainer, \ 28 | Dummy 29 | 30 | 31 | def output_temporal_loss( 32 | input_frame, 33 | previous_input_frame, 34 | output_frame, 35 | previous_output_frame, 36 | reverse_optical_flow, 37 | occlusion_mask): 38 | input_diff = input_frame - warp_optical_flow(previous_input_frame, reverse_optical_flow) 39 | output_diff = output_frame - warp_optical_flow(previous_output_frame, reverse_optical_flow) 40 | luminance_input_diff = rgb_to_luminance(input_diff).unsqueeze_(1) 41 | 42 | n, c, h, w = input_frame.shape 43 | loss = l2_squared(occlusion_mask * (output_diff - luminance_input_diff)) / (h * w) 44 | return loss 45 | 46 | 47 | def feature_temporal_loss( 48 | feature_maps, 49 | previous_feature_maps, 50 | reverse_optical_flow, 51 | occlusion_mask): 52 | n, c, h, w = feature_maps.shape 53 | 54 | reverse_optical_flow_resized = resize_optical_flow(reverse_optical_flow, h, w) 55 | occlusion_mask_resized = torch.nn.functional.interpolate(occlusion_mask, size=(h, w), mode='nearest') 56 | 57 | feature_maps_diff = feature_maps - warp_optical_flow(previous_feature_maps, reverse_optical_flow_resized) 58 | loss = l2_squared(occlusion_mask_resized * feature_maps_diff) / (c * h * w) 59 | 60 | return loss 61 | 62 | 63 | def content_loss( 64 | content_feature_maps, 65 | style_feature_maps): 66 | n, c, h, w = content_feature_maps.shape 67 | 68 | return l2_squared(content_feature_maps - style_feature_maps) / (c * h * w) 69 | 70 | 71 | def style_loss( 72 | content_feature_maps, 73 | style_gram_matrices): 74 | loss = 0 75 | for content_fm, style_gm in zip(content_feature_maps, style_gram_matrices): 76 | loss += l2_squared(gram_matrix(content_fm) - style_gm) 77 | return loss 78 | 79 | 80 | def total_variation(y): 81 | return torch.sum(torch.abs(y[:, :, :, :-1] - y[:, :, :, 1:])) + \ 82 | torch.sum(torch.abs(y[:, :, :-1, :] - y[:, :, 1:, :])) 83 | 84 | 85 | def stylize_image(image, model): 86 | if isinstance(image, Image.Image): 87 | image = transforms.ToTensor()(image) 88 | image = image.cuda().unsqueeze_(0) 89 | image = preprocess_for_reconet(image) 90 | styled_image = model(image).squeeze() 91 | styled_image = postprocess_reconet(styled_image) 92 | return styled_image 93 | 94 | 95 | if __name__ == "__main__": 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("style", help="Path to style image") 98 | parser.add_argument("--data-dir", default="./data", help="Path to data root dir") 99 | parser.add_argument("--gpu-device", type=int, default=0, help="GPU device index") 100 | parser.add_argument("--alpha", type=float, default=1e4, help="Weight of content loss") 101 | parser.add_argument("--beta", type=float, default=1e5, help="Weight of style loss") 102 | parser.add_argument("--gamma", type=float, default=1e-5, help="Weight of style loss") 103 | parser.add_argument("--lambda-f", type=float, default=1e5, help="Weight of feature temporal loss") 104 | parser.add_argument("--lambda-o", type=float, default=2e5, help="Weight of output temporal loss") 105 | parser.add_argument("--epochs", type=int, default=2, help="Number of epochs") 106 | parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate") 107 | parser.add_argument("--output-file", default="./model.pth", help="Output model file path") 108 | parser.add_argument("--frn", action='store_true', help="Use Filter Response Normalization and TLU") 109 | 110 | args = parser.parse_args() 111 | 112 | running_losses = RunningLossesContainer() 113 | global_step = 0 114 | 115 | with torch.cuda.device(args.gpu_device): 116 | transform = transforms.Compose([ 117 | custom_transforms.Resize(640, 360), 118 | custom_transforms.RandomHorizontalFlip(), 119 | custom_transforms.ToTensor() 120 | ]) 121 | monkaa = MonkaaDataset(os.path.join(args.data_dir, "monkaa"), transform) 122 | flyingthings3d = FlyingThings3DDataset(os.path.join(args.data_dir, "flyingthings3d"), transform) 123 | dataset = monkaa + flyingthings3d 124 | batch_size = 2 125 | traindata = torch.utils.data.DataLoader(dataset, 126 | batch_size=batch_size, 127 | shuffle=True, 128 | num_workers=3, 129 | pin_memory=True, 130 | drop_last=True) 131 | 132 | model = ReCoNet(frn=args.frn).cuda() 133 | vgg = Vgg16().cuda() 134 | 135 | with torch.no_grad(): 136 | style = Image.open(args.style) 137 | style = transforms.ToTensor()(style).cuda() 138 | style = style.unsqueeze_(0) 139 | style_vgg_features = vgg(preprocess_for_vgg(style)) 140 | style_gram_matrices = [gram_matrix(x) for x in style_vgg_features] 141 | del style, style_vgg_features 142 | 143 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 144 | writer = SummaryWriter() 145 | 146 | n_epochs = args.epochs 147 | for epoch in range(n_epochs): 148 | for sample in traindata: 149 | optimizer.zero_grad() 150 | 151 | sample = {name: tensor.cuda() for name, tensor in sample.items()} 152 | 153 | occlusion_mask = occlusion_mask_from_flow( 154 | sample["optical_flow"], 155 | sample["reverse_optical_flow"], 156 | sample["motion_boundaries"]) 157 | 158 | # Compute ReCoNet features and output 159 | 160 | reconet_input = preprocess_for_reconet(sample["frame"]) 161 | feature_maps = model.encoder(reconet_input) 162 | output_frame = model.decoder(feature_maps) 163 | 164 | previous_reconet_input = preprocess_for_reconet(sample["previous_frame"]) 165 | previous_feature_maps = model.encoder(previous_reconet_input) 166 | previous_output_frame = model.decoder(previous_feature_maps) 167 | 168 | # Compute VGG features 169 | 170 | vgg_input_frame = preprocess_for_vgg(sample["frame"]) 171 | vgg_output_frame = preprocess_for_vgg(postprocess_reconet(output_frame)) 172 | input_vgg_features = vgg(vgg_input_frame) 173 | output_vgg_features = vgg(vgg_output_frame) 174 | 175 | vgg_previous_input_frame = preprocess_for_vgg(sample["previous_frame"]) 176 | vgg_previous_output_frame = preprocess_for_vgg(postprocess_reconet(previous_output_frame)) 177 | previous_input_vgg_features = vgg(vgg_previous_input_frame) 178 | previous_output_vgg_features = vgg(vgg_previous_output_frame) 179 | 180 | # Compute losses 181 | 182 | alpha = args.alpha 183 | beta = args.beta 184 | gamma = args.gamma 185 | lambda_f = args.lambda_f 186 | lambda_o = args.lambda_o 187 | 188 | losses = { 189 | "content loss": tensors_sum([ 190 | alpha * content_loss(output_vgg_features[2], input_vgg_features[2]), 191 | alpha * content_loss(previous_output_vgg_features[2], previous_input_vgg_features[2]), 192 | ]), 193 | "style loss": tensors_sum([ 194 | beta * style_loss(output_vgg_features, style_gram_matrices), 195 | beta * style_loss(previous_output_vgg_features, style_gram_matrices), 196 | ]), 197 | "total variation": tensors_sum([ 198 | gamma * total_variation(output_frame), 199 | gamma * total_variation(previous_output_frame), 200 | ]), 201 | "feature temporal loss": lambda_f * feature_temporal_loss(feature_maps, previous_feature_maps, 202 | sample["reverse_optical_flow"], 203 | occlusion_mask), 204 | "output temporal loss": lambda_o * output_temporal_loss(reconet_input, previous_reconet_input, 205 | output_frame, previous_output_frame, 206 | sample["reverse_optical_flow"], 207 | occlusion_mask) 208 | } 209 | 210 | training_loss = tensors_sum(list(losses.values())) 211 | losses["training loss"] = training_loss 212 | 213 | training_loss.backward() 214 | optimizer.step() 215 | 216 | with torch.no_grad(): 217 | running_losses.update(losses) 218 | 219 | last_iteration = global_step == len(dataset) // batch_size * n_epochs - 1 220 | if global_step % 25 == 0 or last_iteration: 221 | average_losses = running_losses.get() 222 | for key, value in average_losses.items(): 223 | writer.add_scalar(key, value, global_step) 224 | 225 | running_losses.reset() 226 | 227 | if global_step % 100 == 0 or last_iteration: 228 | styled_test_image = stylize_image(Image.open("test_image.jpeg"), model) 229 | writer.add_image('test image', styled_test_image, global_step) 230 | 231 | for i in range(0, len(dataset), len(dataset) // 4): 232 | sample = dataset[i] 233 | styled_train_image_1 = stylize_image(sample["frame"], model) 234 | styled_train_image_2 = stylize_image(sample["previous_frame"], model) 235 | grid = torchvision.utils.make_grid([styled_train_image_1, styled_train_image_2]) 236 | writer.add_image(f'train images {i}', grid, global_step) 237 | 238 | global_step += 1 239 | 240 | torch.save(model.state_dict(), args.output_file) 241 | writer.close() 242 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | 5 | 6 | def tensors_sum(tensors): 7 | result = 0 8 | for tensor in tensors: 9 | result += tensor 10 | return result 11 | 12 | 13 | def magnitude_squared(x): 14 | return x.pow(2).sum(-1) 15 | 16 | 17 | def nhwc_to_nchw(x): 18 | return x.permute(0, 3, 1, 2) 19 | 20 | 21 | def nchw_to_nhwc(x): 22 | return x.permute(0, 2, 3, 1) 23 | 24 | 25 | def warp_optical_flow(source, reverse_flow): 26 | n, h, w, _ = reverse_flow.shape 27 | 28 | reverse_flow = reverse_flow.clone() 29 | reverse_flow[..., 0] += torch.arange(w).view(1, 1, w).cuda() 30 | reverse_flow[..., 0] *= 2 / w 31 | reverse_flow[..., 0] -= 1 32 | reverse_flow[..., 1] += torch.arange(h).view(1, h, 1).cuda() 33 | reverse_flow[..., 1] *= 2 / h 34 | reverse_flow[..., 1] -= 1 35 | 36 | return torch.nn.functional.grid_sample(source, reverse_flow, padding_mode='border') 37 | 38 | 39 | def occlusion_mask_from_flow(optical_flow, reverse_optical_flow, motion_boundaries): 40 | # "Dense Point Trajectories by GPU-accelerated Large Displacement Optical Flow" 41 | # Page 7 42 | 43 | optical_flow = nhwc_to_nchw(optical_flow) 44 | optical_flow = warp_optical_flow(optical_flow, reverse_optical_flow) 45 | optical_flow = nchw_to_nhwc(optical_flow) 46 | 47 | forward_magnitude = magnitude_squared(optical_flow) 48 | reverse_magnitude = magnitude_squared(reverse_optical_flow) 49 | sum_magnitude = magnitude_squared(optical_flow + reverse_optical_flow) 50 | 51 | occlusion_mask = sum_magnitude < (0.01 * (forward_magnitude + reverse_magnitude) + 0.5) 52 | occlusion_mask &= ~motion_boundaries 53 | return occlusion_mask.to(torch.float32).unsqueeze_(1) 54 | 55 | 56 | def rgb_to_luminance(x): 57 | return x[:, 0, ...] * 0.2126 + x[:, 1, ...] * 0.7512 + x[:, 2, ...] * 0.0722 58 | 59 | 60 | def l2_squared(x): 61 | return x.pow(2).sum() 62 | 63 | 64 | def mean_l2_squared(x): 65 | return x.pow(2).mean() 66 | 67 | 68 | def resize_optical_flow(optical_flow, h, w): 69 | optical_flow_nchw = nhwc_to_nchw(optical_flow) 70 | optical_flow_resized_nchw = torch.nn.functional.interpolate(optical_flow_nchw, size=(h, w), mode='bilinear') 71 | optical_flow_resized = nchw_to_nhwc(optical_flow_resized_nchw) 72 | 73 | old_h, old_w = optical_flow_nchw.shape[-2:] 74 | h_scale, w_scale = h / old_h, w / old_w 75 | optical_flow_resized[..., 0] *= w_scale 76 | optical_flow_resized[..., 1] *= h_scale 77 | return optical_flow_resized 78 | 79 | 80 | def gram_matrix(feature_map): 81 | n, c, h, w = feature_map.shape 82 | feature_map = feature_map.reshape((n, c, h * w)) 83 | return feature_map.bmm(feature_map.transpose(1, 2)) / (c * h * w) 84 | 85 | 86 | def normalize_batch(batch, mean, std): 87 | dtype = batch.dtype 88 | mean = torch.as_tensor(mean, dtype=dtype, device=batch.device) 89 | std = torch.as_tensor(std, dtype=dtype, device=batch.device) 90 | return (batch - mean[None, :, None, None]) / std[None, :, None, None] 91 | 92 | 93 | def preprocess_for_vgg(images_batch): 94 | return normalize_batch(images_batch, 95 | mean=[0.485, 0.456, 0.406], 96 | std=[0.229, 0.224, 0.225]) 97 | 98 | 99 | def preprocess_for_reconet(images_batch): 100 | images_batch = images_batch.clone() 101 | return images_batch * 2 - 1 102 | 103 | 104 | def postprocess_reconet(images_batch): 105 | images_batch = images_batch.clone() 106 | return (images_batch + 1) / 2 107 | 108 | 109 | class RunningLossesContainer: 110 | 111 | def __init__(self): 112 | self.values = defaultdict(lambda: 0) 113 | self.counters = defaultdict(lambda: 0) 114 | 115 | def update(self, losses): 116 | for key, value in losses.items(): 117 | self.values[key] += value.item() 118 | self.counters[key] += 1 119 | 120 | def get(self): 121 | return {key: self.values[key] / self.counters[key] for key in self.values} 122 | 123 | def reset(self): 124 | self.values.clear() 125 | self.counters.clear() 126 | 127 | 128 | class Dummy: 129 | 130 | def __init__(self, *args, **kwargs): 131 | pass 132 | 133 | def __call__(self, *args, **kwargs): 134 | return self 135 | 136 | def __getattribute__(self, item): 137 | return self 138 | 139 | def __enter__(self): 140 | return self 141 | 142 | def __exit__(self, exc_type, exc_val, exc_tb): 143 | return 144 | -------------------------------------------------------------------------------- /vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models import vgg16 4 | 5 | 6 | class Vgg16(torch.nn.Module): 7 | def __init__(self): 8 | super(Vgg16, self).__init__() 9 | features = list(vgg16(pretrained=True).features)[:23] 10 | self.layers = nn.ModuleList(features).eval() 11 | for param in self.parameters(): 12 | param.requires_grad = False 13 | 14 | def forward(self, x): 15 | results = [] 16 | layers_of_interest = {3, 8, 15, 22} 17 | 18 | for i, layer in enumerate(self.layers): 19 | x = layer(x) 20 | if i in layers_of_interest: 21 | results.append(x) 22 | 23 | return results 24 | --------------------------------------------------------------------------------