├── README.md
├── config
└── mt.yaml
├── data
└── datasets.py
├── layer
├── layer_utils.py
└── transformer.py
├── loss
└── distortion.py
├── net
├── elic.py
└── mit.py
├── results
└── mt_rd.png
├── train.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | This repo is an implementation for [M2T: Masking Transformers Twice for Faster Decoding](https://openaccess.thecvf.com/content/ICCV2023/html/Mentzer_M2T_Masking_Transformers_Twice_for_Faster_Decoding_ICCV_2023_paper.html) in **pytorch**.
2 |
3 | ## Install
4 |
5 | The latest codes are tested on Ubuntu18.04LTS, CUDA11.7, PyTorch1.9 and Python 3.7.
6 | Some libraries are required to run the codes in this repo, including [constriction](https://bamler-lab.github.io/constriction/), [compressai](https://github.com/InterDigitalInc/CompressAI), and [timm](https://timm.fast.ai/).
7 |
8 | ### Train
9 | ```python
10 | python train.py --config config/mt.yaml # --wandb (if you want to use wandb)
11 | ```
12 | Model checkpoints and logs will be saved in `./history/MT`.
13 |
14 | ### Test
15 | ```python
16 | python train.py --config config/mt.yaml --test-only --eval-dataset-path: 'path_to_kodak'
17 | ```
18 | ### Performance
19 | Red dot is our reproduction with distortion lambda $\lambda=0.0035$.
20 |
21 |
22 |
23 |
24 | ### Pretrained models
25 | To be released.
26 |
27 | ### Acknowledgements
28 | We use constriction for actual entropy coding. Thanks for Fabian Mentzer's help for the clarification of the details of the paper.
--------------------------------------------------------------------------------
/config/mt.yaml:
--------------------------------------------------------------------------------
1 | gpu-id: '0'
2 | multi-gpu: False
3 | test-only: True
4 |
5 | batch-size: 12
6 | eval-dataset-path: '/media/Dataset/kodak'
7 | num-workers: 12
8 | training-img-size: (384, 384)
9 |
10 | checkpoint: '/media/D/wangsixian/MT/history/MT/MT 2023-10-12 16:30:30/checkpoint_best_loss.pth.tar'
11 | init_lr: 1e-4
12 | min_lr: 1e-4
13 | max_lr: 1e-4
14 |
15 |
16 | epochs: 10
17 | warmup: False
18 | print-every: 100
19 | test-every: 2500
20 | distortion_metric: 'MSE'
21 | lambda_value: 0.0035
22 | #beta_value: 0
23 |
--------------------------------------------------------------------------------
/data/datasets.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from glob import glob
4 |
5 | import numpy as np
6 | import torch
7 | from PIL import Image
8 | from torch.utils.data.dataset import Dataset
9 | from torchvision import transforms
10 |
11 | class OpenImages(Dataset):
12 | files = {"train": "train", "test": "test", "val": "validation"}
13 |
14 | def __init__(self, image_dims, data_dir):
15 | self.imgs = []
16 | for dir in data_dir:
17 | self.imgs += glob(os.path.join(dir, '*.jpg'))
18 | self.imgs += glob(os.path.join(dir, '*.png'))
19 | _, self.im_height, self.im_width = image_dims
20 | self.crop_size = self.im_height
21 | self.image_dims = (3, self.im_height, self.im_width)
22 | self.transform = self._transforms()
23 |
24 | def _transforms(self, scale=0, H=0, W=0):
25 | """
26 | Up(down)scale and randomly crop to `crop_size` x `crop_size`
27 | """
28 | transforms_list = [
29 | # transforms.ToPILImage(),
30 | # transforms.RandomHorizontalFlip(),
31 | # transforms.Resize((math.ceil(scale * H), math.ceil(scale * W))),
32 | transforms.ToTensor(),
33 | transforms.RandomCrop((self.im_height, self.im_width), pad_if_needed=True),
34 | # transforms.Resize((self.im_height, self.im_width)),
35 | ]
36 |
37 | # if self.normalize is True:
38 | # transforms_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
39 |
40 | return transforms.Compose(transforms_list)
41 |
42 | def __getitem__(self, idx):
43 | img_path = self.imgs[idx]
44 | img = Image.open(img_path)
45 | img = img.convert('RGB')
46 | transformed = self.transform(img)
47 | return transformed
48 |
49 | def __len__(self):
50 | return len(self.imgs)
51 |
52 |
53 | class Datasets(Dataset):
54 | def __init__(self, dataset_path):
55 | self.data_dir = dataset_path
56 | self.imgs = []
57 | self.imgs += glob(os.path.join(self.data_dir, '*.jpg'))
58 | self.imgs += glob(os.path.join(self.data_dir, '*.png'))
59 | self.imgs.sort()
60 | self.transform = transforms.Compose([
61 | # transforms.RandomCrop((384, 384), pad_if_needed=True),
62 | transforms.ToTensor()])
63 |
64 | def __getitem__(self, item):
65 | image_ori = self.imgs[item]
66 | image = Image.open(image_ori).convert('RGB')
67 | img = self.transform(image)
68 | return img
69 |
70 | def __len__(self):
71 | return len(self.imgs)
72 |
73 |
74 | def get_loader(train_dir, test_dir, num_workers, batch_size):
75 | train_dataset = OpenImages((3, 384, 384), train_dir)
76 | test_dataset = Datasets(test_dir)
77 |
78 | def worker_init_fn_seed(worker_id):
79 | seed = 10
80 | seed += worker_id
81 | np.random.seed(seed)
82 |
83 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
84 | num_workers=num_workers,
85 | pin_memory=True,
86 | batch_size=batch_size,
87 | worker_init_fn=worker_init_fn_seed,
88 | shuffle=True)
89 |
90 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
91 | batch_size=1,
92 | shuffle=False)
93 |
94 | return train_loader, test_loader
95 |
96 |
97 | def get_dataset(train_dir, test_dir):
98 | train_dataset = OpenImages((3, 384, 384), train_dir)
99 | test_dataset = Datasets(test_dir)
100 | return train_dataset, test_dataset
101 |
102 |
103 | def get_test_loader(test_dir):
104 | test_dataset = Datasets(test_dir)
105 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
106 | batch_size=1,
107 | shuffle=False)
108 | return test_loader
109 |
--------------------------------------------------------------------------------
/layer/layer_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from typing import Optional, Tuple
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch import Tensor
11 |
12 |
13 | def make_conv(
14 | in_channels: int,
15 | out_channels: int,
16 | kernel_size: int,
17 | stride: int = 1,
18 | ) -> nn.Conv2d:
19 | return nn.Conv2d(
20 | in_channels,
21 | out_channels,
22 | kernel_size=kernel_size,
23 | stride=stride,
24 | padding=kernel_size // 2,
25 | )
26 |
27 |
28 | def make_deconv(
29 | in_channels: int,
30 | out_channels: int,
31 | kernel_size: int,
32 | stride: int = 1,
33 | ) -> nn.ConvTranspose2d:
34 | return nn.ConvTranspose2d(
35 | in_channels,
36 | out_channels,
37 | kernel_size=kernel_size,
38 | stride=stride,
39 | output_padding=stride - 1,
40 | padding=kernel_size // 2,
41 | )
42 |
43 |
44 | def make_embedding(input_dim: int, hidden_dim: int) -> nn.Module:
45 | """
46 | Constructs and returns an embedding layer, which is a simple (dense) linear layer
47 |
48 | Args:
49 | input_dim: input dimensions (input to linear layer)
50 | hidden_dim: output size of the linear layer
51 |
52 | Returns:
53 | a linear nn.Module, initialized with random uniform weights and biases
54 | """
55 | scale = 1 / input_dim ** 0.5
56 | linear = torch.nn.Linear(input_dim, hidden_dim, bias=True)
57 | # initialise weights in the same was as vct
58 | torch.nn.init.uniform_(linear.weight, -scale, scale)
59 | torch.nn.init.uniform_(linear.bias, -scale, scale)
60 | return linear
61 |
62 |
63 | def init_weights_truncated_normal(m) -> None:
64 | """
65 | Initialise weights with truncated normal.
66 | Weights that fall outside 2 stds are resampled.
67 | See torch.nn.init.trunc_normal_ for details.
68 |
69 | Args:
70 | m: weights
71 |
72 | Examples:
73 | >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
74 | >>> net.apply(init_weights_truncated_normal)
75 | """
76 | std = 0.02
77 | if isinstance(m, nn.Linear):
78 | torch.nn.init.trunc_normal_(m.weight, std=std, a=-2 * std, b=2 * std)
79 | m.bias.data.fill_(0.01)
80 |
81 |
82 | class MLP(nn.Module):
83 | """MLP head for transformer blocks"""
84 |
85 | def __init__(
86 | self,
87 | in_features: int,
88 | mlp_dim: int,
89 | dropout: float,
90 | ) -> None:
91 | """
92 | MLP head for transformer blocks
93 | Args:
94 | expansion_rate: rate at which the input tensor is expanded
95 | dropout_rate: dropout rate
96 | input_shape: shape of the input tensor, with the last dimension as the
97 | size of channel -- [N1, ... Nn, C]
98 | """
99 | super().__init__()
100 |
101 | # Initialize linear layers with truncated normal
102 | self.fc1 = nn.Linear(in_features=in_features, out_features=mlp_dim)
103 | init_weights_truncated_normal(self.fc1)
104 | self.fc2 = torch.nn.Linear(in_features=mlp_dim, out_features=in_features)
105 | init_weights_truncated_normal(self.fc2)
106 |
107 | self.act = nn.GELU()
108 | self.dropout = nn.Dropout(p=dropout)
109 |
110 | def forward(self, input: Tensor) -> Tensor:
111 | """Forward pass.
112 |
113 | Args:
114 | features: tensor of shape (batch_size, seq_len, hidden_dim)
115 |
116 | Returns:
117 | tensor of shape (batch_size, seq_len, hidden_dim)
118 | """
119 | input = self.fc1(input)
120 | input = self.act(input)
121 | input = self.dropout(input)
122 | input = self.fc2(input)
123 | return self.dropout(input)
124 |
125 |
126 | class WindowMultiHeadAttention(nn.Module):
127 | def __init__(
128 | self,
129 | hidden_dim: int,
130 | num_heads: int,
131 | attn_drop: float = 0.0,
132 | proj_drop: float = 0.0,
133 | ) -> None:
134 | """Windowed multi-head self-attention
135 |
136 | Args:
137 | hidden_dim: size of the hidden units
138 | num_heads: number of attention heads
139 | attn_drop: dropout rate of the attention layer. Defaults to 0.0.
140 | proj_drop: dropout rate of the projection layer. Defaults to 0.0.
141 |
142 | Raises:
143 | ValueError: if `hidden_dim` is not a multiple of `num_heads`
144 | """
145 | super().__init__()
146 | self.hidden_dim = hidden_dim
147 | self.num_heads = num_heads
148 | if hidden_dim % num_heads != 0:
149 | raise ValueError(
150 | f"Size of hidden units ({hidden_dim}) not divisible by number "
151 | f"of heads ({num_heads})."
152 | )
153 | # constnat to scale output before softmax-ing
154 | self.attn_scale = (hidden_dim // num_heads) ** (-0.5)
155 | # All linear layers are initialized with truncated normal
156 | self.q_linear = torch.nn.Linear(
157 | in_features=hidden_dim, out_features=hidden_dim, bias=True
158 | )
159 | init_weights_truncated_normal(self.q_linear)
160 | self.k_linear = torch.nn.Linear(
161 | in_features=hidden_dim, out_features=hidden_dim, bias=True
162 | )
163 | init_weights_truncated_normal(self.k_linear)
164 | self.v_linear = torch.nn.Linear(
165 | in_features=hidden_dim, out_features=hidden_dim, bias=True
166 | )
167 | init_weights_truncated_normal(self.v_linear)
168 | self.attn_dropout = nn.Dropout(p=attn_drop)
169 | self.proj = torch.nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
170 | init_weights_truncated_normal(self.proj)
171 | self.proj_dropout = torch.nn.Dropout(p=proj_drop)
172 | self.softmax = torch.nn.Softmax(dim=-1) # -1 is the default in tf
173 |
174 | def forward(
175 | self,
176 | query: Tensor,
177 | key: Tensor,
178 | value: Tensor,
179 | mask: Optional[torch.Tensor] = None,
180 | ) -> Tuple[Tensor, Tensor]:
181 | """Compute the windowed multi-head self-attention.
182 |
183 | Note: seq_length of keys and values (`seq_len_kv`) must be an integer
184 | multiple of the seq_length of the query (`seq_len_q`).
185 |
186 | Args:
187 | query: tensor of shape (b', seq_len_q, hidden_dim),
188 | representing the query
189 | key: tensor of shape (b', seq_len_kv, hidden_dim),
190 | representing the key
191 | value: tensor of shape (b', seq_len_kv, hidden_dim),
192 | representing the value
193 | mask: optional tensor of shape (b', seq_len_q, seq_len_q),
194 | representing the mask to apply to the attention; 1s will be masked
195 | -> b' is an augmented batch size that includes the total number of patches
196 | -> by default, hidden_dim is 768
197 |
198 | Returns:
199 | tensor of feautures of shape (b', seq_len_q, hidden_dim), as well as the
200 | attention matrix used, shape (b', num_heads, seq_len_q, seq_len_kv).
201 | """
202 | *b, seq_len_q, c = query.shape
203 | assert c == self.hidden_dim, f"Shape mismatch, {c} != {self.hidden_dim}"
204 | seq_len_kv = value.shape[-2]
205 | assert seq_len_kv % seq_len_q == 0, (
206 | f"seq_length of keys and values = {seq_len_kv} must be an integer multiple "
207 | f"of the seq_length of the query = {seq_len_q}"
208 | )
209 | blowup = seq_len_kv // seq_len_q
210 |
211 | query = self.q_linear(query) # [b', seq_len_q, hidden_dim]
212 | key = self.k_linear(key) # [b', seq_len_kv, hidden_dim]
213 | value = self.v_linear(value) # [b', seq_len_kv, hidden_dim]
214 |
215 | # reshape by splitting channels into num_heads, then permute:
216 | query = (
217 | query.reshape(*b, seq_len_q, self.num_heads, c // self.num_heads)
218 | .permute(0, 2, 1, 3)
219 | .contiguous()
220 | ) # [b', num_heads, seq_len_q, c // num_heads]
221 | key = (
222 | key.reshape(*b, seq_len_kv, self.num_heads, c // self.num_heads)
223 | .permute(0, 2, 1, 3)
224 | .contiguous()
225 | ) # [b', num_heads, seq_len_kv, c // num_heads]
226 | value = (
227 | value.reshape(*b, seq_len_kv, self.num_heads, c // self.num_heads)
228 | .permute(0, 2, 1, 3)
229 | .contiguous()
230 | ) # [b', num_heads, seq_len_kv, c // num_heads]
231 |
232 | # b', num_heads, seq_len_q, seq_len_kv
233 | attn = (
234 | torch.matmul(
235 | query, # [b', num_heads, seq_len_q, c // num_heads]
236 | key.transpose(2, 3), # [b', num_heads, c // num_heads, seq_len_kv]
237 | ) # [b', num_heads,seq_len_q, seq_len_kv]
238 | * self.attn_scale
239 | )
240 |
241 | if mask is not None:
242 | if mask.shape[-2:] != (seq_len_q, seq_len_q):
243 | # Note that we only mask for self-attention in the decoder,
244 | # where the attention matrix has shape (..., seq_len_q, seq_len_q).
245 | raise ValueError(f"Invalid mask shape: {mask.shape}.")
246 |
247 | # Here, we add the mask to the attention with a large negative multiplier,
248 | # as this goes into a softmax it will become 0
249 | tile_pattern = [1] * mask.dim()
250 | tile_pattern[-1] = blowup
251 | attn = attn + torch.tile(mask, tile_pattern) * -1e6
252 | else:
253 | tile_pattern = None
254 |
255 | attn = self.softmax(attn) # [b', num_heads, seq_length_q, seq_length_kv]
256 |
257 | if mask is not None and tile_pattern is not None:
258 | # We use the mask again, to be ensure that no masked dimension
259 | # affects the output.
260 | keep = 1 - mask
261 | attn = attn * torch.tile(keep, tile_pattern)
262 |
263 | attn = self.attn_dropout(attn) # [b', num_heads, seq_length_q, seq_length_kv]
264 |
265 | features = torch.matmul(
266 | attn, value
267 | ) # [b', num_heads, seq_len_q, d_model//num_heads], c=d_model
268 | assert features.shape == (*b, self.num_heads, seq_len_q, c // self.num_heads)
269 |
270 | features = (
271 | features.permute(0, 2, 1, 3) # switch num_heads, seq_len_q dimensions.
272 | .contiguous()
273 | .reshape(*b, -1, self.hidden_dim) # flatten num_heads, seq_len_q dimensions
274 | )
275 | features = self.proj(features)
276 | features = self.proj_dropout(features)
277 | assert features.shape == (*b, seq_len_q, c)
278 | return features, attn
279 |
280 |
281 | class StochasticDepth(nn.Module):
282 | """Creates a stochastic depth layer."""
283 |
284 | def __init__(self, stochastic_depth_drop_rate: float) -> None:
285 | """Initializes a stochastic depth layer.
286 |
287 | Args:
288 | stochastic_depth_drop_rate: A `float` of drop rate.
289 |
290 | Returns:
291 | a tensor of the same shape as input.
292 | """
293 | super().__init__()
294 | self._drop_rate = stochastic_depth_drop_rate
295 |
296 | def forward(self, input: Tensor) -> Tensor:
297 | if not self.training or self._drop_rate == 0.0:
298 | return input
299 |
300 | keep_prob = 1.0 - self._drop_rate
301 | batch_size = input.shape[0]
302 | random_tensor = keep_prob
303 | random_tensor = random_tensor + torch.rand(
304 | [batch_size] + [1] * (input.dim() - 1), dtype=input.dtype
305 | )
306 | binary_tensor = torch.floor(random_tensor)
307 | return input / keep_prob * binary_tensor
308 |
309 |
310 | def quantize_ste(x):
311 | """Differentiable quantization via the Straight-Through-Estimator."""
312 | # STE (straight-through estimator) trick: x_hard - x_soft.detach() + x_soft
313 | return (torch.round(x) - x).detach() + x
314 |
--------------------------------------------------------------------------------
/layer/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
5 |
6 |
7 | class Mlp(nn.Module):
8 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
9 | super().__init__()
10 | out_features = out_features or in_features
11 | hidden_features = hidden_features or in_features
12 | self.fc1 = nn.Linear(in_features, hidden_features)
13 | self.act = act_layer()
14 | self.fc2 = nn.Linear(hidden_features, out_features)
15 | self.drop = nn.Dropout(drop)
16 |
17 | def forward(self, x):
18 | x = self.fc1(x)
19 | x = self.act(x)
20 | x = self.drop(x)
21 | x = self.fc2(x)
22 | x = self.drop(x)
23 | return x
24 |
25 |
26 | def window_partition(x, window_size):
27 | """
28 | Args:
29 | x: (B, H, W, C)
30 | window_size (int): window size
31 | Returns:
32 | windows: (num_windows*B, window_size, window_size, C)
33 | """
34 | B, H, W, C = x.shape
35 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
36 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
37 | return windows
38 |
39 |
40 | def window_reverse(windows, window_size, H, W):
41 | """
42 | Args:
43 | windows: (num_windows*B, window_size, window_size, C)
44 | window_size (int): Window size
45 | H (int): Height of image
46 | W (int): Width of image
47 | Returns:
48 | x: (B, H, W, C)
49 | """
50 | # print("windows.shape[0]", windows.shape[0], H * W)
51 | B = int(windows.shape[0] / (H * W / window_size / window_size))
52 | # print(B)
53 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
54 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
55 | return x
56 |
57 |
58 | class WindowAttention(nn.Module):
59 | r""" Window based multi-head self attention (W-MSA) module with relative position bias.
60 | It supports both of shifted and non-shifted window.
61 | Args:
62 | dim (int): Number of input channels.
63 | window_size (tuple[int]): The height and width of the window.
64 | num_heads (int): Number of attention heads.
65 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
66 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
67 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
68 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0
69 | """
70 |
71 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
72 |
73 | super().__init__()
74 | self.dim = dim
75 | self.window_size = window_size # Wh, Ww
76 | self.num_heads = num_heads
77 | head_dim = dim // num_heads
78 | self.scale = qk_scale or head_dim ** -0.5
79 |
80 | # define a parameter table of relative position bias
81 | self.relative_position_bias_table = nn.Parameter(
82 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
83 |
84 | # get pair-wise relative position index for each token inside the window
85 | coords_h = torch.arange(self.window_size[0])
86 | coords_w = torch.arange(self.window_size[1])
87 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
88 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
89 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
90 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
91 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
92 | relative_coords[:, :, 1] += self.window_size[1] - 1
93 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
94 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
95 | self.register_buffer("relative_position_index", relative_position_index)
96 |
97 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
98 | self.attn_drop = nn.Dropout(attn_drop)
99 | self.proj = nn.Linear(dim, dim)
100 | self.proj_drop = nn.Dropout(proj_drop)
101 |
102 | trunc_normal_(self.relative_position_bias_table, std=.02)
103 | self.softmax = nn.Softmax(dim=-1)
104 |
105 | def forward(self, x, mask=None):
106 | """
107 | Args:
108 | x: input features with shape of (num_windows*B, N, C)
109 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
110 | """
111 | B_, N, C = x.shape
112 |
113 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
114 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
115 |
116 | q = q * self.scale
117 | attn = (q @ k.transpose(-2, -1)) # (N+1)x(N+1)
118 |
119 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
120 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
121 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
122 |
123 | attn = attn + relative_position_bias.unsqueeze(0)
124 |
125 | if mask is not None:
126 | mask = mask.to(attn.get_device())
127 | nW = mask.shape[0]
128 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
129 | attn = attn.view(-1, self.num_heads, N, N)
130 | attn = self.softmax(attn)
131 | else:
132 | attn = self.softmax(attn)
133 |
134 | attn = self.attn_drop(attn)
135 |
136 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
137 | x = self.proj(x)
138 | x = self.proj_drop(x)
139 | return x
140 |
141 | def extra_repr(self) -> str:
142 | return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
143 |
144 | def flops(self, N):
145 | # calculate flops for 1 window with token length of N
146 | flops = 0
147 | # qkv = self.qkv(x)
148 | flops += N * self.dim * 3 * self.dim
149 | # attn = (q @ k.transpose(-2, -1))
150 | flops += self.num_heads * N * (self.dim // self.num_heads) * N
151 | # x = (attn @ v)
152 | flops += self.num_heads * N * N * (self.dim // self.num_heads)
153 | # x = self.proj(x)
154 | flops += N * self.dim * self.dim
155 | return flops
156 |
157 |
158 | class TransformerBlock(nn.Module):
159 | r""" Transformer Block.
160 | Args:
161 | dim (int): Number of input channels.
162 | num_heads (int): Number of attention heads.
163 | window_size (int): Window size.
164 | shift_size (int): Shift size for SW-MSA.
165 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
166 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
167 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
168 | drop (float, optional): Dropout rate. Default: 0.0
169 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
170 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
171 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
172 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
173 | """
174 |
175 | def __init__(self, dim, num_heads, window_size=7,
176 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
177 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
178 | super().__init__()
179 | self.dim = dim
180 | self.num_heads = num_heads
181 | self.window_size = window_size
182 | self.mlp_ratio = mlp_ratio
183 | self.norm1 = norm_layer(dim)
184 | self.attn = WindowAttention(
185 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
186 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
187 |
188 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
189 | self.norm2 = norm_layer(dim)
190 | mlp_hidden_dim = int(dim * mlp_ratio)
191 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
192 |
193 | def forward(self, x, input_resolution):
194 | H0, W0 = input_resolution
195 | B, L, C = x.shape
196 | shortcut = x
197 | x = self.norm1(x)
198 | if H0 % self.window_size != 0 or W0 % self.window_size != 0:
199 | x = x.view(B, H0, W0, C).permute(0, 3, 1, 2)
200 | # reflect pad feature maps to multiples of window size
201 | pad_l = pad_t = 0
202 | pad_b = (self.window_size - H0 % self.window_size) % self.window_size
203 | pad_r = (self.window_size - W0 % self.window_size) % self.window_size
204 | x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
205 | # print("padding", pad_l, pad_r, pad_t, pad_b)
206 | x = x.permute(0, 2, 3, 1)
207 | H = H0 + pad_b
208 | W = W0 + pad_r
209 | else:
210 | H = H0
211 | W = W0
212 | x = x.view(B, H0, W0, C)
213 |
214 | # partition windows
215 | x_windows = window_partition(x, self.window_size) # nW*B, window_size, window_size, C
216 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
217 | B_, N, C = x_windows.shape
218 |
219 | attn_windows = self.attn(x_windows, mask=None) # nW*B, window_size*window_size, C
220 |
221 | # merge windows
222 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
223 | x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
224 |
225 | # remove paddings
226 | if L != H * W:
227 | x = x[:, :H0, :W0, :].contiguous()
228 |
229 | x = x.view(B, H0 * W0, C)
230 | # FFN
231 | x = shortcut + self.drop_path(x)
232 | x = x + self.drop_path(self.mlp(self.norm2(x)))
233 |
234 | return x
235 |
236 | def extra_repr(self) -> str:
237 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
238 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
239 |
240 |
241 | class LearnedPosition(nn.Module):
242 | """
243 | Learned poisitional encoding
244 | """
245 |
246 | def __init__(self, seq_length: int, hidden_dim: int) -> None:
247 | """
248 | Learned positional encoding
249 |
250 | Args:
251 | seq_length: sequence length.
252 | hidden_dim: hidden (model) dimension
253 | """
254 | super().__init__()
255 | self._emb = torch.nn.parameter.Parameter(
256 | torch.empty(1, seq_length, hidden_dim).normal_(std=0.02)
257 | ) # [1, seq_length, hidden_dim]
258 | self._seq_len = seq_length
259 | self._hidden_dim = hidden_dim
260 |
261 | def forward(self, x):
262 | """Adds positional encodings to an input
263 |
264 | Args:
265 | x: tensor to which the positional encodings will be added,
266 | expected shape is [B, seq_len, hidden_dim]
267 |
268 | Returns:
269 | the input tensor with the positional encodings added.
270 | """
271 | assert x.dim() == 3 and x.shape[1:] == torch.Size(
272 | [self._seq_len, self._hidden_dim]
273 | ), f"Expected [B, seq_length, hidden_dim] got {x.shape}"
274 | return x + self._emb
275 |
276 |
277 | class SwinTransformerBlock(nn.Module):
278 | r""" Swin Transformer Block.
279 | Args:
280 | dim (int): Number of input channels.
281 | num_heads (int): Number of attention heads.
282 | window_size (int): Window size.
283 | shift_size (int): Shift size for SW-MSA.
284 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
285 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
286 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
287 | drop (float, optional): Dropout rate. Default: 0.0
288 | attn_drop (float, optional): Attention dropout rate. Default: 0.0
289 | drop_path (float, optional): Stochastic depth rate. Default: 0.0
290 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
291 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
292 | """
293 |
294 | def __init__(self, dim, num_heads, window_size=7, shift_size=0,
295 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
296 | act_layer=nn.GELU, norm_layer=nn.LayerNorm):
297 | super().__init__()
298 | self.dim = dim
299 | self.input_resolution = (24, 24)
300 | self.num_heads = num_heads
301 | self.window_size = window_size
302 | self.shift_size = shift_size
303 | self.mlp_ratio = mlp_ratio
304 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
305 |
306 | self.norm1 = norm_layer(dim)
307 | self.attn = WindowAttention(
308 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
309 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
310 |
311 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
312 | self.norm2 = norm_layer(dim)
313 | mlp_hidden_dim = int(dim * mlp_ratio)
314 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
315 |
316 | if self.shift_size > 0:
317 | # calculate attention mask for SW-MSA
318 | H, W = self.input_resolution
319 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
320 | h_slices = (slice(0, -self.window_size),
321 | slice(-self.window_size, -self.shift_size),
322 | slice(-self.shift_size, None))
323 | w_slices = (slice(0, -self.window_size),
324 | slice(-self.window_size, -self.shift_size),
325 | slice(-self.shift_size, None))
326 | cnt = 0
327 | for h in h_slices:
328 | for w in w_slices:
329 | img_mask[:, h, w, :] = cnt
330 | cnt += 1
331 |
332 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
333 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
334 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
335 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
336 | else:
337 | attn_mask = None
338 |
339 | self.register_buffer("attn_mask", attn_mask)
340 |
341 | def forward(self, x, input_resolution, slice_size=None):
342 | H0, W0 = input_resolution
343 | device = x.get_device()
344 | if input_resolution != self.input_resolution:
345 | self.update_mask(input_resolution, device)
346 | if slice_size is not None:
347 | inter_slice_mask = self.get_inter_slice_mask(input_resolution, slice_size, device)
348 | B, L, C = x.shape
349 | # assert L == H * W, "input feature has wrong size, input size{}x{}x{}, H:{}, W:{}".format(B, L, C, H, W)
350 | shortcut = x
351 | x = self.norm1(x)
352 | if H0 % self.window_size != 0 or W0 % self.window_size != 0:
353 | x = x.view(B, H0, W0, C).permute(0, 3, 1, 2)
354 | # reflect pad feature maps to multiples of window size
355 | pad_l = pad_t = 0
356 | pad_b = (self.window_size - H0 % self.window_size) % self.window_size
357 | pad_r = (self.window_size - W0 % self.window_size) % self.window_size
358 | x = F.pad(x, (pad_l, pad_r, pad_t, pad_b))
359 | x = x.permute(0, 2, 3, 1)
360 | H = H0 + pad_b
361 | W = W0 + pad_r
362 | else:
363 | H = H0
364 | W = W0
365 | x = x.view(B, H0, W0, C)
366 |
367 | # cyclic shift
368 | if self.shift_size > 0:
369 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
370 | if slice_size is None:
371 | mask = self.attn_mask
372 | else:
373 | mask = (self.attn_mask + inter_slice_mask).clip(-100.0, 0.0)
374 | else:
375 | shifted_x = x
376 | assert (slice_size is None) or (slice_size[0] % self.window_size == 0)
377 | mask = None
378 |
379 | # partition windows
380 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
381 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
382 | B_, N, C = x_windows.shape
383 |
384 | attn_windows = self.attn(x_windows,
385 | mask=mask) # nW*B, window_size*window_size, C
386 |
387 | # merge windows
388 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
389 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
390 |
391 | # reverse cyclic shift
392 | if self.shift_size > 0:
393 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
394 | else:
395 | x = shifted_x
396 |
397 | # remove paddings
398 | if L != H * W:
399 | x = x[:, :H0, :W0, :].contiguous()
400 |
401 | # x = x.view(B, H0 * W0, C)
402 | x = x.view(B, H * W, C)
403 |
404 | # FFN
405 | x = shortcut + self.drop_path(x)
406 | x = x + self.drop_path(self.mlp(self.norm2(x)))
407 |
408 | return x
409 |
410 | def extra_repr(self) -> str:
411 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
412 | f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
413 |
414 | def flops(self):
415 | flops = 0
416 | H, W = self.input_resolution
417 | # norm1
418 | flops += self.dim * H * W
419 | # W-MSA/SW-MSA
420 | nW = H * W / self.window_size / self.window_size
421 | flops += nW * self.attn.flops(self.window_size * self.window_size)
422 | # mlp
423 | flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
424 | # norm2
425 | flops += self.dim * H * W
426 | return flops
427 |
428 | def update_mask(self, input_resolution, device):
429 | self.input_resolution = input_resolution
430 | if self.shift_size > 0:
431 | # calculate attention mask for SW-MSA
432 | H, W = input_resolution
433 | H = H + (self.window_size - H % self.window_size) % self.window_size
434 | W = W + (self.window_size - W % self.window_size) % self.window_size
435 |
436 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
437 | h_slices = (slice(0, -self.window_size),
438 | slice(-self.window_size, -self.shift_size),
439 | slice(-self.shift_size, None))
440 | w_slices = (slice(0, -self.window_size),
441 | slice(-self.window_size, -self.shift_size),
442 | slice(-self.shift_size, None))
443 | cnt = 0
444 | for h in h_slices:
445 | for w in w_slices:
446 | img_mask[:, h, w, :] = cnt
447 | cnt += 1
448 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
449 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
450 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
451 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
452 | self.attn_mask = attn_mask.to(device)
453 | # print(attn_mask[0])
454 | # print(attn_mask[1])
455 | # print(attn_mask.shape)
456 | else:
457 | pass
458 |
459 | def get_inter_slice_mask(self, input_resolution, slice_size, device):
460 | H, W = input_resolution
461 | Hs, Ws = slice_size
462 | nH = H // Hs
463 | nW = W // Ws
464 | H = H + (self.window_size - H % self.window_size) % self.window_size
465 | W = W + (self.window_size - W % self.window_size) % self.window_size
466 | img_mask = torch.arange(nH * nW).reshape(nH, nW).repeat_interleave(Hs, dim=0).repeat_interleave(Ws, dim=1)
467 | img_mask = img_mask.reshape(1, H, W, 1)
468 | if self.shift_size > 0:
469 | img_mask = torch.roll(img_mask, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
470 | img_mask = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
471 | img_mask = img_mask.view(-1, self.window_size * self.window_size)
472 | attn_mask = img_mask.unsqueeze(1) - img_mask.unsqueeze(2)
473 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
474 | inter_slice_attn_mask = attn_mask.to(device)
475 | return inter_slice_attn_mask
476 |
477 |
478 | if __name__ == '__main__':
479 | layer = SwinTransformerBlock(dim=96, num_heads=3, window_size=4, shift_size=2).cuda()
480 | x = torch.randn(1, 64, 96).cuda()
481 | layer(x, (8, 8))
482 | layer.get_windows_mask((8, 8), (2, 2))
483 |
--------------------------------------------------------------------------------
/loss/distortion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from pytorch_msssim import ms_ssim
4 |
5 |
6 | class Distortion(torch.nn.Module):
7 | def __init__(self, distortion_type):
8 | super(Distortion, self).__init__()
9 | if distortion_type == 'MSE':
10 | self.metric = nn.MSELoss()
11 | elif distortion_type == 'MS-SSIM':
12 | self.metric = ms_ssim
13 | else:
14 | print("Unknown distortion type!")
15 | raise ValueError
16 |
17 | def forward(self, X, Y):
18 | if self.metric == ms_ssim:
19 | return 1 - self.metric(X, Y, data_range=1)
20 | else:
21 | return self.metric(X, Y)
--------------------------------------------------------------------------------
/net/elic.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | #
3 | # This source code is licensed under the MIT license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | ## This module contains modules implementing standard synthesis and analysis transforms
7 |
8 | from typing import List, Optional
9 |
10 | import torch
11 | import torch.nn as nn
12 | from compressai.layers import GDN, AttentionBlock
13 | from torch import Tensor
14 |
15 | from layer.layer_utils import make_conv, make_deconv
16 |
17 |
18 | class ConvGDNAnalysis(nn.Module):
19 | def __init__(
20 | self, network_channels: int = 128, compression_channels: int = 192
21 | ) -> None:
22 | """
23 | Analysis transfrom from scale hyperprior (https://arxiv.org/abs/1802.01436)
24 |
25 | Encodes each image in a video independently.
26 | """
27 | super().__init__()
28 | self._compression_channels = compression_channels
29 |
30 | self.transforms = nn.Sequential(
31 | make_conv(3, network_channels, kernel_size=5, stride=2),
32 | GDN(network_channels),
33 | make_conv(network_channels, network_channels, kernel_size=5, stride=2),
34 | GDN(network_channels),
35 | make_conv(network_channels, network_channels, kernel_size=5, stride=2),
36 | GDN(network_channels),
37 | make_conv(network_channels, compression_channels, kernel_size=5, stride=2),
38 | )
39 | self._num_down = 4
40 |
41 | @property
42 | def compression_channels(self) -> int:
43 | return self._compression_channels
44 |
45 | @property
46 | def num_downsampling_layers(self) -> int:
47 | return self._num_down
48 |
49 | def forward(self, video_frames: Tensor) -> Tensor:
50 | """
51 | Args:
52 | video_frames: frames of a batch of clips. Expected shape [B, T, C, H, W],
53 | which is reshaped to [BxT, C, H, W], hyperprior model encoder is applied
54 | and output is reshaped back to [B, T, , h, w].
55 | Returns:
56 | embeddings: embeddings of shape [B, T, , h, w], obtained
57 | by running ScaleHyperprior.image_analysis().
58 | """
59 | assert (
60 | video_frames.dim() == 5
61 | ), f"Expected [B, T, C, H, W] got {video_frames.shape}"
62 | embeddings = self.transforms(video_frames.reshape(-1, *video_frames.shape[2:]))
63 | return embeddings.reshape(*video_frames.shape[:2], *embeddings.shape[1:])
64 |
65 |
66 | class ConvGDNSynthesis(nn.Module):
67 | def __init__(
68 | self, network_channels: int = 128, compression_channels: int = 192
69 | ) -> None:
70 | """
71 | Synthesis transfrom from scale hyperprior (https://arxiv.org/abs/1802.01436)
72 |
73 | Decodes each image in a video independently
74 | """
75 |
76 | super().__init__()
77 | self._compression_channels = 192
78 |
79 | self.transforms = nn.Sequential(
80 | make_deconv(
81 | compression_channels, network_channels, kernel_size=5, stride=2
82 | ),
83 | GDN(network_channels, inverse=True),
84 | make_deconv(network_channels, network_channels, kernel_size=5, stride=2),
85 | GDN(network_channels, inverse=True),
86 | make_deconv(network_channels, network_channels, kernel_size=5, stride=2),
87 | GDN(network_channels, inverse=True),
88 | make_deconv(network_channels, 3, kernel_size=5, stride=2),
89 | )
90 |
91 | @property
92 | def compression_channels(self) -> int:
93 | return self._compression_channels
94 |
95 | def forward(self, x: Tensor, frames_shape: torch.Size) -> Tensor:
96 | """
97 | Args:
98 | x: the (reconstructed) latent embdeddings to be decoded to images,
99 | expected shape [B, T, C, H, W]
100 | frames_shape: shape of the video clip to be reconstructed.
101 | Returns:
102 | reconstruction: reconstruction of the original video clip with shape
103 | [B, T, C, H, W] = frames_shape.
104 | """
105 | assert x.dim() == 5, f"Expected [B, T, C, H, W] got {x.shape}"
106 | # Treat T as part of the Batch dimension, storing values to reshape back
107 | B, T, *_ = x.shape
108 | x = x.reshape(-1, *x.shape[2:])
109 | assert len(frames_shape) == 5
110 |
111 | x = self.transforms(x) # final reconstruction
112 | x = x.reshape(B, T, *x.shape[1:])
113 | return x[..., : frames_shape[-2], : frames_shape[-1]]
114 |
115 |
116 | class ResidualUnit(nn.Module):
117 | """Simple residual unit"""
118 |
119 | def __init__(self, N: int) -> None:
120 | super().__init__()
121 | self.conv = nn.Sequential(
122 | make_conv(N, N // 2, kernel_size=1),
123 | nn.ReLU(inplace=True),
124 | make_conv(N // 2, N // 2, kernel_size=3),
125 | nn.ReLU(inplace=True),
126 | make_conv(N // 2, N, kernel_size=1),
127 | )
128 | self.activation = nn.ReLU(inplace=True)
129 |
130 | def forward(self, x: Tensor) -> Tensor:
131 | identity = x
132 | out = self.conv(x)
133 | out += identity
134 | out = self.activation(out)
135 | return out
136 |
137 |
138 | def sinusoidal_embedding(values: torch.Tensor, dim=256, max_period=64):
139 | assert values.dim() == 1 and (dim % 2) == 0
140 | exponents = torch.linspace(0, 1, steps=(dim // 2))
141 | freqs = torch.pow(max_period, -1.0 * exponents).to(device=values.device)
142 | args = values.view(-1, 1) * freqs.view(1, dim // 2)
143 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
144 | return embedding
145 |
146 |
147 | class AdaLN(nn.Module):
148 | default_embedding_dim = 256
149 |
150 | def __init__(self, dim, embed_dim=None):
151 | super().__init__()
152 | embed_dim = embed_dim or self.default_embedding_dim
153 | self.embedding_layer = nn.Sequential(
154 | nn.GELU(),
155 | nn.Linear(embed_dim, 2 * dim),
156 | nn.Unflatten(1, unflattened_size=(1, 1, 2 * dim))
157 | )
158 |
159 | def forward(self, x, emb):
160 | # layer norm
161 | x = x.permute(0, 2, 3, 1).contiguous()
162 | # x = self.norm(x)
163 | # AdaLN
164 | embedding = self.embedding_layer(emb)
165 | shift, scale = torch.chunk(embedding, chunks=2, dim=-1)
166 | x = x * (1 + scale) + shift
167 | x = x.permute(0, 3, 1, 2).contiguous()
168 | return x
169 |
170 |
171 | class ConditionalResidualUnit(nn.Module):
172 | def __init__(self, N: int) -> None:
173 | super().__init__()
174 | self.pre_conv = make_conv(N, N // 2, kernel_size=1)
175 | self.adaLN = AdaLN(N // 2)
176 | self.post_conv = nn.Sequential(
177 | nn.ReLU(inplace=True),
178 | make_conv(N // 2, N // 2, kernel_size=3),
179 | # nn.ReLU(inplace=True),
180 | make_conv(N // 2, N, kernel_size=1),
181 | )
182 | self.activation = nn.ReLU(inplace=True)
183 |
184 | def forward(self, x: Tensor, emb: Tensor) -> Tensor:
185 | identity = x
186 | out = self.pre_conv(x)
187 | out = self.adaLN(out, emb)
188 | out = self.post_conv(out)
189 | out += identity
190 | out = self.activation(out)
191 | return out
192 |
193 |
194 | class ELICAnalysis(nn.Module):
195 | def __init__(
196 | self,
197 | num_residual_blocks=3,
198 | channels: List[int] = [256, 256, 256, 192],
199 | compression_channels: Optional[int] = None,
200 | max_frames: Optional[int] = None,
201 | ) -> None:
202 | """Analysis transform from ELIC (https://arxiv.org/abs/2203.10886), which
203 | can be configured to match the one from "Devil's in the Details"
204 | (https://arxiv.org/abs/2203.08450).
205 |
206 | Args:
207 | num_residual_blocks: defaults to 3.
208 | channels: defaults to [128, 160, 192, 192].
209 | compression_channels: optional, defaults to None. If provided, it must equal
210 | the last element of `channels`.
211 | max_frames: optional, defaults to None. If provided, the input is chunked
212 | into max_frames elements, otherwise the entire batch is processed at
213 | once. This is useful when large sequences are to be processed and can
214 | be used to manage memory a bit better.
215 | """
216 | super().__init__()
217 | if len(channels) != 4:
218 | raise ValueError(f"ELIC uses 4 conv layers (not {len(channels)}).")
219 | if compression_channels is not None and compression_channels != channels[-1]:
220 | raise ValueError(
221 | "output_channels specified but does not match channels: "
222 | f"{compression_channels} vs. {channels}"
223 | )
224 | self._compression_channels = (
225 | compression_channels if compression_channels is not None else channels[-1]
226 | )
227 | self._max_frames = max_frames
228 |
229 | def res_units(N):
230 | return [ResidualUnit(N) for _ in range(num_residual_blocks)]
231 |
232 | channels = [3] + channels
233 |
234 | self.transforms = nn.Sequential(
235 | make_conv(channels[0], channels[1], kernel_size=5, stride=2),
236 | *res_units(channels[1]),
237 | make_conv(channels[1], channels[2], kernel_size=5, stride=2),
238 | *res_units(channels[2]),
239 | AttentionBlock(channels[2]),
240 | make_conv(channels[2], channels[3], kernel_size=5, stride=2),
241 | *res_units(channels[3]),
242 | make_conv(channels[3], channels[4], kernel_size=5, stride=2),
243 | AttentionBlock(channels[4]),
244 | )
245 |
246 | @property
247 | def compression_channels(self) -> int:
248 | return self._compression_channels
249 |
250 | def forward(self, x: Tensor) -> Tensor:
251 | x = self.transforms(x)
252 | return x
253 |
254 |
255 | class ELICSynthesis(nn.Module):
256 | def __init__(
257 | self,
258 | num_residual_blocks=3,
259 | channels: List[int] = [192, 256, 256, 3],
260 | output_channels: Optional[int] = None
261 | ) -> None:
262 | """
263 | Synthesis transform from ELIC (https://arxiv.org/abs/2203.10886).
264 |
265 | Args:
266 | num_residual_blocks: defaults to 3.
267 | channels: _defaults to [192, 160, 128, 3].
268 | output_channels: optional, defaults to None. If provided, it must equal
269 | the last element of `channels`.
270 | """
271 | super().__init__()
272 | if len(channels) != 4:
273 | raise ValueError(f"ELIC uses 4 conv layers (not {channels}).")
274 | if output_channels is not None and output_channels != channels[-1]:
275 | raise ValueError(
276 | "output_channels specified but does not match channels: "
277 | f"{output_channels} vs. {channels}"
278 | )
279 |
280 | self._compression_channels = channels[0]
281 |
282 | def res_units(N: int) -> List:
283 | return [ResidualUnit(N) for _ in range(num_residual_blocks)]
284 |
285 | channels = [channels[0]] + channels
286 | self.transforms = nn.Sequential(
287 | AttentionBlock(channels[0]),
288 | make_deconv(channels[0], out_channels=channels[1], kernel_size=5, stride=2),
289 | *res_units(channels[1]),
290 | make_deconv(channels[1], out_channels=channels[2], kernel_size=5, stride=2),
291 | AttentionBlock(channels[2]),
292 | *res_units(channels[2]),
293 | make_deconv(channels[2], out_channels=channels[3], kernel_size=5, stride=2),
294 | *res_units(channels[3]),
295 | make_deconv(channels[3], out_channels=channels[4], kernel_size=5, stride=2),
296 | )
297 |
298 | @property
299 | def compression_channels(self) -> int:
300 | return self._compression_channels
301 |
302 | def forward(self, x: Tensor) -> Tensor:
303 | x = self.transforms(x)
304 | return x
305 |
306 |
307 | class ConditionalAttentionBlock(nn.Module):
308 | def __init__(self, N: int):
309 | super().__init__()
310 | self.conv_a = nn.Sequential(ConditionalResidualUnit(N),
311 | ConditionalResidualUnit(N),
312 | ConditionalResidualUnit(N))
313 |
314 | self.conv_b1 = nn.Sequential(
315 | ConditionalResidualUnit(N),
316 | ConditionalResidualUnit(N),
317 | ConditionalResidualUnit(N),
318 | )
319 | self.conv_b2 = make_conv(N, N, kernel_size=1)
320 |
321 | def forward(self, x: Tensor, emb: Tensor) -> Tensor:
322 | identity = x
323 | a = x
324 | for conv_a in self.conv_a:
325 | a = conv_a(a, emb)
326 | for conv_b1 in self.conv_b1:
327 | x = conv_b1(x, emb)
328 | b = self.conv_b2(x)
329 | out = a * torch.sigmoid(b)
330 | out += identity
331 | return out
332 |
333 |
334 | class AdaptiveELICSynthesis(nn.Module):
335 | def __init__(
336 | self,
337 | num_residual_blocks=3,
338 | channels: List[int] = [192, 256, 256, 3],
339 | output_channels: Optional[int] = None
340 | ) -> None:
341 | super().__init__()
342 | if len(channels) != 4:
343 | raise ValueError(f"ELIC uses 4 conv layers (not {channels}).")
344 | if output_channels is not None and output_channels != channels[-1]:
345 | raise ValueError(
346 | "output_channels specified but does not match channels: "
347 | f"{output_channels} vs. {channels}"
348 | )
349 |
350 | self._compression_channels = channels[0]
351 |
352 | def cond_res_units(N: int) -> List:
353 | return [ConditionalResidualUnit(N) for _ in range(num_residual_blocks)]
354 |
355 | channels = [channels[0]] + channels
356 | self.transforms_1 = ConditionalAttentionBlock(channels[0])
357 | self.deconv_1 = make_deconv(channels[0], out_channels=channels[1], kernel_size=5, stride=2)
358 | self.transforms_2 = nn.Sequential(*cond_res_units(channels[1]))
359 | self.deconv_2 = make_deconv(channels[1], out_channels=channels[2], kernel_size=5, stride=2)
360 | self.transforms_3 = nn.Sequential(ConditionalAttentionBlock(channels[2]),
361 | *cond_res_units(channels[2]))
362 | self.deconv_3 = make_deconv(channels[2], out_channels=channels[3], kernel_size=5, stride=2)
363 | self.transforms_4 = nn.Sequential(*cond_res_units(channels[3]))
364 | self.deconv_4 = make_deconv(channels[3], out_channels=channels[4], kernel_size=5, stride=2)
365 |
366 | def forward(self, x: Tensor, emb: Tensor) -> Tensor:
367 | x = self.transforms_1(x, emb)
368 | x = self.deconv_1(x)
369 | for transforms_2 in self.transforms_2:
370 | x = transforms_2(x, emb)
371 | x = self.deconv_2(x)
372 | for transforms_3 in self.transforms_3:
373 | x = transforms_3(x, emb)
374 | x = self.deconv_3(x)
375 | for transforms_4 in self.transforms_4:
376 | x = transforms_4(x, emb)
377 | x = self.deconv_4(x)
378 | return x
379 |
380 |
381 | if __name__ == '__main__':
382 | g_s = AdaptiveELICSynthesis()
383 | y = torch.randn(1, 192, 8, 8)
384 | emb = torch.randn(1, 256)
385 | x = g_s(y, emb)
386 | print(x.shape)
387 |
--------------------------------------------------------------------------------
/net/mit.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import time
4 | from functools import lru_cache
5 |
6 | import constriction
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 |
11 | from layer.layer_utils import quantize_ste, make_conv
12 | from layer.transformer import TransformerBlock, SwinTransformerBlock
13 | from net.elic import ELICAnalysis, ELICSynthesis
14 |
15 |
16 | # from compressai.entropy_models import GaussianConditional
17 |
18 |
19 | class GaussianMixtureEntropyModel(nn.Module):
20 | def __init__(
21 | self,
22 | minmax: int = 64
23 | ):
24 | super().__init__()
25 | self.minmax = minmax
26 | self.samples = torch.arange(-minmax, minmax + 1, 1, dtype=torch.float32).cuda()
27 | self.laplace = torch.distributions.Laplace(0, 1)
28 | self.pmf_laplace = self.laplace.cdf(self.samples + 0.5) - self.laplace.cdf(self.samples - 0.5)
29 | # self.gaussian_conditional = GaussianConditional(None)
30 |
31 | def update_minmax(self, minmax):
32 | self.minmax = minmax
33 | self.samples = torch.arange(-minmax, minmax + 1, 1, dtype=torch.float32).cuda()
34 | self.pmf_laplace = self.laplace.cdf(self.samples + 0.5) - self.laplace.cdf(self.samples - 0.5)
35 |
36 | def get_GMM_likelihood(self, latent_hat, probs, means, scales):
37 | gaussian1 = torch.distributions.Normal(means[0], scales[0])
38 | gaussian2 = torch.distributions.Normal(means[1], scales[1])
39 | gaussian3 = torch.distributions.Normal(means[2], scales[2])
40 | likelihoods_0 = gaussian1.cdf(latent_hat + 0.5) - gaussian1.cdf(latent_hat - 0.5)
41 | likelihoods_1 = gaussian2.cdf(latent_hat + 0.5) - gaussian2.cdf(latent_hat - 0.5)
42 | likelihoods_2 = gaussian3.cdf(latent_hat + 0.5) - gaussian3.cdf(latent_hat - 0.5)
43 |
44 | likelihoods = 0.999 * (probs[0] * likelihoods_0 + probs[1] * likelihoods_1 + probs[2] * likelihoods_2)
45 | + 0.001 * (self.laplace.cdf(latent_hat + 0.5) - self.laplace.cdf(latent_hat - 0.5))
46 | likelihoods = likelihoods + 1e-10
47 | return likelihoods
48 |
49 | def get_GMM_pmf(self, probs, means, scales):
50 | L = self.samples.size(0)
51 | num_symbol = probs.size(1)
52 | samples = self.samples.unsqueeze(0).repeat(num_symbol, 1) # N 65
53 | scales = scales.unsqueeze(-1).repeat(1, 1, L)
54 | means = means.unsqueeze(-1).repeat(1, 1, L)
55 | probs = probs.unsqueeze(-1).repeat(1, 1, L)
56 | likelihoods_0 = self._likelihood(samples, scales[0], means=means[0])
57 | likelihoods_1 = self._likelihood(samples, scales[1], means=means[1])
58 | likelihoods_2 = self._likelihood(samples, scales[2], means=means[2])
59 | pmf_clip = (0.999 * (probs[0] * likelihoods_0 + probs[1] * likelihoods_1 + probs[2] * likelihoods_2)
60 | + 0.001 * self.pmf_laplace)
61 | return pmf_clip
62 |
63 | def _likelihood(self, inputs, scales, means=None):
64 | half = float(0.5)
65 |
66 | if means is not None:
67 | values = inputs - means
68 | else:
69 | values = inputs
70 | values = torch.abs(values)
71 | upper = self._standardized_cumulative((half - values) / scales)
72 | lower = self._standardized_cumulative((-half - values) / scales)
73 | likelihood = upper - lower
74 | return likelihood
75 |
76 | def _standardized_cumulative(self, inputs):
77 | half = float(0.5)
78 | const = float(-(2 ** -0.5))
79 | # Using the complementary error function maximizes numerical precision.
80 | return half * torch.erfc(const * inputs)
81 |
82 | def compress(self, symbols, probs, means, scales):
83 | pmf_clip = self.get_GMM_pmf(probs, means, scales)
84 | model_family = constriction.stream.model.Categorical() # note empty `()`
85 | probabilities = pmf_clip.cpu().numpy().astype(np.float64)
86 | symbols = symbols.reshape(-1)
87 | symbols = (symbols + self.minmax).cpu().numpy().astype(np.int32)
88 | encoder = constriction.stream.queue.RangeEncoder()
89 | encoder.encode(symbols, model_family, probabilities)
90 | compressed = encoder.get_compressed()
91 | return compressed
92 |
93 | def decompress(self, compressed, probs, means, scales):
94 | pmf = self.get_GMM_pmf(probs, means, scales).cpu().numpy().astype(np.float64)
95 | model = constriction.stream.model.Categorical()
96 | decoder = constriction.stream.queue.RangeDecoder(compressed)
97 | symbols = decoder.decode(model, pmf)
98 | symbols = torch.from_numpy(symbols).to(probs.device) - self.minmax
99 | symbols = torch.tensor(symbols, dtype=torch.float32)
100 | return symbols
101 |
102 |
103 | @lru_cache()
104 | def get_coding_order(target_shape, context_mode, device, step=12):
105 | if context_mode == 'quincunx':
106 | context_tensor = torch.tensor([[4, 2, 4, 0], [3, 4, 3, 4], [4, 1, 4, 2]]).to(device)
107 | elif context_mode == 'checkerboard2':
108 | context_tensor = torch.tensor([[1, 0], [0, 1]]).to(device)
109 | elif context_mode == 'checkerboard4':
110 | context_tensor = torch.tensor([[0, 2], [3, 1]]).to(device)
111 | elif context_mode == 'qlds':
112 | B, C, H, W = target_shape
113 |
114 | def get_qlds(H, W):
115 | n, m, g = 0, 0, 1.32471795724474602596
116 | a1, a2 = 1.0 / g, 1.0 / g / g
117 | context_tensor = torch.zeros((H, W)).to(device) - 1
118 | while m < H * W:
119 | n += 1
120 | x = int(round(((0.5 + n * a1) % 1) * H)) % H
121 | y = int(round(((0.5 + n * a2) % 1) * W)) % W
122 | if context_tensor[x, y] == -1:
123 | context_tensor[x, y] = m
124 | m += 1
125 | return context_tensor
126 |
127 | context_tensor = torch.tensor(get_qlds(H, W), dtype=torch.int)
128 |
129 | def gamma_func(alpha=1.):
130 | return lambda r: r ** alpha
131 |
132 | ratio = 1. * (np.arange(step) + 1) / step
133 | gamma = gamma_func(alpha=2.2)
134 | L = H * W # total number of tokens
135 | mask_ratio = np.clip(np.floor(L * gamma(ratio)), 0, L - 1)
136 | for i in range(step):
137 | context_tensor = torch.where((context_tensor <= mask_ratio[i]) * (context_tensor > i),
138 | torch.ones_like(context_tensor) * i, context_tensor)
139 | return context_tensor
140 | else:
141 | context_tensor = context_mode
142 | B, C, H, W = target_shape
143 | Hp, Wp = context_tensor.size()
144 | coding_order = torch.tile(context_tensor, (H // Hp + 1, W // Wp + 1))[:H, :W]
145 | return coding_order
146 |
147 |
148 | class MaskedImageTransformer(nn.Module):
149 | def __init__(self, latent_dim, dim=768, depth=12, num_heads=12, window_size=24,
150 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
151 | drop_path=0., norm_layer=nn.LayerNorm, transformer='swin'):
152 | super().__init__()
153 | self.dim = dim
154 | self.depth = depth
155 | if transformer == 'swin':
156 | window_size = 4
157 | num_heads = 8
158 | self.blocks = nn.ModuleList([
159 | SwinTransformerBlock(dim=dim,
160 | num_heads=num_heads,
161 | window_size=window_size,
162 | shift_size=0 if (i % 2 == 0) else window_size // 2,
163 | mlp_ratio=mlp_ratio,
164 | qkv_bias=qkv_bias, qk_scale=qk_scale,
165 | drop=drop, attn_drop=attn_drop,
166 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
167 | norm_layer=norm_layer)
168 | for i in range(depth)])
169 | else:
170 | self.blocks = nn.ModuleList([
171 | TransformerBlock(dim=dim,
172 | num_heads=num_heads, window_size=window_size,
173 | mlp_ratio=mlp_ratio,
174 | qkv_bias=qkv_bias, qk_scale=qk_scale,
175 | drop=drop, attn_drop=attn_drop,
176 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
177 | norm_layer=norm_layer)
178 | for i in range(depth)])
179 | self.delta = 5.0
180 | self.embedding_layer = nn.Linear(latent_dim, dim)
181 | # self.positional_encoding = LearnedPosition(window_size * window_size, dim)
182 | self.entropy_parameters = nn.Sequential(
183 | make_conv(dim, dim * 4, 1, 1),
184 | nn.GELU(),
185 | make_conv(dim * 4, dim * 4, 1, 1),
186 | nn.GELU(),
187 | make_conv(dim * 4, latent_dim * 9, 1, 1),
188 | )
189 |
190 | self.gmm_model = GaussianMixtureEntropyModel()
191 | self.laplace = torch.distributions.Laplace(0, 1)
192 | self.mask_token = nn.Parameter(torch.zeros(1, 1, latent_dim), requires_grad=True)
193 |
194 | def forward_with_given_mask(self, latent_hat, mask, slice_size=None):
195 | B, C, H, W = latent_hat.size()
196 | input_resolution = (H, W)
197 | x = latent_hat.flatten(2).transpose(1, 2) # B L N
198 | mask_BLN = mask.flatten(2).transpose(1, 2) # B L N
199 | x_masked = x * mask_BLN + self.mask_token * (1 - mask_BLN)
200 |
201 | x_masked = self.embedding_layer(x_masked / self.delta)
202 | # x = self.positional_encoding(x)
203 | for _, blk in enumerate(self.blocks):
204 | x_masked = blk(x_masked, input_resolution, slice_size)
205 | x_out = x_masked.transpose(1, 2).reshape(B, self.dim, H, W)
206 | params = self.entropy_parameters(x_out)
207 | probs, means, scales = params.chunk(3, dim=1)
208 | probs = torch.softmax(probs.reshape(B, 3, C, H, W), dim=1).transpose(0, 1)
209 | means = means.reshape(B, 3, C, H, W).transpose(0, 1)
210 | scales = torch.abs(scales).reshape(B, 3, C, H, W).transpose(0, 1).clamp(1e-10, 1e10)
211 | return probs, means, scales
212 |
213 | def forward_with_random_mask(self, latent):
214 | B, C, H, W = latent.size()
215 | half = float(0.5)
216 | noise = torch.empty_like(latent).uniform_(-half, half)
217 | latent_noise = latent + noise
218 | latent_hat = quantize_ste(latent)
219 |
220 | def generate_random_mask(latent, r):
221 | mask_loc = torch.randn(H * W).to(latent.get_device())
222 | threshold = torch.sort(mask_loc)[0][r]
223 | mask = torch.where(mask_loc > threshold, torch.ones_like(mask_loc), torch.zeros_like(mask_loc))
224 | mask = mask.reshape(1, 1, H, W).repeat(B, C, 1, 1)
225 | return mask
226 |
227 | r = math.floor(np.random.uniform(0.05, 0.99) * H * W) # drop probability
228 | mask = generate_random_mask(latent_hat, r)
229 | mask_params = mask.unsqueeze(0).repeat(3, 1, 1, 1, 1)
230 | probs, means, scales = self.forward_with_given_mask(latent_hat, mask)
231 | likelihoods_masked = torch.ones_like(latent_hat)
232 | likelihoods = self.gmm_model.get_GMM_likelihood(latent_noise[mask == 0],
233 | probs[mask_params == 0].reshape(3, -1),
234 | means[mask_params == 0].reshape(3, -1),
235 | scales[mask_params == 0].reshape(3, -1))
236 | likelihoods_masked[mask == 0] = likelihoods
237 | return latent_hat, likelihoods_masked
238 |
239 | def inference(self, latent_hat, context_mode='qlds', slice_size=None):
240 | coding_order = get_coding_order(latent_hat.shape, context_mode, latent_hat.get_device(), step=12) # H W
241 | coding_order = coding_order.reshape(1, 1, *coding_order.shape).repeat(latent_hat.shape[0], latent_hat.shape[1],
242 | 1, 1)
243 | total_steps = int(coding_order.max() + 1)
244 | likelihoods = torch.zeros_like(latent_hat)
245 | for i in range(total_steps):
246 | ctx_locations = (coding_order < i)
247 | mask_i = torch.where(ctx_locations, torch.ones_like(latent_hat), torch.zeros_like(latent_hat))
248 | probs_i, means_i, scales_i = self.forward_with_given_mask(latent_hat, mask_i, slice_size)
249 | encoding_locations = (coding_order == i)
250 | mask_params_i = encoding_locations.unsqueeze(0).repeat(3, 1, 1, 1, 1)
251 | likelihoods_i = self.gmm_model.get_GMM_likelihood(latent_hat[encoding_locations],
252 | probs_i[mask_params_i].reshape(3, -1),
253 | means_i[mask_params_i].reshape(3, -1),
254 | scales_i[mask_params_i].reshape(3, -1))
255 | likelihoods[encoding_locations] = likelihoods_i
256 | return likelihoods
257 |
258 | def compress(self, latent, context_mode='qlds'):
259 | B, C, H, W = latent.size()
260 | latent_hat = torch.round(latent)
261 | self.gmm_model.update_minmax(int(latent_hat.max().item()))
262 | coding_order = get_coding_order(latent.shape, context_mode, latent_hat.get_device(), step=12) # H W
263 | coding_order = coding_order.reshape(1, 1, *coding_order.shape).repeat(latent.shape[0], latent.shape[1], 1, 1)
264 | total_steps = int(coding_order.max() + 1)
265 | t0 = time.time()
266 | strings = []
267 | for i in range(total_steps):
268 | # print('STEP', i)
269 | ctx_locations = (coding_order < i)
270 | encoding_locations = (coding_order == i)
271 | mask_params_i = encoding_locations.unsqueeze(0).repeat(3, 1, 1, 1, 1)
272 | mask_i = torch.where(ctx_locations, torch.ones_like(latent), torch.zeros_like(latent))
273 | probs_i, means_i, scales_i = self.forward_with_given_mask(latent_hat, mask_i)
274 | string_i = self.gmm_model.compress(latent_hat[encoding_locations],
275 | probs_i[mask_params_i].reshape(3, -1),
276 | means_i[mask_params_i].reshape(3, -1),
277 | scales_i[mask_params_i].reshape(3, -1))
278 | strings.append(string_i)
279 | print('compress', time.time() - t0)
280 | return strings
281 |
282 | def decompress(self, strings, latent_size, device, context_mode='qlds'):
283 | B, C, H, W = latent_size
284 | coding_order = get_coding_order(latent_size, context_mode, device, step=12) # H W
285 | coding_order = coding_order.reshape(1, 1, *coding_order.shape).repeat(B, C, 1, 1)
286 | total_steps = int(coding_order.max() + 1)
287 | t0 = time.time()
288 | latent_hat = torch.zeros(latent_size).to(device)
289 | for i in range(total_steps):
290 | ctx_locations = (coding_order < i)
291 | encoding_locations = (coding_order == i)
292 | mask_params_i = encoding_locations.unsqueeze(0).repeat(3, 1, 1, 1, 1)
293 | mask_i = torch.where(ctx_locations, torch.ones_like(latent_hat), torch.zeros_like(latent_hat))
294 | probs_i, means_i, scales_i = self.forward_with_given_mask(latent_hat, mask_i)
295 | symbols_i = self.gmm_model.decompress(strings[i],
296 | probs_i[mask_params_i].reshape(3, -1),
297 | means_i[mask_params_i].reshape(3, -1),
298 | scales_i[mask_params_i].reshape(3, -1))
299 | latent_hat[encoding_locations] = symbols_i
300 | print('decompress', time.time() - t0)
301 | return latent_hat
302 |
303 | @torch.jit.ignore
304 | def no_weight_decay(self):
305 | return {'mask_token'}
306 |
307 |
308 | class MaskedImageModelingTransformer(nn.Module):
309 | def __init__(self):
310 | super().__init__()
311 | self.g_a = ELICAnalysis()
312 | self.g_s = ELICSynthesis()
313 | self.mim = MaskedImageTransformer(192)
314 |
315 | def forward(self, x):
316 | y = self.g_a(x)
317 | y_hat, likelihoods = self.mim.forward_with_random_mask(y)
318 | x_hat = self.g_s(y_hat)
319 | return {
320 | "x_hat": x_hat,
321 | "likelihoods": likelihoods,
322 | }
323 |
324 | def inference(self, x):
325 | # TODO Patch-wise inference for off-the-shelf Transformers
326 | y = self.g_a(x)
327 | y_hat = torch.round(y)
328 | likelihoods = self.mim.inference(y_hat)
329 | x_hat = self.g_s(y_hat)
330 | return {
331 | "x_hat": x_hat,
332 | "likelihoods": likelihoods,
333 | }
334 |
335 | def real_inference(self, x):
336 | num_pixels = x.size(2) * x.size(3)
337 | y = self.g_a(x)
338 | strings = self.mim.compress(y)
339 | y_hat = self.mim.decompress(strings, y.shape, x.get_device())
340 | x_hat = self.g_s(y_hat)
341 | bpp = sum([string.size * 32 for string in strings]) / num_pixels
342 | return {
343 | "x_hat": x_hat,
344 | "bpp": bpp,
345 | }
346 |
--------------------------------------------------------------------------------
/results/mt_rd.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wsxtyrdd/Masked-Transformer-For-Image-Compression/2dc391982c0d4e798acb158a552b5e39b968290a/results/mt_rd.png
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | sys.path.append("/media/D/wangsixian/MT")
5 |
6 | import configargparse
7 | import time
8 | from datetime import datetime
9 | import numpy as np
10 | import random
11 | import torch
12 | import torch.backends.cudnn as cudnn
13 | import torch.distributed as dist
14 | import torch.nn.functional as F
15 | import wandb
16 |
17 | torch.backends.cudnn.benchmark = False
18 | from data.datasets import get_loader, get_dataset
19 | from loss.distortion import Distortion
20 | from utils import logger_configuration, load_weights, AverageMeter, save_checkpoint, worker_init_fn_seed
21 | from net.mit import MaskedImageModelingTransformer
22 |
23 |
24 | def train_one_epoch(epoch, net, train_loader, test_loader, optimizer,
25 | device, logger):
26 | local_rank = torch.distributed.get_rank() if config.multi_gpu else 0
27 | best_loss = float("inf")
28 | mse_loss_wrapper = Distortion("MSE").to(device)
29 | ms_ssim_loss_wrapper = Distortion("MS-SSIM").to(device)
30 | elapsed, losses, psnrs, ms_ssim_dbs, bpps = [AverageMeter() for _ in range(5)]
31 | metrics = [elapsed, losses, psnrs, ms_ssim_dbs, bpps]
32 | global global_step
33 | for batch_idx, input_image in enumerate(train_loader):
34 | net.train()
35 | B, C, H, W = input_image.shape
36 | num_pixels = B * H * W
37 | input_image = input_image.to(device)
38 | optimizer.zero_grad()
39 | global_step += 1
40 | start_time = time.time()
41 | results = net(input_image)
42 | bpp_loss = torch.sum(torch.clamp(-torch.log2(results["likelihoods"]), 0, 50)) / num_pixels
43 | # if config.distortion_metric == "MSE":
44 | mse_loss = mse_loss_wrapper(results["x_hat"], input_image)
45 | tot_loss = config.lambda_value * 255 ** 2 * mse_loss + bpp_loss
46 | # elif config.distortion_metric == "MS-SSIM":
47 | # ssim_loss = ms_ssim_loss_wrapper(results["x_hat"], input_image)
48 | # tot_loss = config.lambda_value * ssim_loss + bpp_loss
49 | tot_loss.backward()
50 |
51 | if config.clip_max_norm > 0:
52 | torch.nn.utils.clip_grad_norm_(net.parameters(), config.clip_max_norm)
53 | optimizer.step()
54 |
55 | elapsed.update(time.time() - start_time)
56 | losses.update(tot_loss.item())
57 | bpps.update(bpp_loss.item())
58 |
59 | mse_val = mse_loss_wrapper(results["x_hat"], input_image).detach().item()
60 | psnr = 10 * (np.log(1. / mse_val) / np.log(10))
61 | psnrs.update(psnr.item())
62 |
63 | ms_ssim_val = ms_ssim_loss_wrapper(results["x_hat"], input_image).detach().item()
64 | ms_ssim_db = -10 * (np.log(ms_ssim_val) / np.log(10))
65 | ms_ssim_dbs.update(ms_ssim_db.item())
66 |
67 | if (global_step % config.print_every) == 0 and local_rank == 0:
68 | process = (global_step % train_loader.__len__()) / (train_loader.__len__()) * 100.0
69 | log_info = [
70 | f'Step [{global_step % train_loader.__len__()}/{train_loader.__len__()}={process:.2f}%]',
71 | f'Loss ({losses.avg:.3f})',
72 | f'Time {elapsed.avg:.2f}',
73 | f'PSNR {psnrs.val:.2f} ({psnrs.avg:.2f})',
74 | f'BPP {bpps.val:.2f} ({bpps.avg:.2f})',
75 | f'MS-SSIM {ms_ssim_dbs.val:.2f} ({ms_ssim_dbs.avg:.2f})',
76 | f'Epoch {epoch}'
77 | ]
78 | log = (' | '.join(log_info))
79 | logger.info(log)
80 | if config.wandb:
81 | log_dict = {"PSNR": psnrs.avg,
82 | "MS-SSIM": ms_ssim_dbs.avg,
83 | "BPP": bpps.avg,
84 | "loss": losses.avg,
85 | "Step": global_step,
86 | }
87 | wandb.log(log_dict, step=global_step)
88 | for i in metrics:
89 | i.clear()
90 |
91 | if (global_step + 1) % config.test_every == 0 and local_rank == 0:
92 | loss = test(net, test_loader, device, logger)
93 | is_best = loss < best_loss
94 | best_loss = min(loss, best_loss)
95 | if is_best:
96 | save_checkpoint(
97 | {
98 | "epoch": epoch + 1,
99 | "global_step": global_step,
100 | "state_dict": net.state_dict(),
101 | "optimizer": optimizer.state_dict()
102 | },
103 | is_best,
104 | workdir
105 | )
106 |
107 | if (global_step + 1) % config.save_every == 0 and local_rank == 0:
108 | save_checkpoint(
109 | {
110 | "epoch": epoch + 1,
111 | "global_step": global_step,
112 | "state_dict": net.state_dict(),
113 | "optimizer": optimizer.state_dict()
114 | },
115 | False,
116 | workdir,
117 | filename='EP{}.pth.tar'.format(epoch)
118 | )
119 |
120 |
121 | def test(net, test_loader, device, logger):
122 | with torch.no_grad():
123 | mse_loss_wrapper = Distortion("MSE").to(device)
124 | ms_ssim_loss_wrapper = Distortion("MS-SSIM").to(device)
125 | elapsed, losses, psnrs, ms_ssim_dbs, bpps = [AverageMeter() for _ in range(5)]
126 | global global_step
127 | for batch_idx, input_image in enumerate(test_loader):
128 | net.eval()
129 | B, C, H, W = input_image.shape
130 | num_pixels = B * H * W
131 | input_image = input_image.to(device)
132 | start_time = time.time()
133 | # crop and pad
134 | p = 64
135 | new_H = (H + p - 1) // p * p
136 | new_W = (W + p - 1) // p * p
137 | padding_left = (new_W - W) // 2
138 | padding_right = new_W - W - padding_left
139 | padding_top = (new_H - H) // 2
140 | padding_bottom = new_H - H - padding_top
141 | input_image_pad = F.pad(
142 | input_image,
143 | (padding_left, padding_right, padding_top, padding_bottom),
144 | mode="constant",
145 | value=0,
146 | )
147 | results = net.module.inference(input_image_pad)
148 | results["x_hat"] = F.pad(
149 | results["x_hat"], (-padding_left, -padding_right, -padding_top, -padding_bottom)
150 | )
151 |
152 | mse_loss = mse_loss_wrapper(results["x_hat"], input_image)
153 | bpp_loss = torch.sum(torch.clamp(-torch.log2(results["likelihoods"]), 0, 50)) / num_pixels
154 | tot_loss = config.lambda_value * 255 ** 2 * mse_loss + bpp_loss
155 | elapsed.update(time.time() - start_time)
156 | losses.update(tot_loss.item())
157 | bpps.update(bpp_loss.item())
158 |
159 | # results = net.module.real_inference(input_image_pad)
160 | # mse_loss = mse_loss_wrapper(results["x_hat"], input_image)
161 | # bpp_loss = results["bpp"]
162 | # bpps.update(bpp_loss)
163 |
164 | mse_val = mse_loss.item()
165 | psnr = 10 * (np.log(1. / mse_val) / np.log(10))
166 | psnrs.update(psnr.item())
167 |
168 | ms_ssim_val = ms_ssim_loss_wrapper(results["x_hat"], input_image).mean().item()
169 | ms_ssim_db = -10 * (np.log(ms_ssim_val) / np.log(10))
170 | ms_ssim_dbs.update(ms_ssim_db.item())
171 |
172 | log_info = [
173 | f'Step [{(batch_idx + 1)}/{test_loader.__len__()}]',
174 | f'Loss ({losses.avg:.3f})',
175 | f'Time {elapsed.val:.2f}',
176 | f'PSNR {psnrs.val:.2f} ({psnrs.avg:.2f})',
177 | f'BPP {bpps.val:.2f} ({bpps.avg:.4f})',
178 | f'MS-SSIM {ms_ssim_dbs.val:.2f} ({ms_ssim_dbs.avg:.2f})',
179 | ]
180 | log = (' | '.join(log_info))
181 | logger.info(log)
182 | if config.wandb and local_rank == 0:
183 | log_dict = {"[Kodak] PSNR": psnrs.avg,
184 | "[Kodak] MS-SSIM": ms_ssim_dbs.avg,
185 | "[Kodak] BPP": bpps.avg,
186 | "[Kodak] loss": losses.avg,
187 | "[Kodak] Step": global_step,
188 | }
189 | wandb.log(log_dict, step=global_step)
190 | return losses.avg
191 |
192 |
193 | def parse_args(argv):
194 | parser = configargparse.ArgumentParser()
195 | parser.add_argument('--config', is_config_file=True, default='./config/mt.yaml',
196 | help='Path to config file to replace defaults.')
197 | parser.add_argument('--seed', type=int, default=1024,
198 | help='Random seed.')
199 | parser.add_argument('--gpu-id', type=str, default='0',
200 | help='GPU id to use.')
201 | parser.add_argument('--test-only', action='store_true',
202 | help='Test only (and do not run training).')
203 | parser.add_argument('--multi-gpu', action='store_true')
204 | parser.add_argument('--local_rank', type=int,
205 | help='Local rank for distributed training.')
206 |
207 | # logging
208 | parser.add_argument('--exp-name', type=str, default=datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
209 | help='Experiment name, unique id for trainers, logs.')
210 | parser.add_argument('--wandb', action="store_true",
211 | help='Use wandb for logging.')
212 | parser.add_argument('--print-every', type=int, default=30,
213 | help='Frequency of logging.')
214 |
215 | # dataset
216 | parser.add_argument('--dataset-path', default=['/media/Dataset/openimages/**/'],
217 | help='Path to the dataset')
218 | parser.add_argument('--num-workers', type=int, default=8,
219 | help='Number of workers for dataloader.')
220 | parser.add_argument('--training-img-size', type=tuple, default=(384, 384),
221 | help='Size of the training images.')
222 | parser.add_argument('--eval-dataset-path', type=str,
223 | help='Path to the evaluation dataset')
224 |
225 | # optimization
226 | parser.add_argument('--distortion_metric', type=str, default='MSE',
227 | help='Distortion type, MSE/SSIM/Perceptual.')
228 | parser.add_argument('--lambda_value', type=float, default=1,
229 | help='Weight for the commitment loss.')
230 |
231 | # Optimizer configuration parameters
232 | parser.add_argument('--optimizer_type', type=str, default='AdamW',
233 | help='The type of optimizer to use')
234 | parser.add_argument('--init_lr', type=float, default=1e-4,
235 | help='The minimum learning rate for the learning rate policy')
236 | parser.add_argument('--min_lr', type=float, default=1e-4,
237 | help='The minimum learning rate for the learning rate policy')
238 | parser.add_argument('--max_lr', type=float, default=1e-4,
239 | help='The maximum learning rate for the learning rate policy')
240 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.98),
241 | help='The beta values to use for the optimizer')
242 | parser.add_argument('--warmup_epoch', type=int, default=5,
243 | help='The number of epochs to use for the warmup')
244 | parser.add_argument('--weight_decay', type=float, default=0.03,
245 | help='The weight decay value for the optimizer')
246 | parser.add_argument('--clip-max-norm', type=float, default=1.0,
247 | help='Gradient clipping for stable training.')
248 |
249 | # trainer
250 | parser.add_argument('--warmup', action='store_true')
251 | parser.add_argument('--epochs', type=int, default=10,
252 | help='Number of epochs to run the training.')
253 | parser.add_argument('--batch-size', type=int, default=16,
254 | help='Batch size for the training.')
255 | parser.add_argument("--checkpoint", type=str, default=None,
256 | help="Path to a checkpoint model")
257 | parser.add_argument('--save', action="store_true",
258 | help="Save the model at every epoch (no overwrite).")
259 | parser.add_argument('--save-every', type=int, default=10000,
260 | help='Frequency of saving the model.')
261 | parser.add_argument('--test-every', type=int, default=5000,
262 | help='Frequency of running validation.')
263 |
264 | # model
265 | parser.add_argument('--net', type=str, default='MT',
266 | help='Model architecture.')
267 | args = parser.parse_args(argv)
268 | return args
269 |
270 |
271 | def main(argv):
272 | global config
273 | config = parse_args(argv)
274 |
275 | global local_rank
276 | if config.multi_gpu:
277 | dist.init_process_group(backend='nccl')
278 | local_rank = torch.distributed.get_rank()
279 | torch.cuda.set_device(local_rank)
280 | device = torch.device("cuda", local_rank)
281 | else:
282 | local_rank = 0
283 | device = torch.device("cuda")
284 | os.environ['CUDA_VISIBLE_DEVICES'] = str(config.gpu_id)
285 |
286 | # torch.backends.cudnn.benchmark = True
287 |
288 | if config.seed is not None:
289 | torch.manual_seed(config.seed)
290 | random.seed(config.seed)
291 |
292 | config.device = device
293 | job_type = 'test' if config.test_only else 'train'
294 | exp_name = config.net + " " + config.exp_name
295 | global workdir
296 | workdir, logger = logger_configuration(exp_name, job_type,
297 | method=config.net, save_log=(not config.test_only and local_rank == 0))
298 | net = MaskedImageModelingTransformer().to(device)
299 | if config.multi_gpu:
300 | net = torch.nn.parallel.DistributedDataParallel(net, find_unused_parameters=True)
301 | else:
302 | net = torch.nn.DataParallel(net)
303 |
304 | if config.wandb and local_rank == 0:
305 | print("=============== use wandb ==============")
306 | wandb_init_kwargs = {
307 | 'project': 'ResiComm',
308 | 'name': exp_name,
309 | 'save_code': True,
310 | 'job_type': job_type,
311 | 'config': config.__dict__
312 | }
313 | wandb.init(**wandb_init_kwargs)
314 |
315 | # shutil.copy(config.config, join(workdir, 'config.yaml'))
316 |
317 | config.logger = logger
318 | logger.info(config.__dict__)
319 |
320 | if config.multi_gpu:
321 | train_dataset, test_dataset = get_dataset(config.dataset_path, config.eval_dataset_path)
322 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
323 | train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
324 | num_workers=config.num_workers,
325 | pin_memory=True,
326 | batch_size=config.batch_size,
327 | worker_init_fn=worker_init_fn_seed,
328 | sampler=train_sampler)
329 | test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
330 | batch_size=1,
331 | shuffle=False)
332 | else:
333 | train_loader, test_loader = get_loader(config.dataset_path, config.eval_dataset_path,
334 | config.num_workers, config.batch_size)
335 |
336 | optimizer_cfg = {'lr': config.max_lr,
337 | 'betas': config.betas,
338 | 'weight_decay': config.weight_decay
339 | }
340 | optimizer = getattr(torch.optim, config.optimizer_type)(net.parameters(), **optimizer_cfg)
341 |
342 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(config.epochs * 0.5), gamma=0.1)
343 |
344 | global global_step
345 | global_step = 0
346 |
347 | if config.checkpoint is not None:
348 | load_weights(net, config.checkpoint, device)
349 | else:
350 | logger.info("No pretrained model is loaded.")
351 |
352 | if config.test_only:
353 | test(net, test_loader, device, logger)
354 | else:
355 | steps_epoch = global_step // train_loader.__len__()
356 | init_lambda_value = config.lambda_value
357 | for epoch in range(steps_epoch, config.epochs):
358 | if config.warmup:
359 | # for lambda warmup
360 | if epoch <= int(config.epochs * 0.1):
361 | config.lambda_value = init_lambda_value * 10
362 | else:
363 | config.lambda_value = init_lambda_value
364 |
365 | logger.info('======Current epoch %s ======' % epoch)
366 | logger.info(f"Learning rate: {optimizer.param_groups[0]['lr']}")
367 | logger.info(f"Lambda value: {config.lambda_value}")
368 | train_one_epoch(epoch, net, train_loader, test_loader, optimizer, device, logger)
369 | lr_scheduler.step()
370 |
371 |
372 | if __name__ == '__main__':
373 | main(sys.argv[1:])
374 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def save_config(config, file_path):
9 | import json
10 | info_json = json.dumps(config, sort_keys=False, indent=4, separators=(',', ': '))
11 | f = open(file_path, 'w')
12 | f.write(info_json)
13 |
14 |
15 | def logger_configuration(filename, phase, method='', save_log=True):
16 | logger = logging.getLogger(" ")
17 | workdir = './history/{}/{}'.format(method, filename)
18 | if phase == 'test':
19 | workdir += '_test'
20 | log = workdir + '/{}.log'.format(filename)
21 | samples = workdir + '/samples'
22 | models = workdir + '/models'
23 | if save_log:
24 | makedirs(workdir)
25 | makedirs(samples)
26 | makedirs(models)
27 |
28 | formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s",
29 | "%Y-%m-%d %H:%M:%S")
30 | stdhandler = logging.StreamHandler()
31 | stdhandler.setLevel(logging.INFO)
32 | stdhandler.setFormatter(formatter)
33 | logger.addHandler(stdhandler)
34 | if save_log:
35 | filehandler = logging.FileHandler(log)
36 | filehandler.setLevel(logging.INFO)
37 | filehandler.setFormatter(formatter)
38 | logger.addHandler(filehandler)
39 | logger.setLevel(logging.INFO)
40 | return workdir, logger
41 |
42 |
43 | def makedirs(directory):
44 | if not os.path.exists(directory):
45 | os.makedirs(directory)
46 |
47 |
48 | def save_checkpoint(state, is_best, base_dir, filename="checkpoint.pth.tar"):
49 | if is_best:
50 | torch.save(state, base_dir + "/checkpoint_best_loss.pth.tar")
51 | else:
52 | torch.save(state, base_dir + "/" + filename)
53 |
54 |
55 | class AverageMeter:
56 | """Compute running average."""
57 |
58 | def __init__(self):
59 | self.val = 0
60 | self.avg = 0
61 | self.sum = 0
62 | self.count = 0
63 |
64 | def update(self, val, n=1):
65 | self.val = val
66 | self.sum += val * n
67 | self.count += n
68 | self.avg = self.sum / self.count
69 |
70 | def clear(self):
71 | self.val = 0
72 | self.avg = 0
73 | self.sum = 0
74 | self.count = 0
75 |
76 |
77 | def load_weights(net, model_path, device, remove_keys=None):
78 | try:
79 | pretrained = torch.load(model_path, map_location=device)['state_dict']
80 | except:
81 | pretrained = torch.load(model_path, map_location=device)
82 | result_dict = {}
83 | for key, weight in pretrained.items():
84 | result_key = key
85 | load_flag = True
86 | if 'attn_mask' in key:
87 | load_flag = False
88 | if remove_keys is not None:
89 | for remove_key in remove_keys:
90 | if remove_key in key:
91 | load_flag = False
92 | if load_flag:
93 | result_dict[result_key] = weight
94 | print(net.load_state_dict(result_dict, strict=False))
95 | del result_dict, pretrained
96 |
97 |
98 | def worker_init_fn_seed(worker_id):
99 | seed = 10
100 | seed += worker_id
101 | np.random.seed(seed)
102 |
--------------------------------------------------------------------------------