├── README.md ├── config └── mt.yaml ├── data └── datasets.py ├── layer ├── layer_utils.py └── transformer.py ├── loss └── distortion.py ├── net ├── elic.py └── mit.py ├── results └── mt_rd.png ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | This repo is an implementation for [M2T: Masking Transformers Twice for Faster Decoding](https://openaccess.thecvf.com/content/ICCV2023/html/Mentzer_M2T_Masking_Transformers_Twice_for_Faster_Decoding_ICCV_2023_paper.html) in **pytorch**. 2 | 3 | ## Install 4 | 5 | The latest codes are tested on Ubuntu18.04LTS, CUDA11.7, PyTorch1.9 and Python 3.7. 6 | Some libraries are required to run the codes in this repo, including [constriction](https://bamler-lab.github.io/constriction/), [compressai](https://github.com/InterDigitalInc/CompressAI), and [timm](https://timm.fast.ai/). 7 | 8 | ### Train 9 | ```python 10 | python train.py --config config/mt.yaml # --wandb (if you want to use wandb) 11 | ``` 12 | Model checkpoints and logs will be saved in `./history/MT`. 13 | 14 | ### Test 15 | ```python 16 | python train.py --config config/mt.yaml --test-only --eval-dataset-path: 'path_to_kodak' 17 | ``` 18 | ### Performance 19 | Red dot is our reproduction with distortion lambda $\lambda=0.0035$. 20 |

21 | 22 |

23 | 24 | ### Pretrained models 25 | To be released. 26 | 27 | ### Acknowledgements 28 | We use constriction for actual entropy coding. Thanks for Fabian Mentzer's help for the clarification of the details of the paper. -------------------------------------------------------------------------------- /config/mt.yaml: -------------------------------------------------------------------------------- 1 | gpu-id: '0' 2 | multi-gpu: False 3 | test-only: True 4 | 5 | batch-size: 12 6 | eval-dataset-path: '/media/Dataset/kodak' 7 | num-workers: 12 8 | training-img-size: (384, 384) 9 | 10 | checkpoint: '/media/D/wangsixian/MT/history/MT/MT 2023-10-12 16:30:30/checkpoint_best_loss.pth.tar' 11 | init_lr: 1e-4 12 | min_lr: 1e-4 13 | max_lr: 1e-4 14 | 15 | 16 | epochs: 10 17 | warmup: False 18 | print-every: 100 19 | test-every: 2500 20 | distortion_metric: 'MSE' 21 | lambda_value: 0.0035 22 | #beta_value: 0 23 | -------------------------------------------------------------------------------- /data/datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from glob import glob 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from torch.utils.data.dataset import Dataset 9 | from torchvision import transforms 10 | 11 | class OpenImages(Dataset): 12 | files = {"train": "train", "test": "test", "val": "validation"} 13 | 14 | def __init__(self, image_dims, data_dir): 15 | self.imgs = [] 16 | for dir in data_dir: 17 | self.imgs += glob(os.path.join(dir, '*.jpg')) 18 | self.imgs += glob(os.path.join(dir, '*.png')) 19 | _, self.im_height, self.im_width = image_dims 20 | self.crop_size = self.im_height 21 | self.image_dims = (3, self.im_height, self.im_width) 22 | self.transform = self._transforms() 23 | 24 | def _transforms(self, scale=0, H=0, W=0): 25 | """ 26 | Up(down)scale and randomly crop to `crop_size` x `crop_size` 27 | """ 28 | transforms_list = [ 29 | # transforms.ToPILImage(), 30 | # transforms.RandomHorizontalFlip(), 31 | # transforms.Resize((math.ceil(scale * H), math.ceil(scale * W))), 32 | transforms.ToTensor(), 33 | transforms.RandomCrop((self.im_height, self.im_width), pad_if_needed=True), 34 | # transforms.Resize((self.im_height, self.im_width)), 35 | ] 36 | 37 | # if self.normalize is True: 38 | # transforms_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 39 | 40 | return transforms.Compose(transforms_list) 41 | 42 | def __getitem__(self, idx): 43 | img_path = self.imgs[idx] 44 | img = Image.open(img_path) 45 | img = img.convert('RGB') 46 | transformed = self.transform(img) 47 | return transformed 48 | 49 | def __len__(self): 50 | return len(self.imgs) 51 | 52 | 53 | class Datasets(Dataset): 54 | def __init__(self, dataset_path): 55 | self.data_dir = dataset_path 56 | self.imgs = [] 57 | self.imgs += glob(os.path.join(self.data_dir, '*.jpg')) 58 | self.imgs += glob(os.path.join(self.data_dir, '*.png')) 59 | self.imgs.sort() 60 | self.transform = transforms.Compose([ 61 | # transforms.RandomCrop((384, 384), pad_if_needed=True), 62 | transforms.ToTensor()]) 63 | 64 | def __getitem__(self, item): 65 | image_ori = self.imgs[item] 66 | image = Image.open(image_ori).convert('RGB') 67 | img = self.transform(image) 68 | return img 69 | 70 | def __len__(self): 71 | return len(self.imgs) 72 | 73 | 74 | def get_loader(train_dir, test_dir, num_workers, batch_size): 75 | train_dataset = OpenImages((3, 384, 384), train_dir) 76 | test_dataset = Datasets(test_dir) 77 | 78 | def worker_init_fn_seed(worker_id): 79 | seed = 10 80 | seed += worker_id 81 | np.random.seed(seed) 82 | 83 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 84 | num_workers=num_workers, 85 | pin_memory=True, 86 | batch_size=batch_size, 87 | worker_init_fn=worker_init_fn_seed, 88 | shuffle=True) 89 | 90 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 91 | batch_size=1, 92 | shuffle=False) 93 | 94 | return train_loader, test_loader 95 | 96 | 97 | def get_dataset(train_dir, test_dir): 98 | train_dataset = OpenImages((3, 384, 384), train_dir) 99 | test_dataset = Datasets(test_dir) 100 | return train_dataset, test_dataset 101 | 102 | 103 | def get_test_loader(test_dir): 104 | test_dataset = Datasets(test_dir) 105 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 106 | batch_size=1, 107 | shuffle=False) 108 | return test_loader 109 | -------------------------------------------------------------------------------- /layer/layer_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Optional, Tuple 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch import Tensor 11 | 12 | 13 | def make_conv( 14 | in_channels: int, 15 | out_channels: int, 16 | kernel_size: int, 17 | stride: int = 1, 18 | ) -> nn.Conv2d: 19 | return nn.Conv2d( 20 | in_channels, 21 | out_channels, 22 | kernel_size=kernel_size, 23 | stride=stride, 24 | padding=kernel_size // 2, 25 | ) 26 | 27 | 28 | def make_deconv( 29 | in_channels: int, 30 | out_channels: int, 31 | kernel_size: int, 32 | stride: int = 1, 33 | ) -> nn.ConvTranspose2d: 34 | return nn.ConvTranspose2d( 35 | in_channels, 36 | out_channels, 37 | kernel_size=kernel_size, 38 | stride=stride, 39 | output_padding=stride - 1, 40 | padding=kernel_size // 2, 41 | ) 42 | 43 | 44 | def make_embedding(input_dim: int, hidden_dim: int) -> nn.Module: 45 | """ 46 | Constructs and returns an embedding layer, which is a simple (dense) linear layer 47 | 48 | Args: 49 | input_dim: input dimensions (input to linear layer) 50 | hidden_dim: output size of the linear layer 51 | 52 | Returns: 53 | a linear nn.Module, initialized with random uniform weights and biases 54 | """ 55 | scale = 1 / input_dim ** 0.5 56 | linear = torch.nn.Linear(input_dim, hidden_dim, bias=True) 57 | # initialise weights in the same was as vct 58 | torch.nn.init.uniform_(linear.weight, -scale, scale) 59 | torch.nn.init.uniform_(linear.bias, -scale, scale) 60 | return linear 61 | 62 | 63 | def init_weights_truncated_normal(m) -> None: 64 | """ 65 | Initialise weights with truncated normal. 66 | Weights that fall outside 2 stds are resampled. 67 | See torch.nn.init.trunc_normal_ for details. 68 | 69 | Args: 70 | m: weights 71 | 72 | Examples: 73 | >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) 74 | >>> net.apply(init_weights_truncated_normal) 75 | """ 76 | std = 0.02 77 | if isinstance(m, nn.Linear): 78 | torch.nn.init.trunc_normal_(m.weight, std=std, a=-2 * std, b=2 * std) 79 | m.bias.data.fill_(0.01) 80 | 81 | 82 | class MLP(nn.Module): 83 | """MLP head for transformer blocks""" 84 | 85 | def __init__( 86 | self, 87 | in_features: int, 88 | mlp_dim: int, 89 | dropout: float, 90 | ) -> None: 91 | """ 92 | MLP head for transformer blocks 93 | Args: 94 | expansion_rate: rate at which the input tensor is expanded 95 | dropout_rate: dropout rate 96 | input_shape: shape of the input tensor, with the last dimension as the 97 | size of channel -- [N1, ... Nn, C] 98 | """ 99 | super().__init__() 100 | 101 | # Initialize linear layers with truncated normal 102 | self.fc1 = nn.Linear(in_features=in_features, out_features=mlp_dim) 103 | init_weights_truncated_normal(self.fc1) 104 | self.fc2 = torch.nn.Linear(in_features=mlp_dim, out_features=in_features) 105 | init_weights_truncated_normal(self.fc2) 106 | 107 | self.act = nn.GELU() 108 | self.dropout = nn.Dropout(p=dropout) 109 | 110 | def forward(self, input: Tensor) -> Tensor: 111 | """Forward pass. 112 | 113 | Args: 114 | features: tensor of shape (batch_size, seq_len, hidden_dim) 115 | 116 | Returns: 117 | tensor of shape (batch_size, seq_len, hidden_dim) 118 | """ 119 | input = self.fc1(input) 120 | input = self.act(input) 121 | input = self.dropout(input) 122 | input = self.fc2(input) 123 | return self.dropout(input) 124 | 125 | 126 | class WindowMultiHeadAttention(nn.Module): 127 | def __init__( 128 | self, 129 | hidden_dim: int, 130 | num_heads: int, 131 | attn_drop: float = 0.0, 132 | proj_drop: float = 0.0, 133 | ) -> None: 134 | """Windowed multi-head self-attention 135 | 136 | Args: 137 | hidden_dim: size of the hidden units 138 | num_heads: number of attention heads 139 | attn_drop: dropout rate of the attention layer. Defaults to 0.0. 140 | proj_drop: dropout rate of the projection layer. Defaults to 0.0. 141 | 142 | Raises: 143 | ValueError: if `hidden_dim` is not a multiple of `num_heads` 144 | """ 145 | super().__init__() 146 | self.hidden_dim = hidden_dim 147 | self.num_heads = num_heads 148 | if hidden_dim % num_heads != 0: 149 | raise ValueError( 150 | f"Size of hidden units ({hidden_dim}) not divisible by number " 151 | f"of heads ({num_heads})." 152 | ) 153 | # constnat to scale output before softmax-ing 154 | self.attn_scale = (hidden_dim // num_heads) ** (-0.5) 155 | # All linear layers are initialized with truncated normal 156 | self.q_linear = torch.nn.Linear( 157 | in_features=hidden_dim, out_features=hidden_dim, bias=True 158 | ) 159 | init_weights_truncated_normal(self.q_linear) 160 | self.k_linear = torch.nn.Linear( 161 | in_features=hidden_dim, out_features=hidden_dim, bias=True 162 | ) 163 | init_weights_truncated_normal(self.k_linear) 164 | self.v_linear = torch.nn.Linear( 165 | in_features=hidden_dim, out_features=hidden_dim, bias=True 166 | ) 167 | init_weights_truncated_normal(self.v_linear) 168 | self.attn_dropout = nn.Dropout(p=attn_drop) 169 | self.proj = torch.nn.Linear(in_features=hidden_dim, out_features=hidden_dim) 170 | init_weights_truncated_normal(self.proj) 171 | self.proj_dropout = torch.nn.Dropout(p=proj_drop) 172 | self.softmax = torch.nn.Softmax(dim=-1) # -1 is the default in tf 173 | 174 | def forward( 175 | self, 176 | query: Tensor, 177 | key: Tensor, 178 | value: Tensor, 179 | mask: Optional[torch.Tensor] = None, 180 | ) -> Tuple[Tensor, Tensor]: 181 | """Compute the windowed multi-head self-attention. 182 | 183 | Note: seq_length of keys and values (`seq_len_kv`) must be an integer 184 | multiple of the seq_length of the query (`seq_len_q`). 185 | 186 | Args: 187 | query: tensor of shape (b', seq_len_q, hidden_dim), 188 | representing the query 189 | key: tensor of shape (b', seq_len_kv, hidden_dim), 190 | representing the key 191 | value: tensor of shape (b', seq_len_kv, hidden_dim), 192 | representing the value 193 | mask: optional tensor of shape (b', seq_len_q, seq_len_q), 194 | representing the mask to apply to the attention; 1s will be masked 195 | -> b' is an augmented batch size that includes the total number of patches 196 | -> by default, hidden_dim is 768 197 | 198 | Returns: 199 | tensor of feautures of shape (b', seq_len_q, hidden_dim), as well as the 200 | attention matrix used, shape (b', num_heads, seq_len_q, seq_len_kv). 201 | """ 202 | *b, seq_len_q, c = query.shape 203 | assert c == self.hidden_dim, f"Shape mismatch, {c} != {self.hidden_dim}" 204 | seq_len_kv = value.shape[-2] 205 | assert seq_len_kv % seq_len_q == 0, ( 206 | f"seq_length of keys and values = {seq_len_kv} must be an integer multiple " 207 | f"of the seq_length of the query = {seq_len_q}" 208 | ) 209 | blowup = seq_len_kv // seq_len_q 210 | 211 | query = self.q_linear(query) # [b', seq_len_q, hidden_dim] 212 | key = self.k_linear(key) # [b', seq_len_kv, hidden_dim] 213 | value = self.v_linear(value) # [b', seq_len_kv, hidden_dim] 214 | 215 | # reshape by splitting channels into num_heads, then permute: 216 | query = ( 217 | query.reshape(*b, seq_len_q, self.num_heads, c // self.num_heads) 218 | .permute(0, 2, 1, 3) 219 | .contiguous() 220 | ) # [b', num_heads, seq_len_q, c // num_heads] 221 | key = ( 222 | key.reshape(*b, seq_len_kv, self.num_heads, c // self.num_heads) 223 | .permute(0, 2, 1, 3) 224 | .contiguous() 225 | ) # [b', num_heads, seq_len_kv, c // num_heads] 226 | value = ( 227 | value.reshape(*b, seq_len_kv, self.num_heads, c // self.num_heads) 228 | .permute(0, 2, 1, 3) 229 | .contiguous() 230 | ) # [b', num_heads, seq_len_kv, c // num_heads] 231 | 232 | # b', num_heads, seq_len_q, seq_len_kv 233 | attn = ( 234 | torch.matmul( 235 | query, # [b', num_heads, seq_len_q, c // num_heads] 236 | key.transpose(2, 3), # [b', num_heads, c // num_heads, seq_len_kv] 237 | ) # [b', num_heads,seq_len_q, seq_len_kv] 238 | * self.attn_scale 239 | ) 240 | 241 | if mask is not None: 242 | if mask.shape[-2:] != (seq_len_q, seq_len_q): 243 | # Note that we only mask for self-attention in the decoder, 244 | # where the attention matrix has shape (..., seq_len_q, seq_len_q). 245 | raise ValueError(f"Invalid mask shape: {mask.shape}.") 246 | 247 | # Here, we add the mask to the attention with a large negative multiplier, 248 | # as this goes into a softmax it will become 0 249 | tile_pattern = [1] * mask.dim() 250 | tile_pattern[-1] = blowup 251 | attn = attn + torch.tile(mask, tile_pattern) * -1e6 252 | else: 253 | tile_pattern = None 254 | 255 | attn = self.softmax(attn) # [b', num_heads, seq_length_q, seq_length_kv] 256 | 257 | if mask is not None and tile_pattern is not None: 258 | # We use the mask again, to be ensure that no masked dimension 259 | # affects the output. 260 | keep = 1 - mask 261 | attn = attn * torch.tile(keep, tile_pattern) 262 | 263 | attn = self.attn_dropout(attn) # [b', num_heads, seq_length_q, seq_length_kv] 264 | 265 | features = torch.matmul( 266 | attn, value 267 | ) # [b', num_heads, seq_len_q, d_model//num_heads], c=d_model 268 | assert features.shape == (*b, self.num_heads, seq_len_q, c // self.num_heads) 269 | 270 | features = ( 271 | features.permute(0, 2, 1, 3) # switch num_heads, seq_len_q dimensions. 272 | .contiguous() 273 | .reshape(*b, -1, self.hidden_dim) # flatten num_heads, seq_len_q dimensions 274 | ) 275 | features = self.proj(features) 276 | features = self.proj_dropout(features) 277 | assert features.shape == (*b, seq_len_q, c) 278 | return features, attn 279 | 280 | 281 | class StochasticDepth(nn.Module): 282 | """Creates a stochastic depth layer.""" 283 | 284 | def __init__(self, stochastic_depth_drop_rate: float) -> None: 285 | """Initializes a stochastic depth layer. 286 | 287 | Args: 288 | stochastic_depth_drop_rate: A `float` of drop rate. 289 | 290 | Returns: 291 | a tensor of the same shape as input. 292 | """ 293 | super().__init__() 294 | self._drop_rate = stochastic_depth_drop_rate 295 | 296 | def forward(self, input: Tensor) -> Tensor: 297 | if not self.training or self._drop_rate == 0.0: 298 | return input 299 | 300 | keep_prob = 1.0 - self._drop_rate 301 | batch_size = input.shape[0] 302 | random_tensor = keep_prob 303 | random_tensor = random_tensor + torch.rand( 304 | [batch_size] + [1] * (input.dim() - 1), dtype=input.dtype 305 | ) 306 | binary_tensor = torch.floor(random_tensor) 307 | return input / keep_prob * binary_tensor 308 | 309 | 310 | def quantize_ste(x): 311 | """Differentiable quantization via the Straight-Through-Estimator.""" 312 | # STE (straight-through estimator) trick: x_hard - x_soft.detach() + x_soft 313 | return (torch.round(x) - x).detach() + x 314 | -------------------------------------------------------------------------------- /layer/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 5 | 6 | 7 | class Mlp(nn.Module): 8 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 9 | super().__init__() 10 | out_features = out_features or in_features 11 | hidden_features = hidden_features or in_features 12 | self.fc1 = nn.Linear(in_features, hidden_features) 13 | self.act = act_layer() 14 | self.fc2 = nn.Linear(hidden_features, out_features) 15 | self.drop = nn.Dropout(drop) 16 | 17 | def forward(self, x): 18 | x = self.fc1(x) 19 | x = self.act(x) 20 | x = self.drop(x) 21 | x = self.fc2(x) 22 | x = self.drop(x) 23 | return x 24 | 25 | 26 | def window_partition(x, window_size): 27 | """ 28 | Args: 29 | x: (B, H, W, C) 30 | window_size (int): window size 31 | Returns: 32 | windows: (num_windows*B, window_size, window_size, C) 33 | """ 34 | B, H, W, C = x.shape 35 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 36 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 37 | return windows 38 | 39 | 40 | def window_reverse(windows, window_size, H, W): 41 | """ 42 | Args: 43 | windows: (num_windows*B, window_size, window_size, C) 44 | window_size (int): Window size 45 | H (int): Height of image 46 | W (int): Width of image 47 | Returns: 48 | x: (B, H, W, C) 49 | """ 50 | # print("windows.shape[0]", windows.shape[0], H * W) 51 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 52 | # print(B) 53 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 54 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 55 | return x 56 | 57 | 58 | class WindowAttention(nn.Module): 59 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 60 | It supports both of shifted and non-shifted window. 61 | Args: 62 | dim (int): Number of input channels. 63 | window_size (tuple[int]): The height and width of the window. 64 | num_heads (int): Number of attention heads. 65 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 66 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 67 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 68 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 69 | """ 70 | 71 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 72 | 73 | super().__init__() 74 | self.dim = dim 75 | self.window_size = window_size # Wh, Ww 76 | self.num_heads = num_heads 77 | head_dim = dim // num_heads 78 | self.scale = qk_scale or head_dim ** -0.5 79 | 80 | # define a parameter table of relative position bias 81 | self.relative_position_bias_table = nn.Parameter( 82 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 83 | 84 | # get pair-wise relative position index for each token inside the window 85 | coords_h = torch.arange(self.window_size[0]) 86 | coords_w = torch.arange(self.window_size[1]) 87 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 88 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 89 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 90 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 91 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 92 | relative_coords[:, :, 1] += self.window_size[1] - 1 93 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 94 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 95 | self.register_buffer("relative_position_index", relative_position_index) 96 | 97 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 98 | self.attn_drop = nn.Dropout(attn_drop) 99 | self.proj = nn.Linear(dim, dim) 100 | self.proj_drop = nn.Dropout(proj_drop) 101 | 102 | trunc_normal_(self.relative_position_bias_table, std=.02) 103 | self.softmax = nn.Softmax(dim=-1) 104 | 105 | def forward(self, x, mask=None): 106 | """ 107 | Args: 108 | x: input features with shape of (num_windows*B, N, C) 109 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 110 | """ 111 | B_, N, C = x.shape 112 | 113 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 114 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 115 | 116 | q = q * self.scale 117 | attn = (q @ k.transpose(-2, -1)) # (N+1)x(N+1) 118 | 119 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 120 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 121 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 122 | 123 | attn = attn + relative_position_bias.unsqueeze(0) 124 | 125 | if mask is not None: 126 | mask = mask.to(attn.get_device()) 127 | nW = mask.shape[0] 128 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 129 | attn = attn.view(-1, self.num_heads, N, N) 130 | attn = self.softmax(attn) 131 | else: 132 | attn = self.softmax(attn) 133 | 134 | attn = self.attn_drop(attn) 135 | 136 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 137 | x = self.proj(x) 138 | x = self.proj_drop(x) 139 | return x 140 | 141 | def extra_repr(self) -> str: 142 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' 143 | 144 | def flops(self, N): 145 | # calculate flops for 1 window with token length of N 146 | flops = 0 147 | # qkv = self.qkv(x) 148 | flops += N * self.dim * 3 * self.dim 149 | # attn = (q @ k.transpose(-2, -1)) 150 | flops += self.num_heads * N * (self.dim // self.num_heads) * N 151 | # x = (attn @ v) 152 | flops += self.num_heads * N * N * (self.dim // self.num_heads) 153 | # x = self.proj(x) 154 | flops += N * self.dim * self.dim 155 | return flops 156 | 157 | 158 | class TransformerBlock(nn.Module): 159 | r""" Transformer Block. 160 | Args: 161 | dim (int): Number of input channels. 162 | num_heads (int): Number of attention heads. 163 | window_size (int): Window size. 164 | shift_size (int): Shift size for SW-MSA. 165 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 166 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 167 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 168 | drop (float, optional): Dropout rate. Default: 0.0 169 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 170 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 171 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 172 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 173 | """ 174 | 175 | def __init__(self, dim, num_heads, window_size=7, 176 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 177 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 178 | super().__init__() 179 | self.dim = dim 180 | self.num_heads = num_heads 181 | self.window_size = window_size 182 | self.mlp_ratio = mlp_ratio 183 | self.norm1 = norm_layer(dim) 184 | self.attn = WindowAttention( 185 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 186 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 187 | 188 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 189 | self.norm2 = norm_layer(dim) 190 | mlp_hidden_dim = int(dim * mlp_ratio) 191 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 192 | 193 | def forward(self, x, input_resolution): 194 | H0, W0 = input_resolution 195 | B, L, C = x.shape 196 | shortcut = x 197 | x = self.norm1(x) 198 | if H0 % self.window_size != 0 or W0 % self.window_size != 0: 199 | x = x.view(B, H0, W0, C).permute(0, 3, 1, 2) 200 | # reflect pad feature maps to multiples of window size 201 | pad_l = pad_t = 0 202 | pad_b = (self.window_size - H0 % self.window_size) % self.window_size 203 | pad_r = (self.window_size - W0 % self.window_size) % self.window_size 204 | x = F.pad(x, (pad_l, pad_r, pad_t, pad_b)) 205 | # print("padding", pad_l, pad_r, pad_t, pad_b) 206 | x = x.permute(0, 2, 3, 1) 207 | H = H0 + pad_b 208 | W = W0 + pad_r 209 | else: 210 | H = H0 211 | W = W0 212 | x = x.view(B, H0, W0, C) 213 | 214 | # partition windows 215 | x_windows = window_partition(x, self.window_size) # nW*B, window_size, window_size, C 216 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 217 | B_, N, C = x_windows.shape 218 | 219 | attn_windows = self.attn(x_windows, mask=None) # nW*B, window_size*window_size, C 220 | 221 | # merge windows 222 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 223 | x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 224 | 225 | # remove paddings 226 | if L != H * W: 227 | x = x[:, :H0, :W0, :].contiguous() 228 | 229 | x = x.view(B, H0 * W0, C) 230 | # FFN 231 | x = shortcut + self.drop_path(x) 232 | x = x + self.drop_path(self.mlp(self.norm2(x))) 233 | 234 | return x 235 | 236 | def extra_repr(self) -> str: 237 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 238 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 239 | 240 | 241 | class LearnedPosition(nn.Module): 242 | """ 243 | Learned poisitional encoding 244 | """ 245 | 246 | def __init__(self, seq_length: int, hidden_dim: int) -> None: 247 | """ 248 | Learned positional encoding 249 | 250 | Args: 251 | seq_length: sequence length. 252 | hidden_dim: hidden (model) dimension 253 | """ 254 | super().__init__() 255 | self._emb = torch.nn.parameter.Parameter( 256 | torch.empty(1, seq_length, hidden_dim).normal_(std=0.02) 257 | ) # [1, seq_length, hidden_dim] 258 | self._seq_len = seq_length 259 | self._hidden_dim = hidden_dim 260 | 261 | def forward(self, x): 262 | """Adds positional encodings to an input 263 | 264 | Args: 265 | x: tensor to which the positional encodings will be added, 266 | expected shape is [B, seq_len, hidden_dim] 267 | 268 | Returns: 269 | the input tensor with the positional encodings added. 270 | """ 271 | assert x.dim() == 3 and x.shape[1:] == torch.Size( 272 | [self._seq_len, self._hidden_dim] 273 | ), f"Expected [B, seq_length, hidden_dim] got {x.shape}" 274 | return x + self._emb 275 | 276 | 277 | class SwinTransformerBlock(nn.Module): 278 | r""" Swin Transformer Block. 279 | Args: 280 | dim (int): Number of input channels. 281 | num_heads (int): Number of attention heads. 282 | window_size (int): Window size. 283 | shift_size (int): Shift size for SW-MSA. 284 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 285 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 286 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 287 | drop (float, optional): Dropout rate. Default: 0.0 288 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 289 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 290 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 291 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 292 | """ 293 | 294 | def __init__(self, dim, num_heads, window_size=7, shift_size=0, 295 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 296 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 297 | super().__init__() 298 | self.dim = dim 299 | self.input_resolution = (24, 24) 300 | self.num_heads = num_heads 301 | self.window_size = window_size 302 | self.shift_size = shift_size 303 | self.mlp_ratio = mlp_ratio 304 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 305 | 306 | self.norm1 = norm_layer(dim) 307 | self.attn = WindowAttention( 308 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 309 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 310 | 311 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 312 | self.norm2 = norm_layer(dim) 313 | mlp_hidden_dim = int(dim * mlp_ratio) 314 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 315 | 316 | if self.shift_size > 0: 317 | # calculate attention mask for SW-MSA 318 | H, W = self.input_resolution 319 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 320 | h_slices = (slice(0, -self.window_size), 321 | slice(-self.window_size, -self.shift_size), 322 | slice(-self.shift_size, None)) 323 | w_slices = (slice(0, -self.window_size), 324 | slice(-self.window_size, -self.shift_size), 325 | slice(-self.shift_size, None)) 326 | cnt = 0 327 | for h in h_slices: 328 | for w in w_slices: 329 | img_mask[:, h, w, :] = cnt 330 | cnt += 1 331 | 332 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 333 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 334 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 335 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 336 | else: 337 | attn_mask = None 338 | 339 | self.register_buffer("attn_mask", attn_mask) 340 | 341 | def forward(self, x, input_resolution, slice_size=None): 342 | H0, W0 = input_resolution 343 | device = x.get_device() 344 | if input_resolution != self.input_resolution: 345 | self.update_mask(input_resolution, device) 346 | if slice_size is not None: 347 | inter_slice_mask = self.get_inter_slice_mask(input_resolution, slice_size, device) 348 | B, L, C = x.shape 349 | # assert L == H * W, "input feature has wrong size, input size{}x{}x{}, H:{}, W:{}".format(B, L, C, H, W) 350 | shortcut = x 351 | x = self.norm1(x) 352 | if H0 % self.window_size != 0 or W0 % self.window_size != 0: 353 | x = x.view(B, H0, W0, C).permute(0, 3, 1, 2) 354 | # reflect pad feature maps to multiples of window size 355 | pad_l = pad_t = 0 356 | pad_b = (self.window_size - H0 % self.window_size) % self.window_size 357 | pad_r = (self.window_size - W0 % self.window_size) % self.window_size 358 | x = F.pad(x, (pad_l, pad_r, pad_t, pad_b)) 359 | x = x.permute(0, 2, 3, 1) 360 | H = H0 + pad_b 361 | W = W0 + pad_r 362 | else: 363 | H = H0 364 | W = W0 365 | x = x.view(B, H0, W0, C) 366 | 367 | # cyclic shift 368 | if self.shift_size > 0: 369 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 370 | if slice_size is None: 371 | mask = self.attn_mask 372 | else: 373 | mask = (self.attn_mask + inter_slice_mask).clip(-100.0, 0.0) 374 | else: 375 | shifted_x = x 376 | assert (slice_size is None) or (slice_size[0] % self.window_size == 0) 377 | mask = None 378 | 379 | # partition windows 380 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 381 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 382 | B_, N, C = x_windows.shape 383 | 384 | attn_windows = self.attn(x_windows, 385 | mask=mask) # nW*B, window_size*window_size, C 386 | 387 | # merge windows 388 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 389 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 390 | 391 | # reverse cyclic shift 392 | if self.shift_size > 0: 393 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 394 | else: 395 | x = shifted_x 396 | 397 | # remove paddings 398 | if L != H * W: 399 | x = x[:, :H0, :W0, :].contiguous() 400 | 401 | # x = x.view(B, H0 * W0, C) 402 | x = x.view(B, H * W, C) 403 | 404 | # FFN 405 | x = shortcut + self.drop_path(x) 406 | x = x + self.drop_path(self.mlp(self.norm2(x))) 407 | 408 | return x 409 | 410 | def extra_repr(self) -> str: 411 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 412 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 413 | 414 | def flops(self): 415 | flops = 0 416 | H, W = self.input_resolution 417 | # norm1 418 | flops += self.dim * H * W 419 | # W-MSA/SW-MSA 420 | nW = H * W / self.window_size / self.window_size 421 | flops += nW * self.attn.flops(self.window_size * self.window_size) 422 | # mlp 423 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio 424 | # norm2 425 | flops += self.dim * H * W 426 | return flops 427 | 428 | def update_mask(self, input_resolution, device): 429 | self.input_resolution = input_resolution 430 | if self.shift_size > 0: 431 | # calculate attention mask for SW-MSA 432 | H, W = input_resolution 433 | H = H + (self.window_size - H % self.window_size) % self.window_size 434 | W = W + (self.window_size - W % self.window_size) % self.window_size 435 | 436 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 437 | h_slices = (slice(0, -self.window_size), 438 | slice(-self.window_size, -self.shift_size), 439 | slice(-self.shift_size, None)) 440 | w_slices = (slice(0, -self.window_size), 441 | slice(-self.window_size, -self.shift_size), 442 | slice(-self.shift_size, None)) 443 | cnt = 0 444 | for h in h_slices: 445 | for w in w_slices: 446 | img_mask[:, h, w, :] = cnt 447 | cnt += 1 448 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 449 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 450 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 451 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 452 | self.attn_mask = attn_mask.to(device) 453 | # print(attn_mask[0]) 454 | # print(attn_mask[1]) 455 | # print(attn_mask.shape) 456 | else: 457 | pass 458 | 459 | def get_inter_slice_mask(self, input_resolution, slice_size, device): 460 | H, W = input_resolution 461 | Hs, Ws = slice_size 462 | nH = H // Hs 463 | nW = W // Ws 464 | H = H + (self.window_size - H % self.window_size) % self.window_size 465 | W = W + (self.window_size - W % self.window_size) % self.window_size 466 | img_mask = torch.arange(nH * nW).reshape(nH, nW).repeat_interleave(Hs, dim=0).repeat_interleave(Ws, dim=1) 467 | img_mask = img_mask.reshape(1, H, W, 1) 468 | if self.shift_size > 0: 469 | img_mask = torch.roll(img_mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 470 | img_mask = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 471 | img_mask = img_mask.view(-1, self.window_size * self.window_size) 472 | attn_mask = img_mask.unsqueeze(1) - img_mask.unsqueeze(2) 473 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 474 | inter_slice_attn_mask = attn_mask.to(device) 475 | return inter_slice_attn_mask 476 | 477 | 478 | if __name__ == '__main__': 479 | layer = SwinTransformerBlock(dim=96, num_heads=3, window_size=4, shift_size=2).cuda() 480 | x = torch.randn(1, 64, 96).cuda() 481 | layer(x, (8, 8)) 482 | layer.get_windows_mask((8, 8), (2, 2)) 483 | -------------------------------------------------------------------------------- /loss/distortion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from pytorch_msssim import ms_ssim 4 | 5 | 6 | class Distortion(torch.nn.Module): 7 | def __init__(self, distortion_type): 8 | super(Distortion, self).__init__() 9 | if distortion_type == 'MSE': 10 | self.metric = nn.MSELoss() 11 | elif distortion_type == 'MS-SSIM': 12 | self.metric = ms_ssim 13 | else: 14 | print("Unknown distortion type!") 15 | raise ValueError 16 | 17 | def forward(self, X, Y): 18 | if self.metric == ms_ssim: 19 | return 1 - self.metric(X, Y, data_range=1) 20 | else: 21 | return self.metric(X, Y) -------------------------------------------------------------------------------- /net/elic.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | ## This module contains modules implementing standard synthesis and analysis transforms 7 | 8 | from typing import List, Optional 9 | 10 | import torch 11 | import torch.nn as nn 12 | from compressai.layers import GDN, AttentionBlock 13 | from torch import Tensor 14 | 15 | from layer.layer_utils import make_conv, make_deconv 16 | 17 | 18 | class ConvGDNAnalysis(nn.Module): 19 | def __init__( 20 | self, network_channels: int = 128, compression_channels: int = 192 21 | ) -> None: 22 | """ 23 | Analysis transfrom from scale hyperprior (https://arxiv.org/abs/1802.01436) 24 | 25 | Encodes each image in a video independently. 26 | """ 27 | super().__init__() 28 | self._compression_channels = compression_channels 29 | 30 | self.transforms = nn.Sequential( 31 | make_conv(3, network_channels, kernel_size=5, stride=2), 32 | GDN(network_channels), 33 | make_conv(network_channels, network_channels, kernel_size=5, stride=2), 34 | GDN(network_channels), 35 | make_conv(network_channels, network_channels, kernel_size=5, stride=2), 36 | GDN(network_channels), 37 | make_conv(network_channels, compression_channels, kernel_size=5, stride=2), 38 | ) 39 | self._num_down = 4 40 | 41 | @property 42 | def compression_channels(self) -> int: 43 | return self._compression_channels 44 | 45 | @property 46 | def num_downsampling_layers(self) -> int: 47 | return self._num_down 48 | 49 | def forward(self, video_frames: Tensor) -> Tensor: 50 | """ 51 | Args: 52 | video_frames: frames of a batch of clips. Expected shape [B, T, C, H, W], 53 | which is reshaped to [BxT, C, H, W], hyperprior model encoder is applied 54 | and output is reshaped back to [B, T, , h, w]. 55 | Returns: 56 | embeddings: embeddings of shape [B, T, , h, w], obtained 57 | by running ScaleHyperprior.image_analysis(). 58 | """ 59 | assert ( 60 | video_frames.dim() == 5 61 | ), f"Expected [B, T, C, H, W] got {video_frames.shape}" 62 | embeddings = self.transforms(video_frames.reshape(-1, *video_frames.shape[2:])) 63 | return embeddings.reshape(*video_frames.shape[:2], *embeddings.shape[1:]) 64 | 65 | 66 | class ConvGDNSynthesis(nn.Module): 67 | def __init__( 68 | self, network_channels: int = 128, compression_channels: int = 192 69 | ) -> None: 70 | """ 71 | Synthesis transfrom from scale hyperprior (https://arxiv.org/abs/1802.01436) 72 | 73 | Decodes each image in a video independently 74 | """ 75 | 76 | super().__init__() 77 | self._compression_channels = 192 78 | 79 | self.transforms = nn.Sequential( 80 | make_deconv( 81 | compression_channels, network_channels, kernel_size=5, stride=2 82 | ), 83 | GDN(network_channels, inverse=True), 84 | make_deconv(network_channels, network_channels, kernel_size=5, stride=2), 85 | GDN(network_channels, inverse=True), 86 | make_deconv(network_channels, network_channels, kernel_size=5, stride=2), 87 | GDN(network_channels, inverse=True), 88 | make_deconv(network_channels, 3, kernel_size=5, stride=2), 89 | ) 90 | 91 | @property 92 | def compression_channels(self) -> int: 93 | return self._compression_channels 94 | 95 | def forward(self, x: Tensor, frames_shape: torch.Size) -> Tensor: 96 | """ 97 | Args: 98 | x: the (reconstructed) latent embdeddings to be decoded to images, 99 | expected shape [B, T, C, H, W] 100 | frames_shape: shape of the video clip to be reconstructed. 101 | Returns: 102 | reconstruction: reconstruction of the original video clip with shape 103 | [B, T, C, H, W] = frames_shape. 104 | """ 105 | assert x.dim() == 5, f"Expected [B, T, C, H, W] got {x.shape}" 106 | # Treat T as part of the Batch dimension, storing values to reshape back 107 | B, T, *_ = x.shape 108 | x = x.reshape(-1, *x.shape[2:]) 109 | assert len(frames_shape) == 5 110 | 111 | x = self.transforms(x) # final reconstruction 112 | x = x.reshape(B, T, *x.shape[1:]) 113 | return x[..., : frames_shape[-2], : frames_shape[-1]] 114 | 115 | 116 | class ResidualUnit(nn.Module): 117 | """Simple residual unit""" 118 | 119 | def __init__(self, N: int) -> None: 120 | super().__init__() 121 | self.conv = nn.Sequential( 122 | make_conv(N, N // 2, kernel_size=1), 123 | nn.ReLU(inplace=True), 124 | make_conv(N // 2, N // 2, kernel_size=3), 125 | nn.ReLU(inplace=True), 126 | make_conv(N // 2, N, kernel_size=1), 127 | ) 128 | self.activation = nn.ReLU(inplace=True) 129 | 130 | def forward(self, x: Tensor) -> Tensor: 131 | identity = x 132 | out = self.conv(x) 133 | out += identity 134 | out = self.activation(out) 135 | return out 136 | 137 | 138 | def sinusoidal_embedding(values: torch.Tensor, dim=256, max_period=64): 139 | assert values.dim() == 1 and (dim % 2) == 0 140 | exponents = torch.linspace(0, 1, steps=(dim // 2)) 141 | freqs = torch.pow(max_period, -1.0 * exponents).to(device=values.device) 142 | args = values.view(-1, 1) * freqs.view(1, dim // 2) 143 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 144 | return embedding 145 | 146 | 147 | class AdaLN(nn.Module): 148 | default_embedding_dim = 256 149 | 150 | def __init__(self, dim, embed_dim=None): 151 | super().__init__() 152 | embed_dim = embed_dim or self.default_embedding_dim 153 | self.embedding_layer = nn.Sequential( 154 | nn.GELU(), 155 | nn.Linear(embed_dim, 2 * dim), 156 | nn.Unflatten(1, unflattened_size=(1, 1, 2 * dim)) 157 | ) 158 | 159 | def forward(self, x, emb): 160 | # layer norm 161 | x = x.permute(0, 2, 3, 1).contiguous() 162 | # x = self.norm(x) 163 | # AdaLN 164 | embedding = self.embedding_layer(emb) 165 | shift, scale = torch.chunk(embedding, chunks=2, dim=-1) 166 | x = x * (1 + scale) + shift 167 | x = x.permute(0, 3, 1, 2).contiguous() 168 | return x 169 | 170 | 171 | class ConditionalResidualUnit(nn.Module): 172 | def __init__(self, N: int) -> None: 173 | super().__init__() 174 | self.pre_conv = make_conv(N, N // 2, kernel_size=1) 175 | self.adaLN = AdaLN(N // 2) 176 | self.post_conv = nn.Sequential( 177 | nn.ReLU(inplace=True), 178 | make_conv(N // 2, N // 2, kernel_size=3), 179 | # nn.ReLU(inplace=True), 180 | make_conv(N // 2, N, kernel_size=1), 181 | ) 182 | self.activation = nn.ReLU(inplace=True) 183 | 184 | def forward(self, x: Tensor, emb: Tensor) -> Tensor: 185 | identity = x 186 | out = self.pre_conv(x) 187 | out = self.adaLN(out, emb) 188 | out = self.post_conv(out) 189 | out += identity 190 | out = self.activation(out) 191 | return out 192 | 193 | 194 | class ELICAnalysis(nn.Module): 195 | def __init__( 196 | self, 197 | num_residual_blocks=3, 198 | channels: List[int] = [256, 256, 256, 192], 199 | compression_channels: Optional[int] = None, 200 | max_frames: Optional[int] = None, 201 | ) -> None: 202 | """Analysis transform from ELIC (https://arxiv.org/abs/2203.10886), which 203 | can be configured to match the one from "Devil's in the Details" 204 | (https://arxiv.org/abs/2203.08450). 205 | 206 | Args: 207 | num_residual_blocks: defaults to 3. 208 | channels: defaults to [128, 160, 192, 192]. 209 | compression_channels: optional, defaults to None. If provided, it must equal 210 | the last element of `channels`. 211 | max_frames: optional, defaults to None. If provided, the input is chunked 212 | into max_frames elements, otherwise the entire batch is processed at 213 | once. This is useful when large sequences are to be processed and can 214 | be used to manage memory a bit better. 215 | """ 216 | super().__init__() 217 | if len(channels) != 4: 218 | raise ValueError(f"ELIC uses 4 conv layers (not {len(channels)}).") 219 | if compression_channels is not None and compression_channels != channels[-1]: 220 | raise ValueError( 221 | "output_channels specified but does not match channels: " 222 | f"{compression_channels} vs. {channels}" 223 | ) 224 | self._compression_channels = ( 225 | compression_channels if compression_channels is not None else channels[-1] 226 | ) 227 | self._max_frames = max_frames 228 | 229 | def res_units(N): 230 | return [ResidualUnit(N) for _ in range(num_residual_blocks)] 231 | 232 | channels = [3] + channels 233 | 234 | self.transforms = nn.Sequential( 235 | make_conv(channels[0], channels[1], kernel_size=5, stride=2), 236 | *res_units(channels[1]), 237 | make_conv(channels[1], channels[2], kernel_size=5, stride=2), 238 | *res_units(channels[2]), 239 | AttentionBlock(channels[2]), 240 | make_conv(channels[2], channels[3], kernel_size=5, stride=2), 241 | *res_units(channels[3]), 242 | make_conv(channels[3], channels[4], kernel_size=5, stride=2), 243 | AttentionBlock(channels[4]), 244 | ) 245 | 246 | @property 247 | def compression_channels(self) -> int: 248 | return self._compression_channels 249 | 250 | def forward(self, x: Tensor) -> Tensor: 251 | x = self.transforms(x) 252 | return x 253 | 254 | 255 | class ELICSynthesis(nn.Module): 256 | def __init__( 257 | self, 258 | num_residual_blocks=3, 259 | channels: List[int] = [192, 256, 256, 3], 260 | output_channels: Optional[int] = None 261 | ) -> None: 262 | """ 263 | Synthesis transform from ELIC (https://arxiv.org/abs/2203.10886). 264 | 265 | Args: 266 | num_residual_blocks: defaults to 3. 267 | channels: _defaults to [192, 160, 128, 3]. 268 | output_channels: optional, defaults to None. If provided, it must equal 269 | the last element of `channels`. 270 | """ 271 | super().__init__() 272 | if len(channels) != 4: 273 | raise ValueError(f"ELIC uses 4 conv layers (not {channels}).") 274 | if output_channels is not None and output_channels != channels[-1]: 275 | raise ValueError( 276 | "output_channels specified but does not match channels: " 277 | f"{output_channels} vs. {channels}" 278 | ) 279 | 280 | self._compression_channels = channels[0] 281 | 282 | def res_units(N: int) -> List: 283 | return [ResidualUnit(N) for _ in range(num_residual_blocks)] 284 | 285 | channels = [channels[0]] + channels 286 | self.transforms = nn.Sequential( 287 | AttentionBlock(channels[0]), 288 | make_deconv(channels[0], out_channels=channels[1], kernel_size=5, stride=2), 289 | *res_units(channels[1]), 290 | make_deconv(channels[1], out_channels=channels[2], kernel_size=5, stride=2), 291 | AttentionBlock(channels[2]), 292 | *res_units(channels[2]), 293 | make_deconv(channels[2], out_channels=channels[3], kernel_size=5, stride=2), 294 | *res_units(channels[3]), 295 | make_deconv(channels[3], out_channels=channels[4], kernel_size=5, stride=2), 296 | ) 297 | 298 | @property 299 | def compression_channels(self) -> int: 300 | return self._compression_channels 301 | 302 | def forward(self, x: Tensor) -> Tensor: 303 | x = self.transforms(x) 304 | return x 305 | 306 | 307 | class ConditionalAttentionBlock(nn.Module): 308 | def __init__(self, N: int): 309 | super().__init__() 310 | self.conv_a = nn.Sequential(ConditionalResidualUnit(N), 311 | ConditionalResidualUnit(N), 312 | ConditionalResidualUnit(N)) 313 | 314 | self.conv_b1 = nn.Sequential( 315 | ConditionalResidualUnit(N), 316 | ConditionalResidualUnit(N), 317 | ConditionalResidualUnit(N), 318 | ) 319 | self.conv_b2 = make_conv(N, N, kernel_size=1) 320 | 321 | def forward(self, x: Tensor, emb: Tensor) -> Tensor: 322 | identity = x 323 | a = x 324 | for conv_a in self.conv_a: 325 | a = conv_a(a, emb) 326 | for conv_b1 in self.conv_b1: 327 | x = conv_b1(x, emb) 328 | b = self.conv_b2(x) 329 | out = a * torch.sigmoid(b) 330 | out += identity 331 | return out 332 | 333 | 334 | class AdaptiveELICSynthesis(nn.Module): 335 | def __init__( 336 | self, 337 | num_residual_blocks=3, 338 | channels: List[int] = [192, 256, 256, 3], 339 | output_channels: Optional[int] = None 340 | ) -> None: 341 | super().__init__() 342 | if len(channels) != 4: 343 | raise ValueError(f"ELIC uses 4 conv layers (not {channels}).") 344 | if output_channels is not None and output_channels != channels[-1]: 345 | raise ValueError( 346 | "output_channels specified but does not match channels: " 347 | f"{output_channels} vs. {channels}" 348 | ) 349 | 350 | self._compression_channels = channels[0] 351 | 352 | def cond_res_units(N: int) -> List: 353 | return [ConditionalResidualUnit(N) for _ in range(num_residual_blocks)] 354 | 355 | channels = [channels[0]] + channels 356 | self.transforms_1 = ConditionalAttentionBlock(channels[0]) 357 | self.deconv_1 = make_deconv(channels[0], out_channels=channels[1], kernel_size=5, stride=2) 358 | self.transforms_2 = nn.Sequential(*cond_res_units(channels[1])) 359 | self.deconv_2 = make_deconv(channels[1], out_channels=channels[2], kernel_size=5, stride=2) 360 | self.transforms_3 = nn.Sequential(ConditionalAttentionBlock(channels[2]), 361 | *cond_res_units(channels[2])) 362 | self.deconv_3 = make_deconv(channels[2], out_channels=channels[3], kernel_size=5, stride=2) 363 | self.transforms_4 = nn.Sequential(*cond_res_units(channels[3])) 364 | self.deconv_4 = make_deconv(channels[3], out_channels=channels[4], kernel_size=5, stride=2) 365 | 366 | def forward(self, x: Tensor, emb: Tensor) -> Tensor: 367 | x = self.transforms_1(x, emb) 368 | x = self.deconv_1(x) 369 | for transforms_2 in self.transforms_2: 370 | x = transforms_2(x, emb) 371 | x = self.deconv_2(x) 372 | for transforms_3 in self.transforms_3: 373 | x = transforms_3(x, emb) 374 | x = self.deconv_3(x) 375 | for transforms_4 in self.transforms_4: 376 | x = transforms_4(x, emb) 377 | x = self.deconv_4(x) 378 | return x 379 | 380 | 381 | if __name__ == '__main__': 382 | g_s = AdaptiveELICSynthesis() 383 | y = torch.randn(1, 192, 8, 8) 384 | emb = torch.randn(1, 256) 385 | x = g_s(y, emb) 386 | print(x.shape) 387 | -------------------------------------------------------------------------------- /net/mit.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import time 4 | from functools import lru_cache 5 | 6 | import constriction 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | from layer.layer_utils import quantize_ste, make_conv 12 | from layer.transformer import TransformerBlock, SwinTransformerBlock 13 | from net.elic import ELICAnalysis, ELICSynthesis 14 | 15 | 16 | # from compressai.entropy_models import GaussianConditional 17 | 18 | 19 | class GaussianMixtureEntropyModel(nn.Module): 20 | def __init__( 21 | self, 22 | minmax: int = 64 23 | ): 24 | super().__init__() 25 | self.minmax = minmax 26 | self.samples = torch.arange(-minmax, minmax + 1, 1, dtype=torch.float32).cuda() 27 | self.laplace = torch.distributions.Laplace(0, 1) 28 | self.pmf_laplace = self.laplace.cdf(self.samples + 0.5) - self.laplace.cdf(self.samples - 0.5) 29 | # self.gaussian_conditional = GaussianConditional(None) 30 | 31 | def update_minmax(self, minmax): 32 | self.minmax = minmax 33 | self.samples = torch.arange(-minmax, minmax + 1, 1, dtype=torch.float32).cuda() 34 | self.pmf_laplace = self.laplace.cdf(self.samples + 0.5) - self.laplace.cdf(self.samples - 0.5) 35 | 36 | def get_GMM_likelihood(self, latent_hat, probs, means, scales): 37 | gaussian1 = torch.distributions.Normal(means[0], scales[0]) 38 | gaussian2 = torch.distributions.Normal(means[1], scales[1]) 39 | gaussian3 = torch.distributions.Normal(means[2], scales[2]) 40 | likelihoods_0 = gaussian1.cdf(latent_hat + 0.5) - gaussian1.cdf(latent_hat - 0.5) 41 | likelihoods_1 = gaussian2.cdf(latent_hat + 0.5) - gaussian2.cdf(latent_hat - 0.5) 42 | likelihoods_2 = gaussian3.cdf(latent_hat + 0.5) - gaussian3.cdf(latent_hat - 0.5) 43 | 44 | likelihoods = 0.999 * (probs[0] * likelihoods_0 + probs[1] * likelihoods_1 + probs[2] * likelihoods_2) 45 | + 0.001 * (self.laplace.cdf(latent_hat + 0.5) - self.laplace.cdf(latent_hat - 0.5)) 46 | likelihoods = likelihoods + 1e-10 47 | return likelihoods 48 | 49 | def get_GMM_pmf(self, probs, means, scales): 50 | L = self.samples.size(0) 51 | num_symbol = probs.size(1) 52 | samples = self.samples.unsqueeze(0).repeat(num_symbol, 1) # N 65 53 | scales = scales.unsqueeze(-1).repeat(1, 1, L) 54 | means = means.unsqueeze(-1).repeat(1, 1, L) 55 | probs = probs.unsqueeze(-1).repeat(1, 1, L) 56 | likelihoods_0 = self._likelihood(samples, scales[0], means=means[0]) 57 | likelihoods_1 = self._likelihood(samples, scales[1], means=means[1]) 58 | likelihoods_2 = self._likelihood(samples, scales[2], means=means[2]) 59 | pmf_clip = (0.999 * (probs[0] * likelihoods_0 + probs[1] * likelihoods_1 + probs[2] * likelihoods_2) 60 | + 0.001 * self.pmf_laplace) 61 | return pmf_clip 62 | 63 | def _likelihood(self, inputs, scales, means=None): 64 | half = float(0.5) 65 | 66 | if means is not None: 67 | values = inputs - means 68 | else: 69 | values = inputs 70 | values = torch.abs(values) 71 | upper = self._standardized_cumulative((half - values) / scales) 72 | lower = self._standardized_cumulative((-half - values) / scales) 73 | likelihood = upper - lower 74 | return likelihood 75 | 76 | def _standardized_cumulative(self, inputs): 77 | half = float(0.5) 78 | const = float(-(2 ** -0.5)) 79 | # Using the complementary error function maximizes numerical precision. 80 | return half * torch.erfc(const * inputs) 81 | 82 | def compress(self, symbols, probs, means, scales): 83 | pmf_clip = self.get_GMM_pmf(probs, means, scales) 84 | model_family = constriction.stream.model.Categorical() # note empty `()` 85 | probabilities = pmf_clip.cpu().numpy().astype(np.float64) 86 | symbols = symbols.reshape(-1) 87 | symbols = (symbols + self.minmax).cpu().numpy().astype(np.int32) 88 | encoder = constriction.stream.queue.RangeEncoder() 89 | encoder.encode(symbols, model_family, probabilities) 90 | compressed = encoder.get_compressed() 91 | return compressed 92 | 93 | def decompress(self, compressed, probs, means, scales): 94 | pmf = self.get_GMM_pmf(probs, means, scales).cpu().numpy().astype(np.float64) 95 | model = constriction.stream.model.Categorical() 96 | decoder = constriction.stream.queue.RangeDecoder(compressed) 97 | symbols = decoder.decode(model, pmf) 98 | symbols = torch.from_numpy(symbols).to(probs.device) - self.minmax 99 | symbols = torch.tensor(symbols, dtype=torch.float32) 100 | return symbols 101 | 102 | 103 | @lru_cache() 104 | def get_coding_order(target_shape, context_mode, device, step=12): 105 | if context_mode == 'quincunx': 106 | context_tensor = torch.tensor([[4, 2, 4, 0], [3, 4, 3, 4], [4, 1, 4, 2]]).to(device) 107 | elif context_mode == 'checkerboard2': 108 | context_tensor = torch.tensor([[1, 0], [0, 1]]).to(device) 109 | elif context_mode == 'checkerboard4': 110 | context_tensor = torch.tensor([[0, 2], [3, 1]]).to(device) 111 | elif context_mode == 'qlds': 112 | B, C, H, W = target_shape 113 | 114 | def get_qlds(H, W): 115 | n, m, g = 0, 0, 1.32471795724474602596 116 | a1, a2 = 1.0 / g, 1.0 / g / g 117 | context_tensor = torch.zeros((H, W)).to(device) - 1 118 | while m < H * W: 119 | n += 1 120 | x = int(round(((0.5 + n * a1) % 1) * H)) % H 121 | y = int(round(((0.5 + n * a2) % 1) * W)) % W 122 | if context_tensor[x, y] == -1: 123 | context_tensor[x, y] = m 124 | m += 1 125 | return context_tensor 126 | 127 | context_tensor = torch.tensor(get_qlds(H, W), dtype=torch.int) 128 | 129 | def gamma_func(alpha=1.): 130 | return lambda r: r ** alpha 131 | 132 | ratio = 1. * (np.arange(step) + 1) / step 133 | gamma = gamma_func(alpha=2.2) 134 | L = H * W # total number of tokens 135 | mask_ratio = np.clip(np.floor(L * gamma(ratio)), 0, L - 1) 136 | for i in range(step): 137 | context_tensor = torch.where((context_tensor <= mask_ratio[i]) * (context_tensor > i), 138 | torch.ones_like(context_tensor) * i, context_tensor) 139 | return context_tensor 140 | else: 141 | context_tensor = context_mode 142 | B, C, H, W = target_shape 143 | Hp, Wp = context_tensor.size() 144 | coding_order = torch.tile(context_tensor, (H // Hp + 1, W // Wp + 1))[:H, :W] 145 | return coding_order 146 | 147 | 148 | class MaskedImageTransformer(nn.Module): 149 | def __init__(self, latent_dim, dim=768, depth=12, num_heads=12, window_size=24, 150 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 151 | drop_path=0., norm_layer=nn.LayerNorm, transformer='swin'): 152 | super().__init__() 153 | self.dim = dim 154 | self.depth = depth 155 | if transformer == 'swin': 156 | window_size = 4 157 | num_heads = 8 158 | self.blocks = nn.ModuleList([ 159 | SwinTransformerBlock(dim=dim, 160 | num_heads=num_heads, 161 | window_size=window_size, 162 | shift_size=0 if (i % 2 == 0) else window_size // 2, 163 | mlp_ratio=mlp_ratio, 164 | qkv_bias=qkv_bias, qk_scale=qk_scale, 165 | drop=drop, attn_drop=attn_drop, 166 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 167 | norm_layer=norm_layer) 168 | for i in range(depth)]) 169 | else: 170 | self.blocks = nn.ModuleList([ 171 | TransformerBlock(dim=dim, 172 | num_heads=num_heads, window_size=window_size, 173 | mlp_ratio=mlp_ratio, 174 | qkv_bias=qkv_bias, qk_scale=qk_scale, 175 | drop=drop, attn_drop=attn_drop, 176 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 177 | norm_layer=norm_layer) 178 | for i in range(depth)]) 179 | self.delta = 5.0 180 | self.embedding_layer = nn.Linear(latent_dim, dim) 181 | # self.positional_encoding = LearnedPosition(window_size * window_size, dim) 182 | self.entropy_parameters = nn.Sequential( 183 | make_conv(dim, dim * 4, 1, 1), 184 | nn.GELU(), 185 | make_conv(dim * 4, dim * 4, 1, 1), 186 | nn.GELU(), 187 | make_conv(dim * 4, latent_dim * 9, 1, 1), 188 | ) 189 | 190 | self.gmm_model = GaussianMixtureEntropyModel() 191 | self.laplace = torch.distributions.Laplace(0, 1) 192 | self.mask_token = nn.Parameter(torch.zeros(1, 1, latent_dim), requires_grad=True) 193 | 194 | def forward_with_given_mask(self, latent_hat, mask, slice_size=None): 195 | B, C, H, W = latent_hat.size() 196 | input_resolution = (H, W) 197 | x = latent_hat.flatten(2).transpose(1, 2) # B L N 198 | mask_BLN = mask.flatten(2).transpose(1, 2) # B L N 199 | x_masked = x * mask_BLN + self.mask_token * (1 - mask_BLN) 200 | 201 | x_masked = self.embedding_layer(x_masked / self.delta) 202 | # x = self.positional_encoding(x) 203 | for _, blk in enumerate(self.blocks): 204 | x_masked = blk(x_masked, input_resolution, slice_size) 205 | x_out = x_masked.transpose(1, 2).reshape(B, self.dim, H, W) 206 | params = self.entropy_parameters(x_out) 207 | probs, means, scales = params.chunk(3, dim=1) 208 | probs = torch.softmax(probs.reshape(B, 3, C, H, W), dim=1).transpose(0, 1) 209 | means = means.reshape(B, 3, C, H, W).transpose(0, 1) 210 | scales = torch.abs(scales).reshape(B, 3, C, H, W).transpose(0, 1).clamp(1e-10, 1e10) 211 | return probs, means, scales 212 | 213 | def forward_with_random_mask(self, latent): 214 | B, C, H, W = latent.size() 215 | half = float(0.5) 216 | noise = torch.empty_like(latent).uniform_(-half, half) 217 | latent_noise = latent + noise 218 | latent_hat = quantize_ste(latent) 219 | 220 | def generate_random_mask(latent, r): 221 | mask_loc = torch.randn(H * W).to(latent.get_device()) 222 | threshold = torch.sort(mask_loc)[0][r] 223 | mask = torch.where(mask_loc > threshold, torch.ones_like(mask_loc), torch.zeros_like(mask_loc)) 224 | mask = mask.reshape(1, 1, H, W).repeat(B, C, 1, 1) 225 | return mask 226 | 227 | r = math.floor(np.random.uniform(0.05, 0.99) * H * W) # drop probability 228 | mask = generate_random_mask(latent_hat, r) 229 | mask_params = mask.unsqueeze(0).repeat(3, 1, 1, 1, 1) 230 | probs, means, scales = self.forward_with_given_mask(latent_hat, mask) 231 | likelihoods_masked = torch.ones_like(latent_hat) 232 | likelihoods = self.gmm_model.get_GMM_likelihood(latent_noise[mask == 0], 233 | probs[mask_params == 0].reshape(3, -1), 234 | means[mask_params == 0].reshape(3, -1), 235 | scales[mask_params == 0].reshape(3, -1)) 236 | likelihoods_masked[mask == 0] = likelihoods 237 | return latent_hat, likelihoods_masked 238 | 239 | def inference(self, latent_hat, context_mode='qlds', slice_size=None): 240 | coding_order = get_coding_order(latent_hat.shape, context_mode, latent_hat.get_device(), step=12) # H W 241 | coding_order = coding_order.reshape(1, 1, *coding_order.shape).repeat(latent_hat.shape[0], latent_hat.shape[1], 242 | 1, 1) 243 | total_steps = int(coding_order.max() + 1) 244 | likelihoods = torch.zeros_like(latent_hat) 245 | for i in range(total_steps): 246 | ctx_locations = (coding_order < i) 247 | mask_i = torch.where(ctx_locations, torch.ones_like(latent_hat), torch.zeros_like(latent_hat)) 248 | probs_i, means_i, scales_i = self.forward_with_given_mask(latent_hat, mask_i, slice_size) 249 | encoding_locations = (coding_order == i) 250 | mask_params_i = encoding_locations.unsqueeze(0).repeat(3, 1, 1, 1, 1) 251 | likelihoods_i = self.gmm_model.get_GMM_likelihood(latent_hat[encoding_locations], 252 | probs_i[mask_params_i].reshape(3, -1), 253 | means_i[mask_params_i].reshape(3, -1), 254 | scales_i[mask_params_i].reshape(3, -1)) 255 | likelihoods[encoding_locations] = likelihoods_i 256 | return likelihoods 257 | 258 | def compress(self, latent, context_mode='qlds'): 259 | B, C, H, W = latent.size() 260 | latent_hat = torch.round(latent) 261 | self.gmm_model.update_minmax(int(latent_hat.max().item())) 262 | coding_order = get_coding_order(latent.shape, context_mode, latent_hat.get_device(), step=12) # H W 263 | coding_order = coding_order.reshape(1, 1, *coding_order.shape).repeat(latent.shape[0], latent.shape[1], 1, 1) 264 | total_steps = int(coding_order.max() + 1) 265 | t0 = time.time() 266 | strings = [] 267 | for i in range(total_steps): 268 | # print('STEP', i) 269 | ctx_locations = (coding_order < i) 270 | encoding_locations = (coding_order == i) 271 | mask_params_i = encoding_locations.unsqueeze(0).repeat(3, 1, 1, 1, 1) 272 | mask_i = torch.where(ctx_locations, torch.ones_like(latent), torch.zeros_like(latent)) 273 | probs_i, means_i, scales_i = self.forward_with_given_mask(latent_hat, mask_i) 274 | string_i = self.gmm_model.compress(latent_hat[encoding_locations], 275 | probs_i[mask_params_i].reshape(3, -1), 276 | means_i[mask_params_i].reshape(3, -1), 277 | scales_i[mask_params_i].reshape(3, -1)) 278 | strings.append(string_i) 279 | print('compress', time.time() - t0) 280 | return strings 281 | 282 | def decompress(self, strings, latent_size, device, context_mode='qlds'): 283 | B, C, H, W = latent_size 284 | coding_order = get_coding_order(latent_size, context_mode, device, step=12) # H W 285 | coding_order = coding_order.reshape(1, 1, *coding_order.shape).repeat(B, C, 1, 1) 286 | total_steps = int(coding_order.max() + 1) 287 | t0 = time.time() 288 | latent_hat = torch.zeros(latent_size).to(device) 289 | for i in range(total_steps): 290 | ctx_locations = (coding_order < i) 291 | encoding_locations = (coding_order == i) 292 | mask_params_i = encoding_locations.unsqueeze(0).repeat(3, 1, 1, 1, 1) 293 | mask_i = torch.where(ctx_locations, torch.ones_like(latent_hat), torch.zeros_like(latent_hat)) 294 | probs_i, means_i, scales_i = self.forward_with_given_mask(latent_hat, mask_i) 295 | symbols_i = self.gmm_model.decompress(strings[i], 296 | probs_i[mask_params_i].reshape(3, -1), 297 | means_i[mask_params_i].reshape(3, -1), 298 | scales_i[mask_params_i].reshape(3, -1)) 299 | latent_hat[encoding_locations] = symbols_i 300 | print('decompress', time.time() - t0) 301 | return latent_hat 302 | 303 | @torch.jit.ignore 304 | def no_weight_decay(self): 305 | return {'mask_token'} 306 | 307 | 308 | class MaskedImageModelingTransformer(nn.Module): 309 | def __init__(self): 310 | super().__init__() 311 | self.g_a = ELICAnalysis() 312 | self.g_s = ELICSynthesis() 313 | self.mim = MaskedImageTransformer(192) 314 | 315 | def forward(self, x): 316 | y = self.g_a(x) 317 | y_hat, likelihoods = self.mim.forward_with_random_mask(y) 318 | x_hat = self.g_s(y_hat) 319 | return { 320 | "x_hat": x_hat, 321 | "likelihoods": likelihoods, 322 | } 323 | 324 | def inference(self, x): 325 | # TODO Patch-wise inference for off-the-shelf Transformers 326 | y = self.g_a(x) 327 | y_hat = torch.round(y) 328 | likelihoods = self.mim.inference(y_hat) 329 | x_hat = self.g_s(y_hat) 330 | return { 331 | "x_hat": x_hat, 332 | "likelihoods": likelihoods, 333 | } 334 | 335 | def real_inference(self, x): 336 | num_pixels = x.size(2) * x.size(3) 337 | y = self.g_a(x) 338 | strings = self.mim.compress(y) 339 | y_hat = self.mim.decompress(strings, y.shape, x.get_device()) 340 | x_hat = self.g_s(y_hat) 341 | bpp = sum([string.size * 32 for string in strings]) / num_pixels 342 | return { 343 | "x_hat": x_hat, 344 | "bpp": bpp, 345 | } 346 | -------------------------------------------------------------------------------- /results/mt_rd.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wsxtyrdd/Masked-Transformer-For-Image-Compression/2dc391982c0d4e798acb158a552b5e39b968290a/results/mt_rd.png -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append("/media/D/wangsixian/MT") 5 | 6 | import configargparse 7 | import time 8 | from datetime import datetime 9 | import numpy as np 10 | import random 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.nn.functional as F 15 | import wandb 16 | 17 | torch.backends.cudnn.benchmark = False 18 | from data.datasets import get_loader, get_dataset 19 | from loss.distortion import Distortion 20 | from utils import logger_configuration, load_weights, AverageMeter, save_checkpoint, worker_init_fn_seed 21 | from net.mit import MaskedImageModelingTransformer 22 | 23 | 24 | def train_one_epoch(epoch, net, train_loader, test_loader, optimizer, 25 | device, logger): 26 | local_rank = torch.distributed.get_rank() if config.multi_gpu else 0 27 | best_loss = float("inf") 28 | mse_loss_wrapper = Distortion("MSE").to(device) 29 | ms_ssim_loss_wrapper = Distortion("MS-SSIM").to(device) 30 | elapsed, losses, psnrs, ms_ssim_dbs, bpps = [AverageMeter() for _ in range(5)] 31 | metrics = [elapsed, losses, psnrs, ms_ssim_dbs, bpps] 32 | global global_step 33 | for batch_idx, input_image in enumerate(train_loader): 34 | net.train() 35 | B, C, H, W = input_image.shape 36 | num_pixels = B * H * W 37 | input_image = input_image.to(device) 38 | optimizer.zero_grad() 39 | global_step += 1 40 | start_time = time.time() 41 | results = net(input_image) 42 | bpp_loss = torch.sum(torch.clamp(-torch.log2(results["likelihoods"]), 0, 50)) / num_pixels 43 | # if config.distortion_metric == "MSE": 44 | mse_loss = mse_loss_wrapper(results["x_hat"], input_image) 45 | tot_loss = config.lambda_value * 255 ** 2 * mse_loss + bpp_loss 46 | # elif config.distortion_metric == "MS-SSIM": 47 | # ssim_loss = ms_ssim_loss_wrapper(results["x_hat"], input_image) 48 | # tot_loss = config.lambda_value * ssim_loss + bpp_loss 49 | tot_loss.backward() 50 | 51 | if config.clip_max_norm > 0: 52 | torch.nn.utils.clip_grad_norm_(net.parameters(), config.clip_max_norm) 53 | optimizer.step() 54 | 55 | elapsed.update(time.time() - start_time) 56 | losses.update(tot_loss.item()) 57 | bpps.update(bpp_loss.item()) 58 | 59 | mse_val = mse_loss_wrapper(results["x_hat"], input_image).detach().item() 60 | psnr = 10 * (np.log(1. / mse_val) / np.log(10)) 61 | psnrs.update(psnr.item()) 62 | 63 | ms_ssim_val = ms_ssim_loss_wrapper(results["x_hat"], input_image).detach().item() 64 | ms_ssim_db = -10 * (np.log(ms_ssim_val) / np.log(10)) 65 | ms_ssim_dbs.update(ms_ssim_db.item()) 66 | 67 | if (global_step % config.print_every) == 0 and local_rank == 0: 68 | process = (global_step % train_loader.__len__()) / (train_loader.__len__()) * 100.0 69 | log_info = [ 70 | f'Step [{global_step % train_loader.__len__()}/{train_loader.__len__()}={process:.2f}%]', 71 | f'Loss ({losses.avg:.3f})', 72 | f'Time {elapsed.avg:.2f}', 73 | f'PSNR {psnrs.val:.2f} ({psnrs.avg:.2f})', 74 | f'BPP {bpps.val:.2f} ({bpps.avg:.2f})', 75 | f'MS-SSIM {ms_ssim_dbs.val:.2f} ({ms_ssim_dbs.avg:.2f})', 76 | f'Epoch {epoch}' 77 | ] 78 | log = (' | '.join(log_info)) 79 | logger.info(log) 80 | if config.wandb: 81 | log_dict = {"PSNR": psnrs.avg, 82 | "MS-SSIM": ms_ssim_dbs.avg, 83 | "BPP": bpps.avg, 84 | "loss": losses.avg, 85 | "Step": global_step, 86 | } 87 | wandb.log(log_dict, step=global_step) 88 | for i in metrics: 89 | i.clear() 90 | 91 | if (global_step + 1) % config.test_every == 0 and local_rank == 0: 92 | loss = test(net, test_loader, device, logger) 93 | is_best = loss < best_loss 94 | best_loss = min(loss, best_loss) 95 | if is_best: 96 | save_checkpoint( 97 | { 98 | "epoch": epoch + 1, 99 | "global_step": global_step, 100 | "state_dict": net.state_dict(), 101 | "optimizer": optimizer.state_dict() 102 | }, 103 | is_best, 104 | workdir 105 | ) 106 | 107 | if (global_step + 1) % config.save_every == 0 and local_rank == 0: 108 | save_checkpoint( 109 | { 110 | "epoch": epoch + 1, 111 | "global_step": global_step, 112 | "state_dict": net.state_dict(), 113 | "optimizer": optimizer.state_dict() 114 | }, 115 | False, 116 | workdir, 117 | filename='EP{}.pth.tar'.format(epoch) 118 | ) 119 | 120 | 121 | def test(net, test_loader, device, logger): 122 | with torch.no_grad(): 123 | mse_loss_wrapper = Distortion("MSE").to(device) 124 | ms_ssim_loss_wrapper = Distortion("MS-SSIM").to(device) 125 | elapsed, losses, psnrs, ms_ssim_dbs, bpps = [AverageMeter() for _ in range(5)] 126 | global global_step 127 | for batch_idx, input_image in enumerate(test_loader): 128 | net.eval() 129 | B, C, H, W = input_image.shape 130 | num_pixels = B * H * W 131 | input_image = input_image.to(device) 132 | start_time = time.time() 133 | # crop and pad 134 | p = 64 135 | new_H = (H + p - 1) // p * p 136 | new_W = (W + p - 1) // p * p 137 | padding_left = (new_W - W) // 2 138 | padding_right = new_W - W - padding_left 139 | padding_top = (new_H - H) // 2 140 | padding_bottom = new_H - H - padding_top 141 | input_image_pad = F.pad( 142 | input_image, 143 | (padding_left, padding_right, padding_top, padding_bottom), 144 | mode="constant", 145 | value=0, 146 | ) 147 | results = net.module.inference(input_image_pad) 148 | results["x_hat"] = F.pad( 149 | results["x_hat"], (-padding_left, -padding_right, -padding_top, -padding_bottom) 150 | ) 151 | 152 | mse_loss = mse_loss_wrapper(results["x_hat"], input_image) 153 | bpp_loss = torch.sum(torch.clamp(-torch.log2(results["likelihoods"]), 0, 50)) / num_pixels 154 | tot_loss = config.lambda_value * 255 ** 2 * mse_loss + bpp_loss 155 | elapsed.update(time.time() - start_time) 156 | losses.update(tot_loss.item()) 157 | bpps.update(bpp_loss.item()) 158 | 159 | # results = net.module.real_inference(input_image_pad) 160 | # mse_loss = mse_loss_wrapper(results["x_hat"], input_image) 161 | # bpp_loss = results["bpp"] 162 | # bpps.update(bpp_loss) 163 | 164 | mse_val = mse_loss.item() 165 | psnr = 10 * (np.log(1. / mse_val) / np.log(10)) 166 | psnrs.update(psnr.item()) 167 | 168 | ms_ssim_val = ms_ssim_loss_wrapper(results["x_hat"], input_image).mean().item() 169 | ms_ssim_db = -10 * (np.log(ms_ssim_val) / np.log(10)) 170 | ms_ssim_dbs.update(ms_ssim_db.item()) 171 | 172 | log_info = [ 173 | f'Step [{(batch_idx + 1)}/{test_loader.__len__()}]', 174 | f'Loss ({losses.avg:.3f})', 175 | f'Time {elapsed.val:.2f}', 176 | f'PSNR {psnrs.val:.2f} ({psnrs.avg:.2f})', 177 | f'BPP {bpps.val:.2f} ({bpps.avg:.4f})', 178 | f'MS-SSIM {ms_ssim_dbs.val:.2f} ({ms_ssim_dbs.avg:.2f})', 179 | ] 180 | log = (' | '.join(log_info)) 181 | logger.info(log) 182 | if config.wandb and local_rank == 0: 183 | log_dict = {"[Kodak] PSNR": psnrs.avg, 184 | "[Kodak] MS-SSIM": ms_ssim_dbs.avg, 185 | "[Kodak] BPP": bpps.avg, 186 | "[Kodak] loss": losses.avg, 187 | "[Kodak] Step": global_step, 188 | } 189 | wandb.log(log_dict, step=global_step) 190 | return losses.avg 191 | 192 | 193 | def parse_args(argv): 194 | parser = configargparse.ArgumentParser() 195 | parser.add_argument('--config', is_config_file=True, default='./config/mt.yaml', 196 | help='Path to config file to replace defaults.') 197 | parser.add_argument('--seed', type=int, default=1024, 198 | help='Random seed.') 199 | parser.add_argument('--gpu-id', type=str, default='0', 200 | help='GPU id to use.') 201 | parser.add_argument('--test-only', action='store_true', 202 | help='Test only (and do not run training).') 203 | parser.add_argument('--multi-gpu', action='store_true') 204 | parser.add_argument('--local_rank', type=int, 205 | help='Local rank for distributed training.') 206 | 207 | # logging 208 | parser.add_argument('--exp-name', type=str, default=datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 209 | help='Experiment name, unique id for trainers, logs.') 210 | parser.add_argument('--wandb', action="store_true", 211 | help='Use wandb for logging.') 212 | parser.add_argument('--print-every', type=int, default=30, 213 | help='Frequency of logging.') 214 | 215 | # dataset 216 | parser.add_argument('--dataset-path', default=['/media/Dataset/openimages/**/'], 217 | help='Path to the dataset') 218 | parser.add_argument('--num-workers', type=int, default=8, 219 | help='Number of workers for dataloader.') 220 | parser.add_argument('--training-img-size', type=tuple, default=(384, 384), 221 | help='Size of the training images.') 222 | parser.add_argument('--eval-dataset-path', type=str, 223 | help='Path to the evaluation dataset') 224 | 225 | # optimization 226 | parser.add_argument('--distortion_metric', type=str, default='MSE', 227 | help='Distortion type, MSE/SSIM/Perceptual.') 228 | parser.add_argument('--lambda_value', type=float, default=1, 229 | help='Weight for the commitment loss.') 230 | 231 | # Optimizer configuration parameters 232 | parser.add_argument('--optimizer_type', type=str, default='AdamW', 233 | help='The type of optimizer to use') 234 | parser.add_argument('--init_lr', type=float, default=1e-4, 235 | help='The minimum learning rate for the learning rate policy') 236 | parser.add_argument('--min_lr', type=float, default=1e-4, 237 | help='The minimum learning rate for the learning rate policy') 238 | parser.add_argument('--max_lr', type=float, default=1e-4, 239 | help='The maximum learning rate for the learning rate policy') 240 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.98), 241 | help='The beta values to use for the optimizer') 242 | parser.add_argument('--warmup_epoch', type=int, default=5, 243 | help='The number of epochs to use for the warmup') 244 | parser.add_argument('--weight_decay', type=float, default=0.03, 245 | help='The weight decay value for the optimizer') 246 | parser.add_argument('--clip-max-norm', type=float, default=1.0, 247 | help='Gradient clipping for stable training.') 248 | 249 | # trainer 250 | parser.add_argument('--warmup', action='store_true') 251 | parser.add_argument('--epochs', type=int, default=10, 252 | help='Number of epochs to run the training.') 253 | parser.add_argument('--batch-size', type=int, default=16, 254 | help='Batch size for the training.') 255 | parser.add_argument("--checkpoint", type=str, default=None, 256 | help="Path to a checkpoint model") 257 | parser.add_argument('--save', action="store_true", 258 | help="Save the model at every epoch (no overwrite).") 259 | parser.add_argument('--save-every', type=int, default=10000, 260 | help='Frequency of saving the model.') 261 | parser.add_argument('--test-every', type=int, default=5000, 262 | help='Frequency of running validation.') 263 | 264 | # model 265 | parser.add_argument('--net', type=str, default='MT', 266 | help='Model architecture.') 267 | args = parser.parse_args(argv) 268 | return args 269 | 270 | 271 | def main(argv): 272 | global config 273 | config = parse_args(argv) 274 | 275 | global local_rank 276 | if config.multi_gpu: 277 | dist.init_process_group(backend='nccl') 278 | local_rank = torch.distributed.get_rank() 279 | torch.cuda.set_device(local_rank) 280 | device = torch.device("cuda", local_rank) 281 | else: 282 | local_rank = 0 283 | device = torch.device("cuda") 284 | os.environ['CUDA_VISIBLE_DEVICES'] = str(config.gpu_id) 285 | 286 | # torch.backends.cudnn.benchmark = True 287 | 288 | if config.seed is not None: 289 | torch.manual_seed(config.seed) 290 | random.seed(config.seed) 291 | 292 | config.device = device 293 | job_type = 'test' if config.test_only else 'train' 294 | exp_name = config.net + " " + config.exp_name 295 | global workdir 296 | workdir, logger = logger_configuration(exp_name, job_type, 297 | method=config.net, save_log=(not config.test_only and local_rank == 0)) 298 | net = MaskedImageModelingTransformer().to(device) 299 | if config.multi_gpu: 300 | net = torch.nn.parallel.DistributedDataParallel(net, find_unused_parameters=True) 301 | else: 302 | net = torch.nn.DataParallel(net) 303 | 304 | if config.wandb and local_rank == 0: 305 | print("=============== use wandb ==============") 306 | wandb_init_kwargs = { 307 | 'project': 'ResiComm', 308 | 'name': exp_name, 309 | 'save_code': True, 310 | 'job_type': job_type, 311 | 'config': config.__dict__ 312 | } 313 | wandb.init(**wandb_init_kwargs) 314 | 315 | # shutil.copy(config.config, join(workdir, 'config.yaml')) 316 | 317 | config.logger = logger 318 | logger.info(config.__dict__) 319 | 320 | if config.multi_gpu: 321 | train_dataset, test_dataset = get_dataset(config.dataset_path, config.eval_dataset_path) 322 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 323 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 324 | num_workers=config.num_workers, 325 | pin_memory=True, 326 | batch_size=config.batch_size, 327 | worker_init_fn=worker_init_fn_seed, 328 | sampler=train_sampler) 329 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 330 | batch_size=1, 331 | shuffle=False) 332 | else: 333 | train_loader, test_loader = get_loader(config.dataset_path, config.eval_dataset_path, 334 | config.num_workers, config.batch_size) 335 | 336 | optimizer_cfg = {'lr': config.max_lr, 337 | 'betas': config.betas, 338 | 'weight_decay': config.weight_decay 339 | } 340 | optimizer = getattr(torch.optim, config.optimizer_type)(net.parameters(), **optimizer_cfg) 341 | 342 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(config.epochs * 0.5), gamma=0.1) 343 | 344 | global global_step 345 | global_step = 0 346 | 347 | if config.checkpoint is not None: 348 | load_weights(net, config.checkpoint, device) 349 | else: 350 | logger.info("No pretrained model is loaded.") 351 | 352 | if config.test_only: 353 | test(net, test_loader, device, logger) 354 | else: 355 | steps_epoch = global_step // train_loader.__len__() 356 | init_lambda_value = config.lambda_value 357 | for epoch in range(steps_epoch, config.epochs): 358 | if config.warmup: 359 | # for lambda warmup 360 | if epoch <= int(config.epochs * 0.1): 361 | config.lambda_value = init_lambda_value * 10 362 | else: 363 | config.lambda_value = init_lambda_value 364 | 365 | logger.info('======Current epoch %s ======' % epoch) 366 | logger.info(f"Learning rate: {optimizer.param_groups[0]['lr']}") 367 | logger.info(f"Lambda value: {config.lambda_value}") 368 | train_one_epoch(epoch, net, train_loader, test_loader, optimizer, device, logger) 369 | lr_scheduler.step() 370 | 371 | 372 | if __name__ == '__main__': 373 | main(sys.argv[1:]) 374 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def save_config(config, file_path): 9 | import json 10 | info_json = json.dumps(config, sort_keys=False, indent=4, separators=(',', ': ')) 11 | f = open(file_path, 'w') 12 | f.write(info_json) 13 | 14 | 15 | def logger_configuration(filename, phase, method='', save_log=True): 16 | logger = logging.getLogger(" ") 17 | workdir = './history/{}/{}'.format(method, filename) 18 | if phase == 'test': 19 | workdir += '_test' 20 | log = workdir + '/{}.log'.format(filename) 21 | samples = workdir + '/samples' 22 | models = workdir + '/models' 23 | if save_log: 24 | makedirs(workdir) 25 | makedirs(samples) 26 | makedirs(models) 27 | 28 | formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", 29 | "%Y-%m-%d %H:%M:%S") 30 | stdhandler = logging.StreamHandler() 31 | stdhandler.setLevel(logging.INFO) 32 | stdhandler.setFormatter(formatter) 33 | logger.addHandler(stdhandler) 34 | if save_log: 35 | filehandler = logging.FileHandler(log) 36 | filehandler.setLevel(logging.INFO) 37 | filehandler.setFormatter(formatter) 38 | logger.addHandler(filehandler) 39 | logger.setLevel(logging.INFO) 40 | return workdir, logger 41 | 42 | 43 | def makedirs(directory): 44 | if not os.path.exists(directory): 45 | os.makedirs(directory) 46 | 47 | 48 | def save_checkpoint(state, is_best, base_dir, filename="checkpoint.pth.tar"): 49 | if is_best: 50 | torch.save(state, base_dir + "/checkpoint_best_loss.pth.tar") 51 | else: 52 | torch.save(state, base_dir + "/" + filename) 53 | 54 | 55 | class AverageMeter: 56 | """Compute running average.""" 57 | 58 | def __init__(self): 59 | self.val = 0 60 | self.avg = 0 61 | self.sum = 0 62 | self.count = 0 63 | 64 | def update(self, val, n=1): 65 | self.val = val 66 | self.sum += val * n 67 | self.count += n 68 | self.avg = self.sum / self.count 69 | 70 | def clear(self): 71 | self.val = 0 72 | self.avg = 0 73 | self.sum = 0 74 | self.count = 0 75 | 76 | 77 | def load_weights(net, model_path, device, remove_keys=None): 78 | try: 79 | pretrained = torch.load(model_path, map_location=device)['state_dict'] 80 | except: 81 | pretrained = torch.load(model_path, map_location=device) 82 | result_dict = {} 83 | for key, weight in pretrained.items(): 84 | result_key = key 85 | load_flag = True 86 | if 'attn_mask' in key: 87 | load_flag = False 88 | if remove_keys is not None: 89 | for remove_key in remove_keys: 90 | if remove_key in key: 91 | load_flag = False 92 | if load_flag: 93 | result_dict[result_key] = weight 94 | print(net.load_state_dict(result_dict, strict=False)) 95 | del result_dict, pretrained 96 | 97 | 98 | def worker_init_fn_seed(worker_id): 99 | seed = 10 100 | seed += worker_id 101 | np.random.seed(seed) 102 | --------------------------------------------------------------------------------