├── README.md ├── lib └── utils.py ├── models ├── __init__.py ├── dvc_model.py ├── interpolation │ ├── IFNet.py │ ├── IFNet_m.py │ ├── RIFE.py │ ├── refine.py │ └── warplayer.py ├── interpolation_net.py ├── loss.py └── ops.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | # Distributed DVC 2 | 3 | The official PyTorch implementation of our **ICME 2023** paper: 4 | 5 | **Low-complexity Deep Video Compression with A Distributed Coding Architecture** 6 | 7 | [Xinjie Zhang](https://xinjie-q.github.io/), [Jiawei Shao](https://shaojiawei07.github.io/), [Jun Zhang](https://eejzhang.people.ust.hk/) 8 | 9 | [[ArXiv Preprint](https://arxiv.org/abs/2303.11599v2)] 10 | 11 | ### :bookmark:Brief Introduction 12 | 13 | Prevalent predictive coding-based video compression methods rely on a heavy encoder to reduce temporal redundancy, which makes it challenging to deploy them on resource-constrained devices. Since the 1970s, distributed source coding theory has indicated that independent encoding and joint decoding with side information (SI) can achieve high-efficient compression of correlated sources. This has inspired a *distributed coding* architecture aiming at reducing the encoding complexity. However, traditional distributed coding methods suffer from a substantial performance gap to predictive coding ones. Inspired by the great success of learning-based compression, we propose the first end-to-end distributed deep video compression framework to improve the rate-distortion performance. A key ingredient is an effective SI generation module at the decoder, which helps to effectively exploit inter-frame correlations without computation-intensive encoder-side motion estimation and compensation. Experiments show that our method significantly outperforms conventional distributed video coding and H.264. Meanwhile, it enjoys $6\sim7\times$ encoding speedup against DVC with comparable compression performance. 14 | 15 | ## Acknowledgement 16 | 17 | :heart::heart::heart:Our idea is implemented based on the following projects. We really appreciate their wonderful open-source works! 18 | 19 | - [CompressAI](https://github.com/InterDigitalInc/CompressAI) [[related paper](https://arxiv.org/abs/2011.03029)] 20 | - [RIFE](https://github.com/megvii-research/ECCV2022-RIFE) [[related paper](https://arxiv.org/abs/2011.06294)] 21 | 22 | ## Citation 23 | 24 | If any parts of our paper and code help your research, please consider citing us and giving a star to our repository. 25 | 26 | ``` 27 | @inproceedings{zhang2023low, 28 | title={Low-complexity Deep Video Compression with A Distributed Coding Architecture}, 29 | author={Zhang, Xinjie and Shao, Jiawei and Zhang, Jun}, 30 | booktitle={IEEE International Conference on Multimedia and Expo}, 31 | year={2023}, 32 | } 33 | 34 | @article{zhang2023low, 35 | title={Low-complexity Deep Video Compression with A Distributed Coding Architecture}, 36 | author={Zhang, Xinjie and Shao, Jiawei and Zhang, Jun}, 37 | journal={arXiv preprint arXiv:2303.11599}, 38 | year={2023} 39 | } 40 | ``` 41 | 42 | -------------------------------------------------------------------------------- /lib/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from PIL import Image 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision import transforms 6 | import torch.nn as nn 7 | import torch 8 | import torchvision.transforms.functional as tf 9 | from typing import Dict 10 | from torch import Tensor 11 | import numpy as np 12 | import glob 13 | import cv2 14 | import json 15 | from PIL import Image 16 | import random 17 | 18 | 19 | class Vimeo(Dataset): 20 | def __init__(self, data_root, is_training, crop_size): 21 | self.data_root = data_root # .\vimeo_septuplet 22 | self.image_root = os.path.join(self.data_root, 'sequences') # .\vimeo_septuplet\sequences 23 | self.training = is_training 24 | self.crop_size = crop_size 25 | if self.training: 26 | train_fn = os.path.join(self.data_root, 'sep_trainlist.txt') # 64612 27 | with open(train_fn, 'r') as f: 28 | self.trainlist = f.read().splitlines() # ['00001/0001','00001/0002',...] 29 | else: 30 | test_fn = os.path.join(self.data_root, 'sep_testlist.txt') 31 | with open(test_fn, 'r') as f: 32 | self.testlist = f.read().splitlines() 33 | #self.testlist = self.testlist[:len(self.testlist)//2] 34 | 35 | self.transforms = transforms.Compose([transforms.ToTensor()]) 36 | 37 | 38 | def train_transform(self, frame_list): 39 | # Random cropping augmentation 40 | h_offset = random.choice(range(256 - self.crop_size[0] + 1)) 41 | w_offset = random.choice(range(448 - self.crop_size[1]+ 1)) 42 | 43 | choice = [random.randint(0, 1), random.randint(0, 1), random.randint(0, 1), random.randint(0, 1)] 44 | flip_code = random.randint(-1,1) # 0 : Top-bottom | 1: Right-left | -1: both 45 | 46 | frame_list_ = [] 47 | for frame in frame_list: 48 | frame = frame[h_offset:h_offset + self.crop_size[0], w_offset: w_offset + self.crop_size[1], :] 49 | 50 | # Rotation augmentation 51 | if self.crop_size[0] == self.crop_size[1]: 52 | if choice[0]: 53 | frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE) 54 | elif choice[1]: 55 | frame = cv2.rotate(frame, cv2.ROTATE_180) 56 | elif choice[2]: 57 | frame = cv2.rotate(frame, cv2.ROTATE_90_COUNTERCLOCKWISE) 58 | 59 | # Flip augmentation 60 | if choice[3]: 61 | frame = cv2.flip(frame, flip_code) 62 | 63 | frame = tf.to_tensor(frame) #将numpy数组或PIL.Image读的图片转换成(C,H, W)的Tensor格式且/255归一化到[0,1.0]之间 64 | frame_list_.append(frame) 65 | 66 | return frame_list_ 67 | # return map(TF.to_tensor, (frame1, frame2, frame3, flow, frame_fw, frame_bw)) 68 | #return map(tf.to_tensor, (frame1, frame2, frame3)) 69 | 70 | 71 | def test_transform(self, frame_list): 72 | frame_list_ = [self.transforms(frame) for frame in frame_list] 73 | return frame_list_ 74 | 75 | def __getitem__(self, index): 76 | if self.training: 77 | imgpath = os.path.join(self.image_root, self.trainlist[index]) 78 | else: 79 | imgpath = os.path.join(self.image_root, self.testlist[index]) 80 | imgpaths = [imgpath + f'/im{i}.png' for i in range(1, 8)] 81 | images = [cv2.imread(pth) for pth in imgpaths] 82 | if self.training: 83 | images = self.train_transform(images) 84 | # Random Temporal Flip 85 | if random.random() >= 0.5: 86 | images = images[::-1] 87 | else: 88 | images = self.test_transform(images) 89 | 90 | return images 91 | 92 | ''' 93 | if random.randint(0,1): 94 | First_fn = os.path.join(self.sequence_list[index], 'im1.png') 95 | Third_fn = os.path.join(self.sequence_list[index], 'im3.png') 96 | else: 97 | First_fn = os.path.join(self.sequence_list[index], 'im3.png') 98 | Third_fn = os.path.join(self.sequence_list[index], 'im1.png') 99 | 100 | Second_fn = os.path.join(self.sequence_list[index], 'im2.png') 101 | 102 | frame1 = imread(First_fn) 103 | frame2 = imread(Second_fn) 104 | frame3 = imread(Third_fn) 105 | 106 | frame1, frame2, frame3 = self.transform(frame1, frame2, frame3) 107 | 108 | Input = torch.cat((frame1, frame3), dim=0) 109 | 110 | return Input, frame2 111 | ''' 112 | 113 | def __len__(self): 114 | if self.training: 115 | return len(self.trainlist) 116 | else: 117 | return len(self.testlist) 118 | 119 | def TransformImgsWithSameCrop(images, cropsize): 120 | (i, j, h, w) = transforms.RandomCrop.get_params(images[0], (cropsize[0], cropsize[1])) 121 | images_ = [] 122 | for image in images: 123 | image = image.crop((j, i, j+w, i+h)) # top left corner (j,i); bottom right corner (j+w, i+h) 124 | image = tf.to_tensor(image) 125 | images_.append(image) 126 | 127 | images = images_ 128 | 129 | return images 130 | 131 | 132 | 133 | def save_checkpoint(state, is_best=False, log_dir=None, filename="ckpt.pth.tar"): 134 | save_file = os.path.join(log_dir, filename) 135 | print("save model in:", save_file) 136 | torch.save(state, save_file) 137 | if is_best: 138 | torch.save(state, os.path.join(log_dir, filename.replace(".pth.tar", ".best.pth.tar"))) 139 | 140 | 141 | class CustomDataParallel(nn.DataParallel): 142 | """Custom DataParallel to access the module methods.""" 143 | 144 | def __getattr__(self, key): 145 | try: 146 | return super().__getattr__(key) 147 | except AttributeError: 148 | return getattr(self.module, key) 149 | 150 | 151 | class AverageMeter(object): 152 | """Computes and stores the average and current value""" 153 | def __init__(self, name, fmt=':f'): 154 | self.name = name 155 | self.fmt = fmt 156 | self.reset() 157 | 158 | def reset(self): 159 | self.val = 0 160 | self.avg = 0 161 | self.sum = 0 162 | self.count = 0 163 | 164 | def update(self, val, n=1): 165 | self.val = val 166 | self.sum += val * n 167 | self.count += n 168 | self.avg = self.sum / self.count 169 | 170 | def __str__(self): 171 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 172 | return fmtstr.format(**self.__dict__) 173 | 174 | 175 | def get_output_folder(parent_dir, env_name, output_current_folder=False): 176 | """Return save folder. 177 | Assumes folders in the parent_dir have suffix -run{run 178 | number}. Finds the highest run number and sets the output folder 179 | to that number + 1. This is just convenient so that if you run the 180 | same script multiple times tensorboard can plot all of the results 181 | on the same plots with different names. 182 | Parameters 183 | ---------- 184 | parent_dir: str 185 | Path of the directory containing all experiment runs. 186 | Returns 187 | ------- 188 | parent_dir/run_dir 189 | Path to this run's save directory. 190 | """ 191 | os.makedirs(parent_dir, exist_ok=True) 192 | experiment_id = 0 193 | for folder_name in os.listdir(parent_dir): 194 | if not os.path.isdir(os.path.join(parent_dir, folder_name)): 195 | continue 196 | try: 197 | folder_name = int(folder_name.split('-run')[-1]) 198 | if folder_name > experiment_id: 199 | experiment_id = folder_name 200 | except: 201 | pass 202 | if not output_current_folder: 203 | experiment_id += 1 204 | 205 | parent_dir = os.path.join(parent_dir, env_name) 206 | parent_dir = parent_dir + '-run{}'.format(experiment_id) 207 | os.makedirs(parent_dir, exist_ok=True) 208 | return parent_dir 209 | 210 | 211 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import * 2 | from compressai.models import * 3 | from .interpolation_net import VideoInterpolationNet 4 | from .dvc_model import * 5 | 6 | models_arch = { 7 | "bmshj2018-factorized": FactorizedPrior, 8 | "bmshj2018-hyperprior": ScaleHyperprior, 9 | "mbt2018-mean": MeanScaleHyperprior, 10 | "mbt2018": JointAutoregressiveHierarchicalPriors, 11 | "cheng2020-anchor": Cheng2020Anchor, 12 | "cheng2020-attn": Cheng2020Attention, 13 | "DVC-ChannelARpriors": DVC_ChannelARpriors, 14 | } 15 | 16 | 17 | cfgs = { 18 | "DVC-ChannelARpriors":{ 19 | 1: (192, 192), 20 | 2: (192, 192), 21 | 3: (192, 192), 22 | 4: (192, 192), 23 | 5: (192, 192), 24 | }, 25 | "bmshj2018-factorized": { 26 | 1: (128, 192), 27 | 2: (128, 192), 28 | 3: (128, 192), 29 | 4: (128, 192), 30 | 5: (128, 192), 31 | 6: (192, 320), 32 | 7: (192, 320), 33 | 8: (192, 320), 34 | }, 35 | "bmshj2018-hyperprior": { 36 | 1: (128, 192), 37 | 2: (128, 192), 38 | 3: (128, 192), 39 | 4: (128, 192), 40 | 5: (128, 192), 41 | 6: (192, 320), 42 | 7: (192, 320), 43 | 8: (192, 320), 44 | }, 45 | "mbt2018-mean": { 46 | 1: (128, 192), 47 | 2: (128, 192), 48 | 3: (128, 192), 49 | 4: (128, 192), 50 | 5: (192, 320), 51 | 6: (192, 320), 52 | 7: (192, 320), 53 | 8: (192, 320), 54 | }, 55 | "mbt2018": { 56 | 1: (192, 192), 57 | 2: (192, 192), 58 | 3: (192, 192), 59 | 4: (192, 192), 60 | 5: (192, 320), 61 | 6: (192, 320), 62 | 7: (192, 320), 63 | 8: (192, 320), 64 | }, 65 | "cheng2020-anchor": { 66 | 1: (128,), 67 | 2: (128,), 68 | 3: (128,), 69 | 4: (192,), 70 | 5: (192,), 71 | 6: (192,), 72 | }, 73 | "cheng2020-attn": { 74 | 1: (128,), 75 | 2: (128,), 76 | 3: (128,), 77 | 4: (192,), 78 | 5: (192,), 79 | 6: (192,), 80 | }, 81 | } 82 | 83 | 84 | 85 | root_url = "https://compressai.s3.amazonaws.com/models/v1" 86 | model_urls = { 87 | "bmshj2018-hyperprior": { 88 | "mse": { 89 | 1: f"{root_url}/bmshj2018-hyperprior-1-7eb97409.pth.tar", 90 | 2: f"{root_url}/bmshj2018-hyperprior-2-93677231.pth.tar", 91 | 3: f"{root_url}/bmshj2018-hyperprior-3-6d87be32.pth.tar", 92 | 4: f"{root_url}/bmshj2018-hyperprior-4-de1b779c.pth.tar", 93 | 5: f"{root_url}/bmshj2018-hyperprior-5-f8b614e1.pth.tar", 94 | 6: f"{root_url}/bmshj2018-hyperprior-6-1ab9c41e.pth.tar", 95 | 7: f"{root_url}/bmshj2018-hyperprior-7-3804dcbd.pth.tar", 96 | 8: f"{root_url}/bmshj2018-hyperprior-8-a583f0cf.pth.tar", 97 | }, 98 | "ms-ssim": { 99 | 1: f"{root_url}/bmshj2018-hyperprior-ms-ssim-1-5cf249be.pth.tar", 100 | 2: f"{root_url}/bmshj2018-hyperprior-ms-ssim-2-1ff60d1f.pth.tar", 101 | 3: f"{root_url}/bmshj2018-hyperprior-ms-ssim-3-92dd7878.pth.tar", 102 | 4: f"{root_url}/bmshj2018-hyperprior-ms-ssim-4-4377354e.pth.tar", 103 | 5: f"{root_url}/bmshj2018-hyperprior-ms-ssim-5-c34afc8d.pth.tar", 104 | 6: f"{root_url}/bmshj2018-hyperprior-ms-ssim-6-3a6d8229.pth.tar", 105 | 7: f"{root_url}/bmshj2018-hyperprior-ms-ssim-7-8747d3bc.pth.tar", 106 | 8: f"{root_url}/bmshj2018-hyperprior-ms-ssim-8-cc15b5f3.pth.tar", 107 | }, 108 | }, 109 | "mbt2018-mean": { 110 | "mse": { 111 | 1: f"{root_url}/mbt2018-mean-1-e522738d.pth.tar", 112 | 2: f"{root_url}/mbt2018-mean-2-e54a039d.pth.tar", 113 | 3: f"{root_url}/mbt2018-mean-3-723404a8.pth.tar", 114 | 4: f"{root_url}/mbt2018-mean-4-6dba02a3.pth.tar", 115 | 5: f"{root_url}/mbt2018-mean-5-d504e8eb.pth.tar", 116 | 6: f"{root_url}/mbt2018-mean-6-a19628ab.pth.tar", 117 | 7: f"{root_url}/mbt2018-mean-7-d5d441d1.pth.tar", 118 | 8: f"{root_url}/mbt2018-mean-8-8089ae3e.pth.tar", 119 | }, 120 | "ms-ssim": { 121 | 1: f"{root_url}/mbt2018-mean-ms-ssim-1-5bf9c0b6.pth.tar", 122 | 2: f"{root_url}/mbt2018-mean-ms-ssim-2-e2a1bf3f.pth.tar", 123 | 3: f"{root_url}/mbt2018-mean-ms-ssim-3-640ce819.pth.tar", 124 | 4: f"{root_url}/mbt2018-mean-ms-ssim-4-12626c13.pth.tar", 125 | 5: f"{root_url}/mbt2018-mean-ms-ssim-5-1be7f059.pth.tar", 126 | 6: f"{root_url}/mbt2018-mean-ms-ssim-6-b83bf379.pth.tar", 127 | 7: f"{root_url}/mbt2018-mean-ms-ssim-7-ddf9644c.pth.tar", 128 | 8: f"{root_url}/mbt2018-mean-ms-ssim-8-0cc7b94f.pth.tar", 129 | }, 130 | }, 131 | "mbt2018": { 132 | "mse": { 133 | 1: f"{root_url}/mbt2018-1-3f36cd77.pth.tar", 134 | 2: f"{root_url}/mbt2018-2-43b70cdd.pth.tar", 135 | 3: f"{root_url}/mbt2018-3-22901978.pth.tar", 136 | 4: f"{root_url}/mbt2018-4-456e2af9.pth.tar", 137 | 5: f"{root_url}/mbt2018-5-b4a046dd.pth.tar", 138 | 6: f"{root_url}/mbt2018-6-7052e5ea.pth.tar", 139 | 7: f"{root_url}/mbt2018-7-8ba2bf82.pth.tar", 140 | 8: f"{root_url}/mbt2018-8-dd0097aa.pth.tar", 141 | }, 142 | "ms-ssim": { 143 | 1: f"{root_url}/mbt2018-ms-ssim-1-2878436b.pth.tar", 144 | 2: f"{root_url}/mbt2018-ms-ssim-2-c41cb208.pth.tar", 145 | 3: f"{root_url}/mbt2018-ms-ssim-3-d0dd64e8.pth.tar", 146 | 4: f"{root_url}/mbt2018-ms-ssim-4-a120e037.pth.tar", 147 | 5: f"{root_url}/mbt2018-ms-ssim-5-9b30e3b7.pth.tar", 148 | 6: f"{root_url}/mbt2018-ms-ssim-6-f8b3626f.pth.tar", 149 | 7: f"{root_url}/mbt2018-ms-ssim-7-16e6ff50.pth.tar", 150 | 8: f"{root_url}/mbt2018-ms-ssim-8-0cb49d43.pth.tar", 151 | }, 152 | }, 153 | "cheng2020-anchor": { 154 | "mse": { 155 | 1: f"{root_url}/cheng2020-anchor-1-dad2ebff.pth.tar", 156 | 2: f"{root_url}/cheng2020-anchor-2-a29008eb.pth.tar", 157 | 3: f"{root_url}/cheng2020-anchor-3-e49be189.pth.tar", 158 | 4: f"{root_url}/cheng2020-anchor-4-98b0b468.pth.tar", 159 | 5: f"{root_url}/cheng2020-anchor-5-23852949.pth.tar", 160 | 6: f"{root_url}/cheng2020-anchor-6-4c052b1a.pth.tar", 161 | }, 162 | "ms-ssim": { 163 | 1: f"{root_url}/cheng2020_anchor-ms-ssim-1-20f521db.pth.tar", 164 | 2: f"{root_url}/cheng2020_anchor-ms-ssim-2-c7ff5812.pth.tar", 165 | 3: f"{root_url}/cheng2020_anchor-ms-ssim-3-c23e22d5.pth.tar", 166 | 4: f"{root_url}/cheng2020_anchor-ms-ssim-4-0e658304.pth.tar", 167 | 5: f"{root_url}/cheng2020_anchor-ms-ssim-5-c0a95e77.pth.tar", 168 | 6: f"{root_url}/cheng2020_anchor-ms-ssim-6-f2dc1913.pth.tar", 169 | }, 170 | }, 171 | "cheng2020-attn": { 172 | "mse": { 173 | 1: f"{root_url}/cheng2020_attn-mse-1-465f2b64.pth.tar", 174 | 2: f"{root_url}/cheng2020_attn-mse-2-e0805385.pth.tar", 175 | 3: f"{root_url}/cheng2020_attn-mse-3-2d07bbdf.pth.tar", 176 | 4: f"{root_url}/cheng2020_attn-mse-4-f7b0ccf2.pth.tar", 177 | 5: f"{root_url}/cheng2020_attn-mse-5-26c8920e.pth.tar", 178 | 6: f"{root_url}/cheng2020_attn-mse-6-730501f2.pth.tar", 179 | }, 180 | "ms-ssim": { 181 | 1: f"{root_url}/cheng2020_attn-ms-ssim-1-c5381d91.pth.tar", 182 | 2: f"{root_url}/cheng2020_attn-ms-ssim-2-5dad201d.pth.tar", 183 | 3: f"{root_url}/cheng2020_attn-ms-ssim-3-5c9be841.pth.tar", 184 | 4: f"{root_url}/cheng2020_attn-ms-ssim-4-8b2f647e.pth.tar", 185 | 5: f"{root_url}/cheng2020_attn-ms-ssim-5-5ca1f34c.pth.tar", 186 | 6: f"{root_url}/cheng2020_attn-ms-ssim-6-216423ec.pth.tar", 187 | }, 188 | }, 189 | } 190 | -------------------------------------------------------------------------------- /models/dvc_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from torch import nn 4 | from compressai.models.utils import update_registered_buffers, conv, deconv 5 | from compressai.ops import LowerBound 6 | # From Balle's tensorflow compression examples 7 | SCALES_MIN = 0.11 8 | SCALES_MAX = 256 9 | SCALES_LEVELS = 64 10 | 11 | from compressai.entropy_models import GaussianConditional, EntropyBottleneck 12 | from compressai.layers import GDN 13 | from compressai.ans import BufferedRansEncoder, RansDecoder 14 | import time 15 | from .ops import ste_round 16 | 17 | def get_scale_table(min=SCALES_MIN, max=SCALES_MAX, levels=SCALES_LEVELS): 18 | return torch.exp(torch.linspace(math.log(min), math.log(max), levels)) 19 | 20 | 21 | class DVC_ChannelARpriors(nn.Module): 22 | def __init__(self, N=192, M=192, side_input_channels=3, num_slices=8): 23 | super().__init__() 24 | self.num_slices = num_slices #each slices == 24 25 | 26 | self.encode_xa = nn.Sequential( 27 | conv(3, N, kernel_size=5, stride=2), 28 | GDN(N), 29 | conv(N, N, kernel_size=5, stride=2), 30 | GDN(N), 31 | conv(N, N, kernel_size=5, stride=2), 32 | GDN(N), 33 | conv(N, M, kernel_size=5, stride=2), 34 | ) 35 | 36 | self.encode_side = nn.Sequential( 37 | conv(side_input_channels, N, kernel_size=5, stride=2), 38 | GDN(N), 39 | conv(N, N, kernel_size=5, stride=2), 40 | GDN(N), 41 | conv(N, N, kernel_size=5, stride=2), 42 | GDN(N), 43 | conv(N, M, kernel_size=5, stride=2), 44 | ) 45 | 46 | self.decode_x = nn.Sequential( 47 | deconv(2 * M, N, kernel_size=5, stride=2), 48 | GDN(N, inverse=True), 49 | deconv(N, N, kernel_size=5, stride=2), 50 | GDN(N, inverse=True), 51 | deconv(N, N, kernel_size=5, stride=2), 52 | GDN(N, inverse=True), 53 | deconv(N, 3, kernel_size=5, stride=2), 54 | ) 55 | 56 | self.hyper_xa = nn.Sequential( 57 | conv(M, N, stride=1, kernel_size=3), 58 | nn.LeakyReLU(inplace=True), 59 | conv(N, N, stride=2, kernel_size=5), 60 | nn.LeakyReLU(inplace=True), 61 | conv(N, N, stride=2, kernel_size=5), 62 | ) 63 | 64 | self.hyper_xs = nn.Sequential( 65 | deconv(N, M, stride=2, kernel_size=5), 66 | nn.LeakyReLU(inplace=True), 67 | deconv(M, M * 3 // 2, stride=2, kernel_size=5), 68 | nn.LeakyReLU(inplace=True), 69 | conv(M * 3 // 2, M * 2, stride=1, kernel_size=3), 70 | ) 71 | 72 | self.cc_mean_transforms = nn.ModuleList( 73 | nn.Sequential( 74 | conv(M + M // self.num_slices * i, M//3 + M//3//self.num_slices* i, stride=1, kernel_size=3), 75 | nn.LeakyReLU(inplace=True), 76 | conv(M//3 + M//3//self.num_slices* i, M//6 + M//6//self.num_slices* i, stride=1, kernel_size=3), 77 | nn.LeakyReLU(inplace=True), 78 | conv(M//6 + M//6//self.num_slices* i, M // self.num_slices, stride=1, kernel_size=3), 79 | ) for i in range(self.num_slices) 80 | ) 81 | self.cc_scale_transforms = nn.ModuleList( 82 | nn.Sequential( 83 | conv(M + M // self.num_slices * i, M//3 + M//3//self.num_slices* i, stride=1, kernel_size=3), 84 | nn.LeakyReLU(inplace=True), 85 | conv(M//3 + M//3//self.num_slices* i, M//6 + M//6//self.num_slices* i, stride=1, kernel_size=3), 86 | nn.LeakyReLU(inplace=True), 87 | conv(M//6 + M//6//self.num_slices* i, M // self.num_slices, stride=1, kernel_size=3), 88 | ) for i in range(self.num_slices) 89 | ) 90 | self.lrp_transforms = nn.ModuleList( 91 | nn.Sequential( 92 | conv(M + M // self.num_slices * (i+1), M//3 + M//3//self.num_slices * (i+1), stride=1, kernel_size=3), 93 | nn.LeakyReLU(inplace=True), 94 | conv(M//3 + M//3//self.num_slices* (i+1), M//6 + M//6//self.num_slices * (i+1), stride=1, kernel_size=3), 95 | nn.LeakyReLU(inplace=True), 96 | conv(M//6 + M//6//self.num_slices* (i+1), M // self.num_slices, stride=1, kernel_size=3), 97 | ) for i in range(self.num_slices) 98 | ) 99 | 100 | self.N = int(N) 101 | self.M = int(M) 102 | 103 | self.entropy_bottleneck = EntropyBottleneck(N) 104 | self.gaussian_conditional = GaussianConditional(None) 105 | 106 | def aux_loss(self): 107 | aux_loss = sum( 108 | m.loss() for m in self.modules() if isinstance(m, EntropyBottleneck) 109 | ) 110 | return aux_loss 111 | 112 | 113 | 114 | def forward(self, x, side, replace_idx=None): 115 | #x, side = input[:,:3, :,:], input[:,3:,:,:] 116 | y = self.encode_xa(x) 117 | y_shape = y.shape[2:] 118 | z = self.hyper_xa(y) 119 | _, z_likelihoods = self.entropy_bottleneck(z) 120 | 121 | z_offset = self.entropy_bottleneck._get_medians() 122 | z_tmp = z - z_offset 123 | z_hat = ste_round(z_tmp) + z_offset 124 | 125 | latent = self.hyper_xs(z_hat) 126 | latent_means, latent_scales = latent.chunk(2, 1) 127 | 128 | y_slices = y.chunk(self.num_slices, 1) 129 | y_hat_slices = [] 130 | y_likelihood = [] 131 | 132 | for slice_index, y_slice in enumerate(y_slices): 133 | support_slices = (y_hat_slices[:slice_index]) 134 | mean_support = torch.cat([latent_means] + support_slices, dim=1) 135 | mu = self.cc_mean_transforms[slice_index](mean_support) 136 | #print("slice_index:", slice_index, y_slice.size(), mean_support.size(), mu.size()) 137 | mu = mu[:, :, :y_shape[0], :y_shape[1]] 138 | #print("mu:", mu.size()) 139 | 140 | scale_support = torch.cat([latent_scales] + support_slices, dim=1) 141 | scale = self.cc_scale_transforms[slice_index](scale_support) 142 | scale = scale[:, :, :y_shape[0], :y_shape[1]] 143 | 144 | _, y_slice_likelihood = self.gaussian_conditional(y_slice, scale, mu) 145 | y_likelihood.append(y_slice_likelihood) 146 | y_hat_slice = ste_round(y_slice - mu) + mu 147 | 148 | lrp_support = torch.cat([mean_support, y_hat_slice], dim=1) 149 | lrp = self.lrp_transforms[slice_index](lrp_support) 150 | lrp = 0.5 * torch.tanh(lrp) 151 | y_hat_slice += lrp 152 | #print("lrp:", lrp_support.size(), lrp.size()) 153 | 154 | y_hat_slices.append(y_hat_slice) 155 | #input() 156 | 157 | y_hat = torch.cat(y_hat_slices, dim=1) 158 | y_likelihoods = torch.cat(y_likelihood, dim=1) 159 | 160 | fea_side = self.encode_side(side) 161 | fea_side_likelihoods = self.entropy(fea_side) 162 | x_hat = self.decode_x(torch.cat((y_hat, fea_side), 1)) 163 | 164 | return { 165 | "x_hat": x_hat, 166 | "likelihoods": {"y": y_likelihoods, "z": z_likelihoods, }, 167 | } 168 | 169 | def entropy(self, y): 170 | y_shape = y.shape[2:] 171 | z = self.hyper_xa(y) 172 | _, z_likelihoods = self.entropy_bottleneck(z) 173 | 174 | z_offset = self.entropy_bottleneck._get_medians() 175 | z_tmp = z - z_offset 176 | z_hat = ste_round(z_tmp) + z_offset 177 | 178 | latent = self.hyper_xs(z_hat) 179 | latent_means, latent_scales = latent.chunk(2, 1) 180 | 181 | y_slices = y.chunk(self.num_slices, 1) 182 | y_hat_slices = [] 183 | y_likelihood = [] 184 | 185 | for slice_index, y_slice in enumerate(y_slices): 186 | support_slices = (y_hat_slices[:slice_index]) 187 | mean_support = torch.cat([latent_means] + support_slices, dim=1) 188 | mu = self.cc_mean_transforms[slice_index](mean_support) 189 | mu = mu[:, :, :y_shape[0], :y_shape[1]] 190 | 191 | scale_support = torch.cat([latent_scales] + support_slices, dim=1) 192 | scale = self.cc_scale_transforms[slice_index](scale_support) 193 | scale = scale[:, :, :y_shape[0], :y_shape[1]] 194 | 195 | _, y_slice_likelihood = self.gaussian_conditional(y_slice, scale, mu) 196 | y_likelihood.append(y_slice_likelihood) 197 | y_hat_slice = ste_round(y_slice - mu) + mu 198 | 199 | lrp_support = torch.cat([mean_support, y_hat_slice], dim=1) 200 | lrp = self.lrp_transforms[slice_index](lrp_support) 201 | lrp = 0.5 * torch.tanh(lrp) 202 | y_hat_slice += lrp 203 | 204 | y_hat_slices.append(y_hat_slice) 205 | 206 | y_hat = torch.cat(y_hat_slices, dim=1) 207 | y_likelihoods = torch.cat(y_likelihood, dim=1) 208 | return y_likelihoods 209 | 210 | 211 | 212 | def update(self, scale_table=None, force=False): 213 | if scale_table is None: 214 | scale_table = get_scale_table() 215 | updated = self.gaussian_conditional.update_scale_table(scale_table, force=force) 216 | rv = self.entropy_bottleneck.update(force=force) 217 | updated |= rv 218 | return updated 219 | 220 | def load_state_dict(self, state_dict): 221 | update_registered_buffers( 222 | self.gaussian_conditional, 223 | "gaussian_conditional", 224 | ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"], 225 | state_dict, 226 | ) 227 | update_registered_buffers( 228 | self.entropy_bottleneck, 229 | "entropy_bottleneck", 230 | ["_quantized_cdf", "_offset", "_cdf_length"], 231 | state_dict, 232 | ) 233 | super().load_state_dict(state_dict) 234 | 235 | def compress(self, x): 236 | start = time.time() 237 | y = self.encode_xa(x) 238 | middle_0 = time.time() 239 | z = self.hyper_xa(y) 240 | middle_1 = time.time() 241 | z_strings = self.entropy_bottleneck.compress(z) 242 | z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) 243 | middle_2 = time.time() 244 | 245 | latent = self.hyper_xs(z_hat) 246 | middle_3 = time.time() 247 | 248 | latent_means, latent_scales = latent.chunk(2, 1) 249 | 250 | y_slices = y.chunk(self.num_slices, 1) 251 | y_hat_slices = [] 252 | y_scales = [] 253 | y_means = [] 254 | 255 | cdf = self.gaussian_conditional.quantized_cdf.tolist() 256 | cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist() 257 | offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist() 258 | 259 | encoder = BufferedRansEncoder() 260 | symbols_list = [] 261 | indexes_list = [] 262 | y_strings = [] 263 | y_shape = y.shape[2:] 264 | 265 | for slice_index, y_slice in enumerate(y_slices): 266 | support_slices = (y_hat_slices[:slice_index]) 267 | 268 | mean_support = torch.cat([latent_means] + support_slices, dim=1) 269 | mu = self.cc_mean_transforms[slice_index](mean_support) 270 | mu = mu[:, :, :y_shape[0], :y_shape[1]] 271 | 272 | scale_support = torch.cat([latent_scales] + support_slices, dim=1) 273 | scale = self.cc_scale_transforms[slice_index](scale_support) 274 | scale = scale[:, :, :y_shape[0], :y_shape[1]] 275 | 276 | index = self.gaussian_conditional.build_indexes(scale) 277 | y_q_slice = self.gaussian_conditional.quantize(y_slice, "symbols", mu) 278 | y_hat_slice = y_q_slice + mu 279 | 280 | symbols_list.extend(y_q_slice.reshape(-1).tolist()) 281 | indexes_list.extend(index.reshape(-1).tolist()) 282 | lrp_support = torch.cat([mean_support, y_hat_slice], dim=1) 283 | lrp = self.lrp_transforms[slice_index](lrp_support) 284 | lrp = 0.5 * torch.tanh(lrp) 285 | y_hat_slice += lrp 286 | 287 | y_hat_slices.append(y_hat_slice) 288 | y_scales.append(scale) 289 | y_means.append(mu) 290 | 291 | encoder.encode_with_indexes(symbols_list, indexes_list, cdf, cdf_lengths, offsets) 292 | y_string = encoder.flush() 293 | y_strings.append(y_string) 294 | 295 | end = time.time() 296 | 297 | WZ_encode_time = torch.tensor(middle_0-start) 298 | hyper_encode_time = torch.tensor(middle_1-middle_0) 299 | z_entropy_time = torch.tensor(middle_2- middle_1) 300 | hyper_decode_time = torch.tensor(middle_3- middle_2) 301 | WZ_entropy_time = torch.tensor(end - middle_3) 302 | 303 | total_time = torch.tensor(end - start) 304 | add_time = WZ_encode_time + hyper_encode_time + z_entropy_time + hyper_decode_time+ WZ_entropy_time 305 | out = {"Encoder_WZ_encode_time": WZ_encode_time, "Encoder_hyper_encode_time": hyper_encode_time, "Encoder_z_entropy_time": z_entropy_time, 306 | "Encoder_hyper_decode_time": hyper_decode_time, "Encoder_WZ_entropy_time": WZ_entropy_time, 307 | "Encoder_add_time": add_time, "Encoder_total_time": total_time} 308 | 309 | return y_hat_slices, {"strings": [y_strings, z_strings], "shape": z.size()[-2:]}, out 310 | 311 | 312 | def decoder_recon(y_hat_slices, side): 313 | y_hat = torch.cat(y_hat_slices, dim=1) 314 | fea_side = self.encode_side(side) 315 | x_hat = self.decode_x(torch.cat((y_hat, fea_side), 1)) 316 | x_hat = x_hat.clamp_(0, 1) 317 | return {"x_hat": x_hat} 318 | 319 | 320 | def decompress(self, strings, shape, side): 321 | start = time.time() 322 | z_hat = self.entropy_bottleneck.decompress(strings[1], shape) 323 | middle_0 = time.time() 324 | latent = self.hyper_xs(z_hat) 325 | middle_1 = time.time() 326 | 327 | latent_means, latent_scales = latent.chunk(2, 1) 328 | y_shape = [z_hat.shape[2] * 4, z_hat.shape[3] * 4] 329 | y_string = strings[0][0] 330 | y_hat_slices = [] 331 | cdf = self.gaussian_conditional.quantized_cdf.tolist() 332 | cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist() 333 | offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist() 334 | 335 | decoder = RansDecoder() 336 | decoder.set_stream(y_string) 337 | 338 | for slice_index in range(self.num_slices): 339 | support_slices = (y_hat_slices[:slice_index]) 340 | mean_support = torch.cat([latent_means] + support_slices, dim=1) 341 | mu = self.cc_mean_transforms[slice_index](mean_support) 342 | mu = mu[:, :, :y_shape[0], :y_shape[1]] 343 | 344 | scale_support = torch.cat([latent_scales] + support_slices, dim=1) 345 | scale = self.cc_scale_transforms[slice_index](scale_support) 346 | scale = scale[:, :, :y_shape[0], :y_shape[1]] 347 | 348 | index = self.gaussian_conditional.build_indexes(scale) 349 | 350 | rv = decoder.decode_stream(index.reshape(-1).tolist(), cdf, cdf_lengths, offsets) 351 | rv = torch.Tensor(rv).reshape(1, -1, y_shape[0], y_shape[1]) 352 | y_hat_slice = self.gaussian_conditional.dequantize(rv, mu) 353 | 354 | lrp_support = torch.cat([mean_support, y_hat_slice], dim=1) 355 | lrp = self.lrp_transforms[slice_index](lrp_support) 356 | lrp = 0.5 * torch.tanh(lrp) 357 | y_hat_slice += lrp 358 | 359 | y_hat_slices.append(y_hat_slice) 360 | 361 | y_hat = torch.cat(y_hat_slices, dim=1) 362 | middle_2 = time.time() 363 | 364 | fea_side = self.encode_side(side) 365 | middle_3 = time.time() 366 | 367 | x_hat = self.decode_x(torch.cat((y_hat, fea_side), 1)).clamp_(0, 1) 368 | end = time.time() 369 | 370 | z_entropy_time = torch.tensor(middle_0-start) 371 | hyper_decode_time = torch.tensor(middle_1-middle_0) 372 | WZ_entropy_time = torch.tensor(middle_2- middle_1) 373 | side_time =torch.tensor( middle_3- middle_2) 374 | WZ_decode_time = torch.tensor(end - middle_3) 375 | 376 | total_time = torch.tensor(end - start) 377 | add_time = z_entropy_time + hyper_decode_time+ WZ_entropy_time + side_time + WZ_decode_time 378 | out = {"Decoder_z_entropy_time": z_entropy_time, "Decoder_hyper_decode_time": hyper_decode_time, "Decoder_WZ_entropy_time": WZ_entropy_time, 379 | "Decoder_side_time": side_time, "Decoder_WZ_decode_time": WZ_decode_time, "Decoder_add_time": add_time, "Decoder_total_time": total_time} 380 | 381 | return {"x_hat": x_hat}, out 382 | 383 | 384 | -------------------------------------------------------------------------------- /models/interpolation/IFNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .warplayer import warp 5 | from .refine import * 6 | 7 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): 8 | return nn.Sequential( 9 | torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1), 10 | nn.PReLU(out_planes) 11 | ) 12 | 13 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 14 | return nn.Sequential( 15 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 16 | padding=padding, dilation=dilation, bias=True), 17 | nn.PReLU(out_planes) 18 | ) 19 | 20 | class IFBlock(nn.Module): 21 | def __init__(self, in_planes, c=64): 22 | super(IFBlock, self).__init__() 23 | self.conv0 = nn.Sequential( 24 | conv(in_planes, c//2, 3, 2, 1), 25 | conv(c//2, c, 3, 2, 1), 26 | ) 27 | self.convblock = nn.Sequential( 28 | conv(c, c), 29 | conv(c, c), 30 | conv(c, c), 31 | conv(c, c), 32 | conv(c, c), 33 | conv(c, c), 34 | conv(c, c), 35 | conv(c, c), 36 | ) 37 | self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1) 38 | 39 | def forward(self, x, flow, scale): 40 | if scale != 1: 41 | x = F.interpolate(x, scale_factor = 1. / scale, mode="bilinear", align_corners=False) 42 | if flow != None: 43 | flow = F.interpolate(flow, scale_factor = 1. / scale, mode="bilinear", align_corners=False) * 1. / scale 44 | x = torch.cat((x, flow), 1) 45 | x = self.conv0(x) 46 | x = self.convblock(x) + x 47 | tmp = self.lastconv(x) 48 | tmp = F.interpolate(tmp, scale_factor = scale * 2, mode="bilinear", align_corners=False) 49 | flow = tmp[:, :4] * scale * 2 50 | mask = tmp[:, 4:5] 51 | return flow, mask 52 | 53 | class IFNet(nn.Module): 54 | def __init__(self): 55 | super(IFNet, self).__init__() 56 | self.block0 = IFBlock(6, c=240) 57 | self.block1 = IFBlock(13+4, c=150) 58 | self.block2 = IFBlock(13+4, c=90) 59 | self.block_tea = IFBlock(16+4, c=90) 60 | self.contextnet = Contextnet() 61 | self.unet = Unet() 62 | 63 | def forward(self, x, scale=[4,2,1], timestep=0.5): 64 | img0 = x[:, :3] 65 | img1 = x[:, 3:6] 66 | gt = x[:, 6:] # In inference time, gt is None 67 | flow_list = [] 68 | merged = [] 69 | mask_list = [] 70 | warped_img0 = img0 71 | warped_img1 = img1 72 | flow = None 73 | loss_distill = 0 74 | stu = [self.block0, self.block1, self.block2] 75 | for i in range(3): 76 | if flow != None: 77 | flow_d, mask_d = stu[i](torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow, scale=scale[i]) 78 | flow = flow + flow_d 79 | mask = mask + mask_d 80 | else: 81 | flow, mask = stu[i](torch.cat((img0, img1), 1), None, scale=scale[i]) 82 | mask_list.append(torch.sigmoid(mask)) 83 | flow_list.append(flow) 84 | warped_img0 = warp(img0, flow[:, :2]) 85 | warped_img1 = warp(img1, flow[:, 2:4]) 86 | merged_student = (warped_img0, warped_img1) 87 | merged.append(merged_student) 88 | if gt.shape[1] == 3: 89 | flow_d, mask_d = self.block_tea(torch.cat((img0, img1, warped_img0, warped_img1, mask, gt), 1), flow, scale=1) 90 | flow_teacher = flow + flow_d 91 | warped_img0_teacher = warp(img0, flow_teacher[:, :2]) 92 | warped_img1_teacher = warp(img1, flow_teacher[:, 2:4]) 93 | mask_teacher = torch.sigmoid(mask + mask_d) 94 | merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher) 95 | else: 96 | flow_teacher = None 97 | merged_teacher = None 98 | for i in range(3): 99 | merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) 100 | if gt.shape[1] == 3: 101 | loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01).float().detach() 102 | loss_distill += ((flow_teacher.detach() - flow_list[i]).abs() * loss_mask).mean() 103 | c0 = self.contextnet(img0, flow[:, :2]) 104 | c1 = self.contextnet(img1, flow[:, 2:4]) 105 | tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) 106 | res = tmp[:, :3] * 2 - 1 107 | merged[2] = torch.clamp(merged[2] + res, 0, 1) 108 | return flow_list, mask_list[2], merged, flow_teacher, merged_teacher, loss_distill 109 | -------------------------------------------------------------------------------- /models/interpolation/IFNet_m.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .warplayer import warp 5 | from .refine import * 6 | from deepspeed.profiling.flops_profiler import get_model_profile 7 | from ptflops import get_model_complexity_info 8 | 9 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): 10 | return nn.Sequential( 11 | torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1), 12 | nn.PReLU(out_planes) 13 | ) 14 | 15 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 16 | return nn.Sequential( 17 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 18 | padding=padding, dilation=dilation, bias=True), 19 | nn.PReLU(out_planes) 20 | ) 21 | 22 | class IFBlock(nn.Module): 23 | def __init__(self, in_planes, c=64): 24 | super(IFBlock, self).__init__() 25 | self.conv0 = nn.Sequential( 26 | conv(in_planes, c//2, 3, 2, 1), 27 | conv(c//2, c, 3, 2, 1), 28 | ) 29 | self.convblock = nn.Sequential( 30 | conv(c, c), 31 | conv(c, c), 32 | conv(c, c), 33 | conv(c, c), 34 | conv(c, c), 35 | conv(c, c), 36 | conv(c, c), 37 | conv(c, c), 38 | ) 39 | self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1) 40 | 41 | def forward(self, x, flow, scale): 42 | if scale != 1: 43 | x = F.interpolate(x, scale_factor = 1. / scale, mode="bilinear", align_corners=False) 44 | if flow != None: 45 | flow = F.interpolate(flow, scale_factor = 1. / scale, mode="bilinear", align_corners=False) * 1. / scale 46 | x = torch.cat((x, flow), 1) 47 | x = self.conv0(x) 48 | x = self.convblock(x) + x 49 | tmp = self.lastconv(x) 50 | tmp = F.interpolate(tmp, scale_factor = scale * 2, mode="bilinear", align_corners=False) 51 | flow = tmp[:, :4] * scale * 2 52 | mask = tmp[:, 4:5] 53 | return flow, mask 54 | 55 | class IFNet_m(nn.Module): 56 | def __init__(self): 57 | super(IFNet_m, self).__init__() 58 | self.block0 = IFBlock(6+1, c=240) 59 | self.block1 = IFBlock(13+4+1, c=150) 60 | self.block2 = IFBlock(13+4+1, c=90) 61 | self.block_tea = IFBlock(16+4+1, c=90) 62 | self.contextnet = Contextnet() 63 | self.unet = Unet() 64 | 65 | def forward(self, x, scale=[4,2,1], timestep=0.5): 66 | timestep = (x[:, :1].clone() * 0 + 1) * timestep 67 | img0 = x[:, :3] 68 | img1 = x[:, 3:6] 69 | gt = x[:, 6:] # In inference time, gt is None 70 | flow_list = [] 71 | merged = [] 72 | mask_list = [] 73 | warped_img0 = img0 74 | warped_img1 = img1 75 | flow = None 76 | loss_distill = 0 77 | stu = [self.block0, self.block1, self.block2] 78 | for i in range(3): 79 | if flow != None: 80 | flow_d, mask_d = stu[i](torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask), 1), flow, scale=scale[i]) 81 | flow = flow + flow_d 82 | mask = mask + mask_d 83 | else: 84 | flow, mask = stu[i](torch.cat((img0, img1, timestep), 1), None, scale=scale[i]) 85 | mask_list.append(torch.sigmoid(mask)) 86 | flow_list.append(flow) 87 | warped_img0 = warp(img0, flow[:, :2]) 88 | warped_img1 = warp(img1, flow[:, 2:4]) 89 | merged_student = (warped_img0, warped_img1) 90 | merged.append(merged_student) 91 | if gt.shape[1] == 3: 92 | flow_d, mask_d = self.block_tea(torch.cat((img0, img1, timestep, warped_img0, warped_img1, mask, gt), 1), flow, scale=1) 93 | flow_teacher = flow + flow_d 94 | warped_img0_teacher = warp(img0, flow_teacher[:, :2]) 95 | warped_img1_teacher = warp(img1, flow_teacher[:, 2:4]) 96 | mask_teacher = torch.sigmoid(mask + mask_d) 97 | merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher) 98 | else: 99 | flow_teacher = None 100 | merged_teacher = None 101 | for i in range(3): 102 | merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) 103 | if gt.shape[1] == 3: 104 | loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01).float().detach() 105 | loss_distill += ((flow_teacher.detach() - flow_list[i]).abs() * loss_mask).mean() 106 | c0 = self.contextnet(img0, flow[:, :2]) 107 | c1 = self.contextnet(img1, flow[:, 2:4]) 108 | tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) 109 | res = tmp[:, :3] * 2 - 1 110 | merged[2] = torch.clamp(merged[2] + res, 0, 1) 111 | return flow_list, mask_list[2], merged, flow_teacher, merged_teacher, loss_distill 112 | 113 | if __name__ == '__main__': 114 | encoder = IFNet_m() 115 | H, W= 1920, 1024 116 | flops, macs, params = get_model_profile(encoder, input_shape=(1, 6, H, W), print_profile=False, detailed=False, warm_up=10, as_string=True,) 117 | print("Encoder - flops:{}, macs:{}, params:{}".format(flops, macs, params)) 118 | #macs, params = get_model_complexity_info(encoder, (6, H, W), as_strings=True, print_per_layer_stat=False) 119 | #print("Encoder - macs:{}, params:{}".format(macs, params)) -------------------------------------------------------------------------------- /models/interpolation/RIFE.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torch.optim import AdamW 5 | import torch.optim as optim 6 | import itertools 7 | from model.warplayer import warp 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | from model.IFNet import * 10 | from model.IFNet_m import * 11 | import torch.nn.functional as F 12 | from model.loss import * 13 | from model.laplacian import * 14 | from model.refine import * 15 | 16 | device = torch.device("cuda") 17 | 18 | class Model: 19 | def __init__(self, local_rank=-1, arbitrary=False): 20 | if arbitrary == True: 21 | self.flownet = IFNet_m() 22 | else: 23 | self.flownet = IFNet() 24 | self.device() 25 | self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-3) # use large weight decay may avoid NaN loss 26 | self.epe = EPE() 27 | self.lap = LapLoss() 28 | self.sobel = SOBEL() 29 | if local_rank != -1: 30 | self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) 31 | 32 | def train(self): 33 | self.flownet.train() 34 | 35 | def eval(self): 36 | self.flownet.eval() 37 | 38 | def device(self): 39 | self.flownet.to(device) 40 | 41 | def load_model(self, path, rank=0): 42 | def convert(param): 43 | return { 44 | k.replace("module.", ""): v 45 | for k, v in param.items() 46 | if "module." in k 47 | } 48 | 49 | if rank <= 0: 50 | self.flownet.load_state_dict(convert(torch.load('{}/flownet.pkl'.format(path)))) 51 | 52 | def save_model(self, path, rank=0): 53 | if rank == 0: 54 | torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path)) 55 | 56 | def inference(self, img0, img1, scale_list=[4, 2, 1], TTA=False, timestep=0.5): 57 | imgs = torch.cat((img0, img1), 1) 58 | flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list, timestep=timestep) 59 | if TTA == False: 60 | return merged[2] 61 | else: 62 | flow2, mask2, merged2, flow_teacher2, merged_teacher2, loss_distill2 = self.flownet(imgs.flip(2).flip(3), scale_list, timestep=timestep) 63 | return (merged[2] + merged2[2].flip(2).flip(3)) / 2 64 | 65 | def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): 66 | for param_group in self.optimG.param_groups: 67 | param_group['lr'] = learning_rate 68 | img0 = imgs[:, :3] 69 | img1 = imgs[:, 3:] 70 | if training: 71 | self.train() 72 | else: 73 | self.eval() 74 | flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(torch.cat((imgs, gt), 1), scale=[4, 2, 1]) 75 | loss_l1 = (self.lap(merged[2], gt)).mean() 76 | loss_tea = (self.lap(merged_teacher, gt)).mean() 77 | if training: 78 | self.optimG.zero_grad() 79 | loss_G = loss_l1 + loss_tea + loss_distill * 0.01 80 | loss_G.backward() 81 | self.optimG.step() 82 | else: 83 | flow_teacher = flow[2] 84 | return merged[2], { 85 | 'merged_tea': merged_teacher, 86 | 'mask': mask, 87 | 'mask_tea': mask, 88 | 'flow': flow[2][:, :2], 89 | 'flow_tea': flow_teacher, 90 | 'loss_l1': loss_l1, 91 | 'loss_tea': loss_tea, 92 | 'loss_distill': loss_distill, 93 | } 94 | -------------------------------------------------------------------------------- /models/interpolation/refine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.optim as optim 5 | import itertools 6 | from .warplayer import warp 7 | import torch.nn.functional as F 8 | 9 | 10 | def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): 11 | return nn.Sequential( 12 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 13 | padding=padding, dilation=dilation, bias=True), 14 | nn.PReLU(out_planes) 15 | ) 16 | 17 | def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): 18 | return nn.Sequential( 19 | torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), 20 | nn.PReLU(out_planes) 21 | ) 22 | 23 | class Conv2(nn.Module): 24 | def __init__(self, in_planes, out_planes, stride=2): 25 | super(Conv2, self).__init__() 26 | self.conv1 = conv(in_planes, out_planes, 3, stride, 1) 27 | self.conv2 = conv(out_planes, out_planes, 3, 1, 1) 28 | 29 | def forward(self, x): 30 | x = self.conv1(x) 31 | x = self.conv2(x) 32 | return x 33 | 34 | c = 16 35 | class Contextnet(nn.Module): 36 | def __init__(self): 37 | super(Contextnet, self).__init__() 38 | self.conv1 = Conv2(3, c) 39 | self.conv2 = Conv2(c, 2*c) 40 | self.conv3 = Conv2(2*c, 4*c) 41 | self.conv4 = Conv2(4*c, 8*c) 42 | 43 | def forward(self, x, flow): 44 | x = self.conv1(x) 45 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 46 | f1 = warp(x, flow) 47 | x = self.conv2(x) 48 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 49 | f2 = warp(x, flow) 50 | x = self.conv3(x) 51 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 52 | f3 = warp(x, flow) 53 | x = self.conv4(x) 54 | flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 55 | f4 = warp(x, flow) 56 | return [f1, f2, f3, f4] 57 | 58 | class Unet(nn.Module): 59 | def __init__(self): 60 | super(Unet, self).__init__() 61 | self.down0 = Conv2(17, 2*c) 62 | self.down1 = Conv2(4*c, 4*c) 63 | self.down2 = Conv2(8*c, 8*c) 64 | self.down3 = Conv2(16*c, 16*c) 65 | self.up0 = deconv(32*c, 8*c) 66 | self.up1 = deconv(16*c, 4*c) 67 | self.up2 = deconv(8*c, 2*c) 68 | self.up3 = deconv(4*c, c) 69 | self.conv = nn.Conv2d(c, 3, 3, 1, 1) 70 | 71 | def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1): 72 | s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1)) 73 | s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) 74 | s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) 75 | s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) 76 | x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) 77 | x = self.up1(torch.cat((x, s2), 1)) 78 | x = self.up2(torch.cat((x, s1), 1)) 79 | x = self.up3(torch.cat((x, s0), 1)) 80 | x = self.conv(x) 81 | return torch.sigmoid(x) 82 | -------------------------------------------------------------------------------- /models/interpolation/warplayer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | #device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 5 | backwarp_tenGrid = {} 6 | 7 | 8 | def warp(tenInput, tenFlow): 9 | device = tenInput.device 10 | k = (str(device), str(tenFlow.size())) 11 | if k not in backwarp_tenGrid: 12 | tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device).view( 13 | 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) 14 | tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device).view( 15 | 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 16 | backwarp_tenGrid[k] = torch.cat( 17 | [tenHorizontal, tenVertical], 1).to(device) 18 | 19 | tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), 20 | tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) 21 | 22 | g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) 23 | return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True) 24 | -------------------------------------------------------------------------------- /models/interpolation_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import itertools 5 | from torch.nn.parallel import DistributedDataParallel as DDP 6 | import torch.nn.functional as F 7 | from .interpolation.refine import * 8 | from .interpolation.warplayer import warp 9 | from .interpolation.IFNet import * 10 | from .interpolation.IFNet_m import * 11 | import torch.optim as optim 12 | import os 13 | 14 | class VideoInterpolationNet: 15 | def __init__(self, args, local_rank=-1, arbitrary=False, finetune=False): 16 | if arbitrary == True: 17 | self.flownet = IFNet_m() 18 | else: 19 | self.flownet = IFNet() 20 | if local_rank != -1: 21 | self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) 22 | 23 | self.finetune = finetune 24 | if self.finetune: 25 | self.optimG = AdamW(self.flownet.parameters(), lr=args.flow_learning_rate, weight_decay=1e-3) 26 | self.lr_schedulerG = optim.lr_schedulerG.ReduceLROnPlateau(self.optimG, "min", patience=10, factor=0.2) 27 | 28 | def train(self): 29 | self.flownet.train() 30 | 31 | def eval(self): 32 | self.flownet.eval() 33 | 34 | def device(self, device): 35 | self.flownet.to(device) 36 | 37 | def load_model(self, path, rank=0, map_location="cpu"): 38 | def convert(param): 39 | return { 40 | k.replace("module.", ""): v 41 | for k, v in param.items() 42 | if "module." in k 43 | } 44 | if 'flownet.pkl' in str(path): 45 | checkpoint = convert(torch.load(path, map_location=map_location)) 46 | else: 47 | checkpoint = torch.load(path, map_location=map_location) 48 | self.flownet.load_state_dict(checkpoint) 49 | 50 | def save_model(self, path, is_best): 51 | if is_best: 52 | checkpoint = os.path.join(path, "RIFEflow_best.pth.tar") 53 | else: 54 | checkpoint = os.path.join(path, "RIFEflow.pth.tar") 55 | torch.save(self.flownet.state_dict(), checkpoint) 56 | 57 | def freeze_model(self): 58 | for param in self.flownet.parameters(): 59 | param.requires_grad = False 60 | 61 | ''' 62 | def inference(self, img0, img1, scale_list=[4, 2, 1], TTA=False, timestep=0.5): 63 | imgs = torch.cat((img0, img1), 1) 64 | flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list, timestep=timestep) 65 | if TTA == False: 66 | return merged[2] 67 | else: 68 | flow2, mask2, merged2, flow_teacher2, merged_teacher2, loss_distill2 = self.flownet(imgs.flip(2).flip(3), scale_list, timestep=timestep) 69 | return (merged[2] + merged2[2].flip(2).flip(3)) / 2 70 | ''' 71 | 72 | def inference(self, img0, img1, scale_list=[4, 2, 1], timestep=0.5): 73 | imgs = torch.cat((img0, img1), 1) 74 | flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list, timestep=timestep) 75 | return merged[2] 76 | 77 | 78 | 79 | def coding_inference(self, img0, img1, scale_list=[4, 2, 1], TTA=False, timestep=0.5): 80 | h, w = img0.size(2), img0.size(3) 81 | p = 64 # maximum 6 strides of 2 82 | new_h = (h + p - 1) // p * p 83 | new_w = (w + p - 1) // p * p 84 | padding_left = (new_w - w) // 2 85 | padding_right = new_w - w - padding_left 86 | padding_top = (new_h - h) // 2 87 | padding_bottom = new_h - h - padding_top 88 | img0_padded = F.pad( 89 | img0, 90 | (padding_left, padding_right, padding_top, padding_bottom), 91 | mode="constant", 92 | value=0, 93 | ) 94 | img1_padded = F.pad( 95 | img1, 96 | (padding_left, padding_right, padding_top, padding_bottom), 97 | mode="constant", 98 | value=0, 99 | ) 100 | imgs = torch.cat((img0_padded, img1_padded), 1) 101 | flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list, timestep=timestep) 102 | if TTA == False: 103 | final_img = F.pad(merged[2], (-padding_left, -padding_right, -padding_top, -padding_bottom)) 104 | return final_img 105 | else: 106 | flow2, mask2, merged2, flow_teacher2, merged_teacher2, loss_distill2 = self.flownet(imgs.flip(2).flip(3), scale_list, timestep=timestep) 107 | final_img = F.pad((merged[2] + merged2[2].flip(2).flip(3)) / 2, (-padding_left, -padding_right, -padding_top, -padding_bottom)) 108 | return final_img 109 | 110 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch 4 | import math 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | from math import exp 8 | 9 | class DVCLoss(nn.Module): 10 | """Custom rate distortion loss with a Lagrangian parameter.""" 11 | 12 | def __init__(self): 13 | super().__init__() 14 | self.mse = nn.MSELoss() 15 | 16 | def forward(self, output, target, lmbda): 17 | N, _, H, W = target.size() 18 | out = {} 19 | num_pixels = N * H * W 20 | 21 | out["bpp_loss"] = sum((torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)) for likelihoods in output["likelihoods"].values()) 22 | 23 | out["mse_loss"] = self.mse(output["x_hat"], target) 24 | if out["mse_loss"].item() > 0: 25 | psnr = 10 * (np.log(1 * 1 / out["mse_loss"].item()) / np.log(10)) 26 | if math.isinf(psnr): 27 | psnr = np.zeros(1) 28 | else: 29 | psnr = np.zeros(1) 30 | out["psnr"] = psnr 31 | 32 | out["loss"] = lmbda * 255**2 * out["mse_loss"] + out["bpp_loss"] 33 | return out 34 | 35 | 36 | class DVC_MS_SSIM_Loss(nn.Module): 37 | def __init__(self, device, size_average=True, max_val=1): 38 | super().__init__() 39 | self.ms_ssim = MS_SSIM(size_average, max_val).to(device) 40 | 41 | def forward(self, output, target, lmbda): 42 | N, _, H, W = target.size() 43 | out = {} 44 | num_pixels = N * H * W 45 | 46 | out["bpp_loss"] = sum((torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)) for likelihoods in output["likelihoods"].values()) 47 | out["ms_ssim_loss"] = 1 - self.ms_ssim(output["x_hat"], target) 48 | 49 | if out["ms_ssim_loss"] > 0: 50 | ms_ssim_db = 10 * (np.log(1 * 1 / out["ms_ssim_loss"].item()) / np.log(10)) 51 | else: 52 | ms_ssim_db = np.zeros(1) 53 | out["ms_ssim_db"] = ms_ssim_db 54 | 55 | out["loss"] = lmbda * out["ms_ssim_loss"] + out["bpp_loss"] 56 | return out 57 | 58 | 59 | def gaussian(window_size, sigma): 60 | gauss = torch.Tensor( 61 | [exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 62 | return gauss/gauss.sum() 63 | 64 | 65 | def create_window(window_size, sigma, channel): 66 | _1D_window = gaussian(window_size, sigma).unsqueeze(1) 67 | _2D_window = _1D_window.mm( 68 | _1D_window.t()).float().unsqueeze(0).unsqueeze(0) 69 | window = Variable(_2D_window.expand( 70 | channel, 1, window_size, window_size).contiguous()) 71 | return window 72 | 73 | 74 | class MS_SSIM(nn.Module): 75 | def __init__(self, size_average=True, max_val=255, device_id=0): 76 | super(MS_SSIM, self).__init__() 77 | self.size_average = size_average 78 | self.channel = 3 79 | self.max_val = max_val 80 | self.device_id = device_id 81 | 82 | def _ssim(self, img1, img2): 83 | 84 | _, c, w, h = img1.size() 85 | window_size = min(w, h, 11) 86 | sigma = 1.5 * window_size / 11 87 | 88 | window = create_window(window_size, sigma, self.channel) 89 | if self.device_id != None: 90 | window = window.cuda(self.device_id) 91 | 92 | mu1 = F.conv2d(img1, window, padding=window_size // 93 | 2, groups=self.channel) 94 | mu2 = F.conv2d(img2, window, padding=window_size // 95 | 2, groups=self.channel) 96 | 97 | mu1_sq = mu1.pow(2) 98 | mu2_sq = mu2.pow(2) 99 | mu1_mu2 = mu1*mu2 100 | 101 | sigma1_sq = F.conv2d( 102 | img1*img1, window, padding=window_size//2, groups=self.channel) - mu1_sq 103 | sigma2_sq = F.conv2d( 104 | img2*img2, window, padding=window_size//2, groups=self.channel) - mu2_sq 105 | sigma12 = F.conv2d(img1*img2, window, padding=window_size // 106 | 2, groups=self.channel) - mu1_mu2 107 | 108 | C1 = (0.01*self.max_val)**2 109 | C2 = (0.03*self.max_val)**2 110 | V1 = 2.0 * sigma12 + C2 111 | V2 = sigma1_sq + sigma2_sq + C2 112 | ssim_map = ((2*mu1_mu2 + C1)*V1)/((mu1_sq + mu2_sq + C1)*V2) 113 | mcs_map = V1 / V2 114 | if self.size_average: 115 | return ssim_map.mean(), mcs_map.mean() 116 | 117 | def ms_ssim(self, img1, img2, levels=5): 118 | 119 | weight = Variable(torch.Tensor([0.0448, 0.2856, 0.3001, 0.2363, 0.1333])) 120 | msssim=Variable(torch.Tensor(levels,)) 121 | mcs=Variable(torch.Tensor(levels,)) 122 | # if self.device_id != None: 123 | # weight = weight.cuda(self.device_id) 124 | # weight = msssim.cuda(self.device_id) 125 | # weight = mcs.cuda(self.device_id) 126 | # print(weight.device) 127 | 128 | for i in range(levels): 129 | ssim_map, mcs_map=self._ssim(img1, img2) 130 | msssim[i]=ssim_map 131 | mcs[i]=mcs_map 132 | filtered_im1=F.avg_pool2d(img1, kernel_size=2, stride=2) 133 | filtered_im2=F.avg_pool2d(img2, kernel_size=2, stride=2) 134 | img1=filtered_im1 135 | img2=filtered_im2 136 | 137 | value=(torch.prod(mcs[0:levels-1]**weight[0:levels-1]) * 138 | (msssim[levels-1]**weight[levels-1])) 139 | return value 140 | 141 | 142 | def forward(self, img1, img2, levels=5): 143 | return self.ms_ssim(img1, img2, levels) 144 | 145 | -------------------------------------------------------------------------------- /models/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 InterDigital Communications, Inc. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch import Tensor 18 | 19 | def ste_round(x: Tensor) -> Tensor: 20 | """ 21 | Rounding with non-zero gradients. Gradients are approximated by replacing 22 | the derivative by the identity function. 23 | 24 | Used in `"Lossy Image Compression with Compressive Autoencoders" 25 | `_ 26 | 27 | .. note:: 28 | 29 | Implemented with the pytorch `detach()` reparametrization trick: 30 | 31 | `x_round = x_round - x.detach() + x` 32 | """ 33 | return torch.round(x) - x.detach() + x 34 | 35 | 36 | def conv1x1(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module: 37 | """1x1 convolution.""" 38 | return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride) 39 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import argparse 4 | import json 5 | import math 6 | import sys 7 | import struct 8 | import time 9 | 10 | from collections import defaultdict 11 | from pathlib import Path 12 | from typing import Any, Dict, List, Tuple, Union 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | 19 | from pytorch_msssim import ms_ssim 20 | from torch import Tensor 21 | from torch.cuda import amp 22 | from torch.utils.model_zoo import tqdm 23 | import compressai 24 | from compressai.datasets import RawVideoSequence, VideoFormat 25 | 26 | from compressai.transforms.functional import ( 27 | rgb2ycbcr, 28 | ycbcr2rgb, 29 | yuv_420_to_444, 30 | yuv_444_to_420, 31 | ) 32 | 33 | from compressai.zoo.pretrained import load_pretrained 34 | from models import * 35 | from torch.hub import load_state_dict_from_url 36 | Frame = Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, ...]] 37 | 38 | 39 | def collect_videos(rootpath: str) -> List[str]: 40 | video_files = [] 41 | 42 | if 'UVG' in rootpath: 43 | video_files.extend(Path(rootpath).glob(f"1024/*.yuv")) #f"*/*{ext}" 44 | elif 'MCL_JCV' in rootpath: 45 | video_files.extend(sorted(Path(rootpath).glob("1024/*.yuv"))) 46 | 47 | return video_files 48 | 49 | 50 | # TODO (racapef) duplicate from bench 51 | def to_tensors( 52 | frame: Tuple[np.ndarray, np.ndarray, np.ndarray], 53 | max_value: int = 1, 54 | device: str = "cpu", 55 | ) -> Frame: 56 | return tuple( 57 | torch.from_numpy(np.true_divide(c, max_value, dtype=np.float32)).to(device) 58 | for c in frame 59 | ) 60 | 61 | def aggregate_results(filepaths: List[Path]) -> Dict[str, Any]: 62 | metrics = defaultdict(list) 63 | 64 | # sum 65 | for f in filepaths: 66 | with f.open("r") as fd: 67 | data = json.load(fd) 68 | for k, v in data["results"].items(): 69 | metrics[k].append(v) 70 | 71 | # normalize 72 | agg = {k: np.mean(v) for k, v in metrics.items()} 73 | return agg 74 | 75 | def convert_yuv420_to_rgb( 76 | frame: Tuple[np.ndarray, np.ndarray, np.ndarray], device: torch.device, max_val: int 77 | ) -> Tensor: 78 | # yuv420 [0, 2**bitdepth-1] to rgb 444 [0, 1] only for now 79 | out = to_tensors(frame, device=str(device), max_value=max_val) 80 | out = yuv_420_to_444( 81 | tuple(c.unsqueeze(0).unsqueeze(0) for c in out), mode="bicubic" # type: ignore 82 | ) 83 | return ycbcr2rgb(out) # type: ignore 84 | 85 | 86 | def convert_rgb_to_yuv420(frame: Tensor) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 87 | # yuv420 [0, 2**bitdepth-1] to rgb 444 [0, 1] only for now 88 | return yuv_444_to_420(rgb2ycbcr(frame), mode="avg_pool") 89 | 90 | 91 | def pad(x: Tensor, p: int = 2 ** (4 + 3)) -> Tuple[Tensor, Tuple[int, ...]]: 92 | h, w = x.size(2), x.size(3) 93 | new_h = (h + p - 1) // p * p 94 | new_w = (w + p - 1) // p * p 95 | padding_left = (new_w - w) // 2 96 | padding_right = new_w - w - padding_left 97 | padding_top = (new_h - h) // 2 98 | padding_bottom = new_h - h - padding_top 99 | padding = (padding_left, padding_right, padding_top, padding_bottom) 100 | x = F.pad(x, padding, mode="replicate") 101 | return x, padding 102 | 103 | 104 | def crop(x: Tensor, padding: Tuple[int, ...]) -> Tensor: 105 | return F.pad(x, tuple(-p for p in padding)) 106 | 107 | 108 | def filesize(filepath: str) -> int: 109 | if not Path(filepath).is_file(): 110 | raise ValueError(f'Invalid file "{filepath}".') 111 | return Path(filepath).stat().st_size 112 | 113 | 114 | def write_uints(fd, values, fmt=">{:d}I"): 115 | fd.write(struct.pack(fmt.format(len(values)), *values)) 116 | return len(values) * 4 117 | 118 | 119 | def write_uchars(fd, values, fmt=">{:d}B"): 120 | fd.write(struct.pack(fmt.format(len(values)), *values)) 121 | return len(values) * 1 122 | 123 | 124 | def read_uints(fd, n, fmt=">{:d}I"): 125 | sz = struct.calcsize("I") 126 | return struct.unpack(fmt.format(n), fd.read(n * sz)) 127 | 128 | 129 | def read_uchars(fd, n, fmt=">{:d}B"): 130 | sz = struct.calcsize("B") 131 | return struct.unpack(fmt.format(n), fd.read(n * sz)) 132 | 133 | 134 | def write_bytes(fd, values, fmt=">{:d}s"): 135 | if len(values) == 0: 136 | return 137 | fd.write(struct.pack(fmt.format(len(values)), values)) 138 | return len(values) * 1 139 | 140 | 141 | def read_bytes(fd, n, fmt=">{:d}s"): 142 | sz = struct.calcsize("s") 143 | return struct.unpack(fmt.format(n), fd.read(n * sz))[0] 144 | 145 | 146 | def read_body(fd): 147 | lstrings = [] 148 | shape = read_uints(fd, 2) 149 | n_strings = read_uints(fd, 1)[0] 150 | for _ in range(n_strings): 151 | s = read_bytes(fd, read_uints(fd, 1)[0]) 152 | lstrings.append([s]) 153 | 154 | return lstrings, shape 155 | 156 | 157 | def write_body(fd, shape, out_strings): 158 | bytes_cnt = 0 159 | bytes_cnt = write_uints(fd, (shape[0], shape[1], len(out_strings))) 160 | for s in out_strings: 161 | bytes_cnt += write_uints(fd, (len(s[0]),)) 162 | bytes_cnt += write_bytes(fd, s[0]) 163 | return bytes_cnt 164 | 165 | 166 | def compute_metrics_for_frame( 167 | org_frame: Frame, 168 | rec_frame: Tensor, 169 | device: str = "cpu", 170 | max_val: int = 255, 171 | index: int = 1, 172 | ) -> Dict[str, Any]: 173 | out: Dict[str, Any] = {} 174 | 175 | # YCbCr metrics 176 | org_yuv = to_tensors(org_frame, device=str(device), max_value=max_val) 177 | org_yuv = tuple(p.unsqueeze(0).unsqueeze(0) for p in org_yuv) # type: ignore 178 | rec_yuv = convert_rgb_to_yuv420(rec_frame) 179 | for i, component in enumerate("yuv"): 180 | org = (org_yuv[i] * max_val).clamp(0, max_val).round() 181 | rec = (rec_yuv[i] * max_val).clamp(0, max_val).round() 182 | out[f"psnr-{component}"] = 20 * np.log10(max_val) - 10 * torch.log10( 183 | (org - rec).pow(2).mean() 184 | ) 185 | out["psnr-yuv"] = (4 * out["psnr-y"] + out["psnr-u"] + out["psnr-v"]) / 6 186 | 187 | # RGB metrics 188 | org_rgb = convert_yuv420_to_rgb( 189 | org_frame, device, max_val 190 | ) # ycbcr2rgb(yuv_420_to_444(org_frame, mode="bicubic")) # type: ignore 191 | org_rgb = (org_rgb * max_val).clamp(0, max_val).round() 192 | rec_frame = (rec_frame * max_val).clamp(0, max_val).round() 193 | mse_rgb = (org_rgb - rec_frame).pow(2).mean() 194 | psnr_rgb = 20 * np.log10(max_val) - 10 * torch.log10(mse_rgb) 195 | 196 | ms_ssim_rgb = ms_ssim(org_rgb, rec_frame, data_range=max_val) 197 | out.update({"ms-ssim-rgb": ms_ssim_rgb, "mse-rgb": mse_rgb, "psnr-rgb": psnr_rgb}) 198 | return out 199 | 200 | 201 | 202 | def compute_si_metrics_for_frame( 203 | org_frame: Frame, 204 | rec_frame: Tensor, 205 | device: str = "cpu", 206 | max_val: int = 255, 207 | ) -> Dict[str, Any]: 208 | out: Dict[str, Any] = {} 209 | 210 | # RGB metrics 211 | org_rgb = convert_yuv420_to_rgb( 212 | org_frame, device, max_val 213 | ) # ycbcr2rgb(yuv_420_to_444(org_frame, mode="bicubic")) # type: ignore 214 | org_rgb = (org_rgb * max_val).clamp(0, max_val).round() 215 | rec_frame = (rec_frame * max_val).clamp(0, max_val).round() 216 | mse_rgb = (org_rgb - rec_frame).pow(2).mean() 217 | psnr_rgb = 20 * np.log10(max_val) - 10 * torch.log10(mse_rgb) 218 | 219 | ms_ssim_rgb = ms_ssim(org_rgb, rec_frame, data_range=max_val) 220 | out.update({"si-ms-ssim-rgb": ms_ssim_rgb, "si-mse-rgb": mse_rgb, "si-psnr-rgb": psnr_rgb}) 221 | 222 | return out 223 | 224 | def estimate_bits_frame(likelihoods) -> float: 225 | bpp = sum( 226 | (torch.log(lkl[k]).sum() / (-math.log(2))) 227 | for lkl in likelihoods.values() 228 | for k in ("y", "z") 229 | ) 230 | return bpp 231 | 232 | def compute_bpp(likelihoods, num_pixels: int) -> float: 233 | bits_per_frame = sum( 234 | (torch.log(lkl).sum() / (-math.log(2))) 235 | for lkl in likelihoods.values() 236 | ) 237 | bpp = bits_per_frame / num_pixels 238 | return bits_per_frame, bpp 239 | 240 | @torch.no_grad() 241 | def eval_model(interpolation_net, BFrameCompressor:nn.Module, IFrameCompressor:nn.Module, 242 | sequence: Path, binpath: Path, **args: Any) -> Dict[str, Any]: 243 | import time 244 | org_seq = RawVideoSequence.from_file(str(sequence)) 245 | 246 | if org_seq.format != VideoFormat.YUV420: 247 | raise NotImplementedError(f"Unsupported video format: {org_seq.format}") 248 | 249 | device = next(BFrameCompressor.parameters()).device 250 | max_val = 2**org_seq.bitdepth - 1 251 | results = defaultdict(list) 252 | keep_binaries = args["keep_binaries"] 253 | num_frames = args["vframes"] 254 | num_gop = args["GOP"] 255 | frame_arbitrary = args["frame_arbitrary"] 256 | with_interpolation = args["with_interpolation"] 257 | num_pixels = org_seq.height * org_seq.width 258 | print("frame rate:", org_seq.framerate) 259 | intra = args["intra"] 260 | 261 | if with_interpolation and not frame_arbitrary: 262 | frames_idx_list, ref_idx_dict = specific_frame_structure(num_gop) 263 | reconstructions = [] 264 | 265 | f = binpath.open("wb") 266 | 267 | print(f" encoding {sequence.stem}", file=sys.stderr) 268 | # write original image size 269 | write_uints(f, (org_seq.height, org_seq.width)) 270 | # write original bitdepth 271 | write_uchars(f, (org_seq.bitdepth,)) 272 | # write number of coded frames 273 | write_uints(f, (num_frames,)) 274 | with tqdm(total=num_frames) as pbar: 275 | for i in range(num_frames): 276 | x_cur = convert_yuv420_to_rgb(org_seq[i], device, max_val) 277 | x_cur, padding = pad(x_cur) 278 | 279 | if i % num_gop == 0: 280 | start = time.time() 281 | enc_info = IFrameCompressor.compress(x_cur) 282 | enc_time = time.time() - start 283 | write_body(f, enc_info["shape"], enc_info["strings"]) 284 | start = time.time() 285 | x_rec = IFrameCompressor.decompress(enc_info["strings"], enc_info["shape"])["x_hat"] 286 | dec_time = time.time() - start 287 | 288 | first_rec = x_rec 289 | last_key_frame = convert_yuv420_to_rgb(org_seq[i+num_gop], device, max_val) 290 | last_key_frame, _ = pad(last_key_frame) 291 | last_enc_info = IFrameCompressor.compress(last_key_frame) 292 | last_x_rec = IFrameCompressor.decompress(last_enc_info["strings"], last_enc_info["shape"])["x_hat"] 293 | reconstructions = [] 294 | reconstructions.append(x_rec) 295 | else: 296 | if with_interpolation: 297 | cur_interpolation_idx = frames_idx_list[i%num_gop-1] 298 | left_ref_idx, right_ref_idx = ref_idx_dict[cur_interpolation_idx] 299 | if left_ref_idx == 0: 300 | left_x_rec = first_rec 301 | else: 302 | cur_pos_in_frame_idx_list = frames_idx_list.index(left_ref_idx) 303 | left_x_rec = reconstructions[cur_pos_in_frame_idx_list+1] 304 | 305 | if right_ref_idx == num_gop: 306 | right_x_rec = last_x_rec 307 | else: 308 | cur_pos_in_frame_idx_list = frames_idx_list.index(right_ref_idx) 309 | right_x_rec = reconstructions[cur_pos_in_frame_idx_list+1] 310 | x_cur = convert_yuv420_to_rgb(org_seq[cur_interpolation_idx+(i//num_gop)*num_gop], device, max_val) 311 | x_cur, padding = pad(x_cur) 312 | start = time.time() 313 | y, enc_info = BFrameCompressor.compress(x_cur) 314 | enc_time = time.time() - start 315 | write_body(f, enc_info["shape"], enc_info["strings"]) 316 | 317 | start = time.time() 318 | mid_key = interpolation_net.inference(left_x_rec, right_x_rec, timestep=0.5) 319 | x_rec = BFrameCompressor.decompress(enc_info["strings"], enc_info["shape"], mid_key)["x_hat"] 320 | dec_time = time.time() - start 321 | reconstructions.append(x_rec) 322 | else: 323 | start = time.time() 324 | y, enc_info = BFrameCompressor.compress(x_cur) 325 | enc_time = time.time() - start 326 | write_body(f, enc_info["shape"], enc_info["strings"]) 327 | start = time.time() 328 | mid_key = torch.cat((first_rec, last_x_rec), 1) 329 | x_rec = BFrameCompressor.decompress(enc_info["strings"], enc_info["shape"], mid_key)["x_hat"] 330 | dec_time = time.time() - start 331 | 332 | x_rec = x_rec.clamp(0, 1) 333 | if with_interpolation and (i % num_gop != 0): 334 | metrics = compute_metrics_for_frame(org_seq[cur_interpolation_idx+(i//num_gop)*num_gop], crop(x_rec, padding), device, max_val) 335 | else: 336 | metrics = compute_metrics_for_frame(org_seq[i], crop(x_rec, padding), device, max_val) 337 | 338 | if intra or i%num_gop==0: 339 | metrics["key_encoding_time"] = torch.tensor(enc_time) 340 | metrics["key_decoding_time"] = torch.tensor(dec_time) 341 | else: 342 | metrics["inter_encoding_time"] = torch.tensor(enc_time) 343 | metrics["inter_decoding_time"] = torch.tensor(dec_time) 344 | 345 | #print(metrics) 346 | for k, v in metrics.items(): 347 | results[k].append(v) 348 | pbar.update(1) 349 | f.close() 350 | 351 | seq_results: Dict[str, Any] = { 352 | k: torch.mean(torch.stack(v)) for k, v in results.items() 353 | } 354 | 355 | seq_results["bitrate"] = ( 356 | float(filesize(binpath)) * 8 * org_seq.framerate / (num_frames * 1000) 357 | ) 358 | seq_results["bpp"] = (float(filesize(binpath)) * 8 / (num_frames * num_pixels)) 359 | 360 | 361 | if not keep_binaries: 362 | binpath.unlink() 363 | 364 | for k, v in seq_results.items(): 365 | if isinstance(v, torch.Tensor): 366 | seq_results[k] = v.item() 367 | return seq_results 368 | 369 | 370 | def specific_frame_structure(num_gop): 371 | num_frames = num_gop + 1 #+1-->because add the next key frame 372 | frames_idx_dict = {3:[1], 5:[2,1,3], 9:[4,2,1,3,6,5,7], 17:[8,4,2,1,3,6,5,7,12,10,9,11,14,13,15], 373 | 33:[16, 8,4,2,1,3,6,5,7,12,10,9,11,14,13,15,24,20,18,17,19,22,21,23,28,26,25,27,30,29,31]} 374 | #timestep = 0.5 375 | #odd_number: -1, +1, even number 376 | ref_idx_dict = {1:[0,2], 3:[2, 4], 5:[4,6], 7:[6, 8], 9:[8, 10], 11:[10, 12], 13:[12, 14], 15:[14, 16], 377 | 17:[16, 18], 19:[18,20], 21:[20,22], 23:[22,24], 25:[24, 26], 27:[26, 28], 29:[28, 30], 31:[30, 32], 378 | 6:[4, 8], 10:[8, 12], 12:[8, 16], 14:[12, 16], 18:[16, 20], 20:[16, 24], 22:[20, 24], 24:[16, 32], 379 | 26:[24, 28], 28:[24, 32], 30:[28, 32], 380 | 2:[0, 4], 4:[0, 8], 8:[0, 16], 16:[0, 32]} 381 | return frames_idx_dict[num_frames], ref_idx_dict 382 | 383 | 384 | @torch.no_grad() 385 | def eval_model_entropy_estimation(interpolation_net, BFrameCompressor:nn.Module, IFrameCompressor:nn.Module, 386 | sequence: Path, **args: Any) -> Dict[str, Any]: 387 | org_seq = RawVideoSequence.from_file(str(sequence)) 388 | 389 | if org_seq.format != VideoFormat.YUV420: 390 | raise NotImplementedError(f"Unsupported video format: {org_seq.format}") 391 | 392 | device = next(IFrameCompressor.parameters()).device 393 | num_frames = args["vframes"] 394 | print("video length:{}, frame rate:{}".format(len(org_seq), org_seq.framerate)) 395 | num_pixels = org_seq.height * org_seq.width 396 | max_val = 2**org_seq.bitdepth - 1 397 | results = defaultdict(list) 398 | print(f" encoding {sequence.stem}", file=sys.stderr) 399 | 400 | num_gop = args["GOP"] 401 | with_interpolation = args["with_interpolation"] 402 | frames_idx_list, ref_idx_dict = specific_frame_structure(num_gop) 403 | 404 | with tqdm(total=num_frames) as pbar: #97: 0-96 405 | for i in range(num_frames): 406 | x_cur = convert_yuv420_to_rgb(org_seq[i], device, max_val) 407 | x_cur, padding = pad(x_cur) 408 | 409 | if i % num_gop == 0: 410 | first_key = IFrameCompressor(x_cur) 411 | last_key_frame = convert_yuv420_to_rgb(org_seq[i+num_gop], device, max_val) 412 | last_key_frame, _ = pad(last_key_frame) 413 | last_key = IFrameCompressor(last_key_frame) 414 | 415 | x_rec, likelihoods = first_key["x_hat"], first_key["likelihoods"] 416 | reconstructions = [x_rec] 417 | current = [x_cur] 418 | side_info = [] 419 | else: 420 | cur_interpolation_idx = frames_idx_list[i%num_gop-1] 421 | left_ref_idx, right_ref_idx = ref_idx_dict[cur_interpolation_idx] 422 | if left_ref_idx == 0: 423 | left_x_rec = first_key["x_hat"] 424 | else: 425 | cur_pos_in_frame_idx_list = frames_idx_list.index(left_ref_idx) 426 | left_x_rec = reconstructions[cur_pos_in_frame_idx_list+1] 427 | 428 | if right_ref_idx == num_gop: 429 | right_x_rec = last_key["x_hat"] 430 | else: 431 | cur_pos_in_frame_idx_list = frames_idx_list.index(right_ref_idx) 432 | right_x_rec = reconstructions[cur_pos_in_frame_idx_list+1] 433 | x_cur = convert_yuv420_to_rgb(org_seq[cur_interpolation_idx+(i//num_gop)*num_gop], device, max_val) 434 | x_cur, padding = pad(x_cur) 435 | if with_interpolation: 436 | mid_key = interpolation_net.inference(left_x_rec, right_x_rec, timestep=0.5) 437 | side_info.append(mid_key.clamp(0, 1)) 438 | else: 439 | mid_key = torch.cat([left_x_rec, right_x_rec], dim=1) 440 | 441 | out = BFrameCompressor(x_cur, mid_key) 442 | x_rec, likelihoods = out["x_hat"], out["likelihoods"] 443 | reconstructions.append(x_rec) 444 | current.append(x_cur) 445 | 446 | x_rec = x_rec.clamp(0, 1) 447 | if i % num_gop != 0: 448 | org_frame = org_seq[cur_interpolation_idx+(i//num_gop)*num_gop] 449 | 450 | metrics = compute_metrics_for_frame(org_frame, crop(x_rec, padding), device, max_val, i) 451 | metrics["bitrate"], metrics["bpp"] = compute_bpp(likelihoods, num_pixels) 452 | if with_interpolation and i%num_gop!=0: 453 | mid_key = mid_key.clamp(0, 1) 454 | si_psnr_metrics = compute_si_metrics_for_frame(org_frame, crop(mid_key, padding), device, max_val) 455 | metrics.update(si_psnr_metrics) 456 | 457 | for k, v in metrics.items(): 458 | results[k].append(v) 459 | pbar.update(1) 460 | 461 | seq_results: Dict[str, Any] = { 462 | k: torch.mean(torch.stack(v)) for k, v in results.items() 463 | } 464 | seq_results["bitrate"] = float(seq_results["bitrate"]) * org_seq.framerate / 1000 465 | for k, v in seq_results.items(): 466 | if isinstance(v, torch.Tensor): 467 | seq_results[k] = v.item() 468 | return seq_results 469 | 470 | 471 | def run_inference( 472 | filepaths, 473 | interpolation_net, 474 | BFrameCompressor: nn.Module, 475 | IFrameCompressor: nn.Module, 476 | outputdir: Path, 477 | force: bool = False, 478 | entropy_estimation: bool = False, 479 | trained_net: str = "", 480 | description: str = "", 481 | **args: Any, 482 | ) -> Dict[str, Any]: 483 | results_paths = [] 484 | 485 | for filepath in filepaths: 486 | sequence_metrics_path = Path(outputdir) / f"{filepath.stem}-{trained_net}.json" 487 | results_paths.append(sequence_metrics_path) 488 | 489 | if force: 490 | sequence_metrics_path.unlink(missing_ok=True) 491 | if sequence_metrics_path.is_file(): 492 | continue 493 | 494 | with amp.autocast(enabled=args["half"]): 495 | with torch.no_grad(): 496 | if entropy_estimation: 497 | metrics = eval_model_entropy_estimation(interpolation_net, BFrameCompressor, IFrameCompressor, filepath, **args) 498 | else: 499 | encode_folder = os.path.join(outputdir, "encoded_files") 500 | Path(encode_folder).mkdir(parents=True, exist_ok=True) 501 | sequence_bin = Path(encode_folder) / f"{filepath.stem}-{trained_net}.bin" #sequence_metrics_path.with_suffix(".bin") 502 | print(sequence_bin) 503 | metrics = eval_model(interpolation_net, BFrameCompressor, IFrameCompressor, filepath, sequence_bin, **args) 504 | with sequence_metrics_path.open("wb") as f: 505 | output = { 506 | "source": filepath.stem, 507 | "name": args["BFrameModel"], 508 | "description": f"Inference ({description})", 509 | "results": metrics, 510 | } 511 | f.write(json.dumps(output, indent=2).encode()) 512 | results = aggregate_results(results_paths) 513 | return results 514 | 515 | def create_parser() -> argparse.ArgumentParser: 516 | parser = argparse.ArgumentParser( 517 | description="Video compression network evaluation.", 518 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 519 | ) 520 | parser.add_argument("-d", "--dataset", type=str, required=True, help="sequences directory") 521 | parser.add_argument("--output", type=str, help="output directory") 522 | parser.add_argument( 523 | "-im", 524 | "--IFrameModel", 525 | default="mbt2018", 526 | choices=models_arch.keys(), 527 | help="Model architecture (default: %(default)s)", 528 | ) 529 | parser.add_argument( 530 | "-bm", 531 | "--BFrameModel", 532 | default="DVC-ScalePrior", 533 | choices=models_arch.keys(), 534 | help="Model architecture (default: %(default)s)", 535 | ) 536 | parser.add_argument("-iq", "--IFrame_quality", type=int, default=4, help='Model quality') 537 | parser.add_argument("-bq", "--BFrame_quality", type=int, default=1, help='Model quality') 538 | parser.add_argument("--vframes", type=int, default=96, help='Model quality') 539 | parser.add_argument( 540 | "--GOP", 541 | type=int, 542 | default=8, 543 | help="GOP (default: %(default)s)", 544 | ) 545 | parser.add_argument("--b_model_path", type=str, help="Path to a checkpoint") 546 | parser.add_argument("--i_model_path", type=str, help="Path to a checkpoint") 547 | parser.add_argument("--flownet_model_path", type=str, default="../arXiv2020-RIFE/train_log/RIFE_m_train_log/flownet.pkl", help="Path to a checkpoint") 548 | parser.add_argument( 549 | "-f", "--force", action="store_true", help="overwrite previous runs" 550 | ) 551 | parser.add_argument("--cuda", action="store_true", help="use cuda") 552 | parser.add_argument("--half", action="store_true", help="use AMP") 553 | parser.add_argument( 554 | "--entropy-estimation", 555 | action="store_true", 556 | help="use evaluated entropy estimation (no entropy coding)", 557 | ) 558 | parser.add_argument( 559 | "-c", 560 | "--entropy-coder", 561 | choices=compressai.available_entropy_coders(), 562 | default=compressai.available_entropy_coders()[0], 563 | help="entropy coder (default: %(default)s)", 564 | ) 565 | parser.add_argument( 566 | "--keep_binaries", 567 | action="store_true", 568 | help="keep bitstream files in output directory", 569 | ) 570 | parser.add_argument( 571 | "-v", 572 | "--verbose", 573 | action="store_true", 574 | help="verbose mode", 575 | ) 576 | parser.add_argument("--metric", type=str, default="mse", help="metric: mse, ms-ssim") 577 | parser.add_argument("--side_input_channels", type=int, default=3, help="use cuda") 578 | parser.add_argument("--with_interpolation", action="store_true", help='whether use extrapolation network') 579 | parser.add_argument("--num_slices", type=int, default=8, help="use cuda") 580 | return parser 581 | 582 | 583 | def main(args: Any = None) -> None: 584 | if args is None: 585 | args = sys.argv[1:] 586 | parser = create_parser() 587 | args = parser.parse_args(args) 588 | 589 | 590 | description = ( 591 | "entropy-estimation" if args.entropy_estimation else args.entropy_coder 592 | ) 593 | filepaths = collect_videos(args.dataset) 594 | if len(filepaths) == 0: 595 | print("Error: no video found in directory.", file=sys.stderr) 596 | raise SystemExit(1) 597 | 598 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" 599 | #key frame compressor 600 | IFrameCompressor = models_arch[args.IFrameModel](*cfgs[args.IFrameModel][args.IFrame_quality]) 601 | IFrameCompressor = IFrameCompressor.to(device) 602 | url = model_urls[args.IFrameModel][args.metric][args.IFrame_quality] 603 | checkpoint = load_state_dict_from_url(url, progress=True, map_location=device) 604 | checkpoint = load_pretrained(checkpoint) 605 | IFrameCompressor.load_state_dict(checkpoint) 606 | IFrameCompressor.eval() 607 | 608 | 609 | if args.b_model_path: 610 | if args.with_interpolation: 611 | interpolation_net = VideoInterpolationNet(args, arbitrary=True) 612 | print("Loading Video Interpolation model:", args.flownet_model_path) 613 | interpolation_net.load_model(args.flownet_model_path) 614 | interpolation_net.device(device) 615 | interpolation_net.eval() 616 | else: 617 | interpolation_net = None 618 | 619 | #wyner-ziv encoder and decoder 620 | BFrameCompressor = models_arch[args.BFrameModel](*cfgs[args.BFrameModel][args.BFrame_quality], args.side_input_channels, num_slices=args.num_slices) 621 | print(args.BFrameModel, BFrameCompressor.num_slices) 622 | BFrameCompressor = BFrameCompressor.to(device) 623 | print("Loading B frame model: ", args.b_model_path) 624 | checkpoint = torch.load(args.b_model_path, map_location=device) 625 | BFrameCompressor.load_state_dict(checkpoint["state_dict"]) 626 | BFrameCompressor.update(force=True) 627 | BFrameCompressor.eval() 628 | else: 629 | interpolation_net = None 630 | BFrameCompressor = None 631 | 632 | # create output directory 633 | outputdir = args.output 634 | Path(outputdir).mkdir(parents=True, exist_ok=True) 635 | results = defaultdict(list) 636 | args_dict = vars(args) 637 | 638 | trained_net = f"{args.BFrameModel}-{args.metric}-{description}" 639 | 640 | 641 | metrics = run_inference(filepaths, interpolation_net, BFrameCompressor, IFrameCompressor, 642 | outputdir, trained_net=trained_net, description=description, **args_dict,) 643 | for k, v in metrics.items(): 644 | results[k].append(v) 645 | 646 | output = { 647 | "name": f"{args.BFrameModel}-{args.metric}", 648 | "description": f"Inference ({description})", 649 | "results": results, 650 | } 651 | 652 | with (Path(f"{outputdir}/{args.BFrameModel}-{description}.json")).open("wb") as f: 653 | f.write(json.dumps(output, indent=2).encode()) 654 | #print(json.dumps(output, indent=2)) 655 | 656 | 657 | if __name__ == "__main__": 658 | main(sys.argv[1:]) 659 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import random 4 | import shutil 5 | import sys 6 | import os 7 | import yaml 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.optim as optim 13 | from torch.utils.data import DataLoader 14 | from torchvision import transforms 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | from tqdm import tqdm 18 | from lib.utils import get_output_folder, Vimeo, AverageMeter, save_checkpoint, CustomDataParallel 19 | from models import * 20 | from torch.hub import load_state_dict_from_url 21 | from compressai.zoo.pretrained import load_pretrained 22 | 23 | 24 | def configure_optimizers(net, interpolation_net, args): 25 | """Separate parameters for the main optimizer and the auxiliary optimizer. 26 | Return two optimizers""" 27 | 28 | parameters = set( 29 | n 30 | for n, p in net.named_parameters() 31 | if not n.endswith(".quantiles") and p.requires_grad 32 | ) 33 | aux_parameters = set( 34 | n 35 | for n, p in net.named_parameters() 36 | if n.endswith(".quantiles") and p.requires_grad 37 | ) 38 | # Make sure we don't have an intersection of parameters 39 | params_dict = dict(p for p in net.named_parameters() if p[1].requires_grad) 40 | #dict(net.named_parameters()) 41 | inter_params = parameters & aux_parameters 42 | union_params = parameters | aux_parameters 43 | 44 | assert len(inter_params) == 0 45 | assert len(union_params) - len(params_dict.keys()) == 0 46 | if args.flow_finetune and args.with_interpolation: 47 | optimizer = optim.Adam([{'params': (params_dict[n] for n in sorted(parameters)), 'lr':args.learning_rate}, 48 | {'params': interpolation_net.flownet.parameters(), 'lr': args.flow_learning_rate}] 49 | ) 50 | else: 51 | optimizer = optim.Adam( 52 | (params_dict[n] for n in sorted(parameters)), 53 | lr=args.learning_rate, 54 | ) 55 | aux_optimizer = optim.Adam( 56 | (params_dict[n] for n in sorted(aux_parameters)), 57 | lr=args.aux_learning_rate, 58 | ) 59 | return optimizer, aux_optimizer 60 | 61 | def train_one_epoch(IFrameCompressor, BFrameCompressor, interpolation_net, criterion, train_dataloader, 62 | optimizer, aux_optimizer, epoch, iterations, clip_max_norm, args): 63 | IFrameCompressor.train() 64 | BFrameCompressor.train() 65 | if args.with_interpolation: 66 | interpolation_net.train() 67 | device = next(BFrameCompressor.parameters()).device 68 | 69 | loss = AverageMeter('Loss', ':.4e') 70 | bpp_loss = AverageMeter('BppLoss', ':.4e') 71 | aux_loss = AverageMeter('AuxLoss', ':.4e') 72 | 73 | metric_dB_name = 'psnr' if args.metric == "mse" else "ms_ssim_db" 74 | metric_name = "mse_loss" if args.metric == "mse" else "ms_ssim_loss" 75 | metric_dB = AverageMeter(metric_dB_name, ':.4e') 76 | metric_loss = AverageMeter(args.metric, ':.4e') 77 | 78 | train_dataloader = tqdm(train_dataloader) 79 | print('Train epoch:', epoch) 80 | for i, images in enumerate(train_dataloader): 81 | rand_num = random.randint(3, len(images)) 82 | images_index = random.sample(range(len(images)), rand_num) 83 | images_index.sort(reverse=False) #升序 84 | images = [images[idx].to(device) for idx in images_index] 85 | num_p = len(images)-1 86 | 87 | optimizer.zero_grad() 88 | aux_optimizer.zero_grad() 89 | 90 | for imgidx in range(num_p): 91 | if imgidx == 0: 92 | # I frame compression. 93 | first_key = IFrameCompressor(images[0]) 94 | last_key = IFrameCompressor(images[-1]) 95 | else: 96 | # B frame compression 97 | optimizer.zero_grad() 98 | aux_optimizer.zero_grad() 99 | if args.with_interpolation: 100 | mid_key = interpolation_net.inference(first_key["x_hat"], last_key["x_hat"], timestep=((imgidx)/num_p)) 101 | if not args.flow_finetune: 102 | mid_key = mid_key.detach() 103 | else: 104 | mid_key = torch.cat((first_key["x_hat"], last_key["x_hat"]), 1) 105 | 106 | out = BFrameCompressor(images[imgidx], mid_key) 107 | 108 | out_criterion = criterion(out, images[imgidx], args.lmbda) 109 | out_criterion["loss"].backward() 110 | if args.clip_max_norm > 0: 111 | torch.nn.utils.clip_grad_norm_(BFrameCompressor.parameters(), args.clip_max_norm) # mxh add 112 | optimizer.step() 113 | out_aux_loss = BFrameCompressor.aux_loss() 114 | out_aux_loss.backward() 115 | aux_optimizer.step() 116 | 117 | loss.update(out_criterion["loss"].item()) 118 | bpp_loss.update(out_criterion["bpp_loss"].item()) 119 | aux_loss.update(out_aux_loss.item()) 120 | metric_loss.update(out_criterion[metric_name].item()) 121 | metric_dB.update(out_criterion[metric_dB_name].item()) 122 | iterations += 1 123 | 124 | train_dataloader.set_description('[{}/{}]'.format(i, len(train_dataloader))) 125 | train_dataloader.set_postfix({"Loss":loss.avg, 'Bpp':bpp_loss.avg, args.metric: metric_loss.avg, 'Aux':aux_loss.avg, 126 | metric_dB_name:metric_dB.avg}) 127 | 128 | out = {"loss": loss.avg, metric_name: metric_loss.avg, "bpp_loss": bpp_loss.avg, 129 | "aux_loss":aux_loss.avg, metric_dB_name: metric_dB.avg, "iterations": iterations} 130 | 131 | return out 132 | 133 | 134 | def test_epoch(epoch, test_dataloader, IFrameCompressor, BFrameCompressor, interpolation_net, criterion, args): 135 | IFrameCompressor.eval() 136 | BFrameCompressor.eval() 137 | if args.with_interpolation: 138 | interpolation_net.eval() 139 | device = next(BFrameCompressor.parameters()).device 140 | 141 | loss = AverageMeter('Loss', ':.4e') 142 | bpp_loss = AverageMeter('BppLoss', ':.4e') 143 | aux_loss = AverageMeter('AuxLoss', ':.4e') 144 | metric_dB_name = 'psnr' if args.metric == "mse" else "ms_ssim_db" 145 | metric_name = "mse_loss" if args.metric == "mse" else "ms_ssim_loss" 146 | metric_dB = AverageMeter(metric_dB_name, ':.4e') 147 | metric_loss = AverageMeter(args.metric, ':.4e') 148 | 149 | test_dataloader = tqdm(test_dataloader) 150 | with torch.no_grad(): 151 | for i, images in enumerate(test_dataloader): 152 | rand_num = random.randint(3, len(images)) 153 | images_index = random.sample(range(len(images)), rand_num) 154 | images_index.sort(reverse=False) #升序 155 | images = [images[idx].to(device) for idx in images_index] 156 | num_p = len(images)-1 157 | 158 | for imgidx in range(num_p): 159 | if imgidx == 0: 160 | # I frame compression. 161 | first_key = IFrameCompressor(images[0]) 162 | last_key = IFrameCompressor(images[-1]) 163 | else: 164 | if args.with_interpolation: 165 | mid_key = interpolation_net.inference(first_key["x_hat"], last_key["x_hat"], timestep=((imgidx)/num_p)) 166 | else: 167 | mid_key = torch.cat((first_key["x_hat"], last_key["x_hat"]), 1) 168 | 169 | out = BFrameCompressor(images[imgidx], mid_key) 170 | 171 | out_criterion = criterion(out, images[imgidx], args.lmbda) 172 | 173 | loss.update(out_criterion["loss"].item()) 174 | bpp_loss.update(out_criterion["bpp_loss"].item()) 175 | aux_loss.update(BFrameCompressor.aux_loss().item()) 176 | metric_loss.update(out_criterion[metric_name].item()) 177 | metric_dB.update(out_criterion[metric_dB_name].item()) 178 | 179 | test_dataloader.set_description('[{}/{}]'.format(i, len(test_dataloader))) 180 | test_dataloader.set_postfix({"Loss":loss.avg, 'Bpp':bpp_loss.avg, args.metric: metric_loss.avg, 'Aux':aux_loss.avg, 181 | metric_dB_name:metric_dB.avg}) 182 | 183 | 184 | out = {"loss": loss.avg, metric_name: metric_loss.avg, "bpp_loss": bpp_loss.avg, 185 | "aux_loss":aux_loss.avg, metric_dB_name: metric_dB.avg} 186 | 187 | return out 188 | 189 | def parse_args(argv): 190 | parser = argparse.ArgumentParser(description="Example training script.") 191 | parser.add_argument( 192 | "-im", 193 | "--IFrameModel", 194 | default="mbt2018", 195 | choices=models_arch.keys(), 196 | help="Model architecture (default: %(default)s)", 197 | ) 198 | parser.add_argument( 199 | "-bm", 200 | "--BFrameModel", 201 | default="DVC-Hyperprior", 202 | choices=models_arch.keys(), 203 | help="Model architecture (default: %(default)s)", 204 | ) 205 | parser.add_argument("-iq", "--IFrame_quality", type=int, default=4, help='Model quality') 206 | parser.add_argument("-bq", "--BFrame_quality", type=int, default=1, help='Model quality') 207 | parser.add_argument( 208 | "-d", "--dataset", type=str, required=True, help="Training dataset" 209 | ) 210 | parser.add_argument( 211 | "-e", 212 | "--epochs", 213 | default=100, 214 | type=int, 215 | help="Number of epochs (default: %(default)s)", 216 | ) 217 | parser.add_argument( 218 | "-lr", 219 | "--learning-rate", 220 | default=1e-4, 221 | type=float, 222 | help="Learning rate (default: %(default)s)", 223 | ) 224 | parser.add_argument( 225 | "--flow_learning_rate", 226 | default=1e-4, 227 | type=float, 228 | help="Learning rate (default: %(default)s)", 229 | ) 230 | parser.add_argument( 231 | "-n", 232 | "--num-workers", 233 | type=int, 234 | default=4, 235 | help="Dataloaders threads (default: %(default)s)", 236 | ) 237 | parser.add_argument( 238 | "--lambda", 239 | dest="lmbda", 240 | type=float, 241 | default=0, 242 | help="Bit-rate distortion parameter (default: %(default)s)", 243 | ) #0.0018; λ2 = 0.0035; λ3 = 0.0067; λ4 = 0.0130, λ5 =0.025 244 | 245 | parser.add_argument( 246 | "--batch-size", type=int, default=16, help="Batch size (default: %(default)s)" 247 | ) 248 | parser.add_argument( 249 | "--test-batch-size", 250 | type=int, 251 | default=64, 252 | help="Test batch size (default: %(default)s)", 253 | ) 254 | parser.add_argument( 255 | "--aux-learning-rate", 256 | default=1e-3, 257 | help="Auxiliary loss learning rate (default: %(default)s)", 258 | ) 259 | parser.add_argument( 260 | "--patch-size", 261 | type=int, 262 | nargs=2, 263 | default=(256, 256), 264 | help="Size of the patches to be cropped (default: %(default)s)", 265 | ) 266 | parser.add_argument("--cuda", action="store_true", help="Use cuda") 267 | parser.add_argument( 268 | "--save", action="store_true", default=True, help="Save IFrameCompressor to disk" 269 | ) 270 | parser.add_argument( 271 | "--seed", type=float, default=1, help="Set random seed for reproducibility" 272 | ) 273 | parser.add_argument( 274 | "--clip_max_norm", 275 | default=5.0, 276 | type=float, 277 | help="gradient clipping max norm (default: %(default)s", 278 | ) 279 | parser.add_argument("--i-model-path", type=str, default="/home/xzhangga/.cache/torch/hub/checkpoints/mbt2018-4-456e2af9.pth.tar", help="Path to a checkpoint") 280 | parser.add_argument("--b_model_path", type=str, help="Path to a checkpoint") 281 | parser.add_argument("--flownet_model_path", type=str, default="./flownet_model/RIFE_m_train_log/flownet.pkl", help="Path to a checkpoint") 282 | parser.add_argument("--metric", type=str, default="ms-ssim", help="metric: mse, ms-ssim") 283 | parser.add_argument("--flow_finetune", action="store_true", help='whether flownet is finetuned')#default: False 284 | parser.add_argument("--with_interpolation", action="store_true", help='whether use extrapolation network') 285 | parser.add_argument("--use_pretrained_bmodel", action="store_true", help='use pretrained high rate model to train low rate model') 286 | parser.add_argument("--side_input_channels", type=int, default=3) 287 | parser.add_argument("--num_slices", type=int, default=10) 288 | args = parser.parse_args(argv) 289 | return args 290 | 291 | 292 | def main(argv): 293 | args = parse_args(argv) 294 | # Cache the args as a text string to save them in the output dir later 295 | args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) 296 | 297 | if args.seed is not None: 298 | torch.manual_seed(args.seed) 299 | random.seed(args.seed) 300 | 301 | 302 | train_dataset = Vimeo(args.dataset, is_training=True, crop_size=args.patch_size) 303 | test_dataset = Vimeo(args.dataset, is_training=False, crop_size=args.patch_size) 304 | 305 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" 306 | 307 | train_dataloader = DataLoader( 308 | train_dataset, 309 | batch_size=args.batch_size, 310 | num_workers=args.num_workers, 311 | shuffle=True, 312 | pin_memory=(device == "cuda"), 313 | ) 314 | 315 | test_dataloader = DataLoader( 316 | test_dataset, 317 | batch_size=args.test_batch_size, 318 | num_workers=args.num_workers, 319 | shuffle=False, 320 | pin_memory=(device == "cuda"), 321 | ) 322 | 323 | #key frame compressor 324 | IFrameCompressor = models_arch[args.IFrameModel](*cfgs[args.IFrameModel][args.IFrame_quality]) 325 | IFrameCompressor = IFrameCompressor.to(device) 326 | for p in IFrameCompressor.parameters(): 327 | p.requires_grad = False 328 | 329 | url = model_urls[args.IFrameModel][args.metric][args.IFrame_quality] 330 | checkpoint = load_state_dict_from_url(url, progress=True, map_location=device) 331 | checkpoint = load_pretrained(checkpoint) 332 | IFrameCompressor.load_state_dict(checkpoint) 333 | 334 | #wyner-ziv encoder and decoder 335 | BFrameCompressor = models_arch[args.BFrameModel](*cfgs[args.BFrameModel][args.BFrame_quality], args.side_input_channels) 336 | BFrameCompressor = BFrameCompressor.to(device) 337 | 338 | if args.with_interpolation: 339 | interpolation_net = VideoInterpolationNet(args, arbitrary=True) 340 | interpolation_net.load_model(args.flownet_model_path) 341 | if not args.flow_finetune: 342 | interpolation_net.freeze_model() 343 | interpolation_net.device(device) 344 | model_name = "NDVC_Interpolation" 345 | else: 346 | interpolation_net = None 347 | model_name = "NDVC_WO_Interpolation" 348 | 349 | if args.metric == "mse": 350 | criterion = DVCLoss() 351 | else: 352 | criterion = DVC_MS_SSIM_Loss(device, size_average=True, max_val=1) 353 | 354 | optimizer, aux_optimizer = configure_optimizers(BFrameCompressor, interpolation_net, args) 355 | patience = 5 if args.flow_finetune and args.with_interpolation else 10 356 | lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min", patience=patience, factor=0.2) 357 | 358 | last_epoch = 0 359 | last_iterations = 0 360 | best_loss = float("inf") 361 | use_previous_bmode_path = False 362 | 363 | if args.b_model_path: 364 | print("Loading B frame model: ", args.b_model_path) 365 | checkpoint = torch.load(args.b_model_path, map_location=device) 366 | if args.use_pretrained_bmodel: 367 | print("load pretrained model") 368 | BFrameCompressor.load_state_dict(checkpoint["state_dict"]) 369 | else: 370 | print("load pretrained model and optimizer!") 371 | BFrameCompressor.load_state_dict(checkpoint["state_dict"]) 372 | 373 | last_epoch = checkpoint["epoch"] + 1 374 | last_iterations = checkpoint["iterations"] 375 | optimizer.load_state_dict(checkpoint["optimizer"]) 376 | aux_optimizer.load_state_dict(checkpoint["aux_optimizer"]) 377 | lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) 378 | best_b_model_path = os.path.join(os.path.split(args.b_model_path)[0], 'ckpt.best.pth.tar') 379 | best_loss = torch.load(best_b_model_path)["loss"] #checkpoint["loss"] 380 | 381 | use_previous_bmode_path = True 382 | 383 | if args.cuda and torch.cuda.device_count() > 1: 384 | IFrameCompressor = CustomDataParallel(IFrameCompressor) 385 | BFrameCompressor = CustomDataParallel(BFrameCompressor) 386 | 387 | stage = 2 if args.flow_finetune else 1 388 | if use_previous_bmode_path: 389 | log_dir = os.path.split(args.b_model_path)[0] 390 | else: 391 | log_dir = get_output_folder('./checkpoints/{}/{}/{}/stage{}'.format(args.metric, model_name, args.BFrameModel, stage), 'train') 392 | 393 | print(log_dir) 394 | with open(os.path.join(log_dir, 'args.yaml'), 'w') as f: 395 | f.write(args_text) 396 | writer = SummaryWriter(log_dir) 397 | 398 | metric_dB_name = 'psnr' if args.metric == "mse" else "ms_ssim_db" 399 | metric_name = "mse_loss" if args.metric == "mse" else "ms_ssim_loss" 400 | iterations = last_iterations 401 | #val_loss = test_epoch(0, test_dataloader, IFrameCompressor, BFrameCompressor, interpolation_net, criterion, args) 402 | for epoch in range(last_epoch, args.epochs): 403 | print(f"Learning rate: {optimizer.param_groups[0]['lr']}") 404 | train_loss = train_one_epoch(IFrameCompressor, BFrameCompressor, interpolation_net, criterion, train_dataloader, 405 | optimizer, aux_optimizer, epoch, iterations, args.clip_max_norm, args) 406 | val_loss = test_epoch(epoch, test_dataloader, IFrameCompressor, BFrameCompressor, interpolation_net, criterion, args) 407 | 408 | writer.add_scalar('train/loss', train_loss["loss"], epoch) 409 | writer.add_scalar('train/bpp_loss', train_loss["bpp_loss"], epoch) 410 | writer.add_scalar('train/aux_loss', train_loss["aux_loss"], epoch) 411 | 412 | writer.add_scalar('val/loss', val_loss["loss"], epoch) 413 | writer.add_scalar('val/bpp_loss', val_loss["bpp_loss"], epoch) 414 | writer.add_scalar('val/aux_loss', val_loss["aux_loss"], epoch) 415 | 416 | writer.add_scalar('train/'+metric_dB_name, train_loss[metric_dB_name], epoch) 417 | writer.add_scalar('train/'+metric_name, train_loss[metric_name], epoch) 418 | writer.add_scalar('val/'+metric_name, val_loss[metric_name], epoch) 419 | writer.add_scalar('val/'+metric_dB_name, val_loss[metric_dB_name], epoch) 420 | 421 | iterations = train_loss["iterations"] 422 | loss = val_loss["loss"] 423 | lr_scheduler.step(loss) 424 | 425 | is_best = loss < best_loss 426 | best_loss = min(loss, best_loss) 427 | 428 | if args.save: 429 | save_checkpoint( 430 | { 431 | "epoch": epoch, 432 | "iterations": iterations, 433 | "state_dict": BFrameCompressor.state_dict(), 434 | "loss": loss, 435 | "bpp": val_loss["bpp_loss"], 436 | metric_dB_name: val_loss[metric_dB_name], 437 | "optimizer": optimizer.state_dict(), 438 | "aux_optimizer": aux_optimizer.state_dict(), 439 | "lr_scheduler": lr_scheduler.state_dict(), 440 | }, 441 | is_best, log_dir 442 | ) 443 | 444 | if args.flow_finetune: 445 | interpolation_net.save_model(log_dir, is_best) 446 | 447 | if __name__ == "__main__": 448 | main(sys.argv[1:]) 449 | --------------------------------------------------------------------------------