├── clip
├── __init__.py
├── bpe_simple_vocab_16e6.txt.gz
├── __pycache__
│ ├── clip.cpython-37.pyc
│ ├── clip.cpython-38.pyc
│ ├── Closs.cpython-37.pyc
│ ├── Closs.cpython-38.pyc
│ ├── model.cpython-37.pyc
│ ├── model.cpython-38.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── simple_tokenizer.cpython-37.pyc
│ └── simple_tokenizer.cpython-38.pyc
├── simple_tokenizer.py
├── clip.py
├── Closs.py
└── model.py
├── loss
├── __init__.py
└── rd_loss.py
├── models
├── __init__.py
├── fourier_cond.py
├── mlicpp.py
└── ug_mlicpp.py
├── img
├── image1.png
├── image2.png
└── image3.png
├── modules
├── layers
│ ├── __init__.py
│ ├── conv.py
│ ├── attention.py
│ └── res_blk.py
└── transform
│ ├── __init__.py
│ ├── quantization.py
│ ├── entropy.py
│ ├── analysis.py
│ ├── synthesis.py
│ └── context.py
├── playground
├── warmup.sh
├── test.py
├── test_condi.py
├── train.py
├── warmup.py
├── train_mask.py
└── train_condi.py
├── tests
├── test.sh
└── test.py
├── config
├── config.py
└── args.py
├── utils
├── logger.py
├── metrics.py
├── optimizers.py
├── utils.py
├── func.py
├── newtrain_data.py
├── ckbd.py
├── train_data.py
└── testing.py
├── LICENSE
├── classification
└── test.py
├── instance segmantation
└── test.py
└── README.md
/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from .rd_loss import *
2 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .ug_mlicpp import UG_MLICPlusPlus
2 |
--------------------------------------------------------------------------------
/img/image1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/img/image1.png
--------------------------------------------------------------------------------
/img/image2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/img/image2.png
--------------------------------------------------------------------------------
/img/image3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/img/image3.png
--------------------------------------------------------------------------------
/modules/layers/__init__.py:
--------------------------------------------------------------------------------
1 | from .attention import *
2 | from .conv import *
3 | from .res_blk import *
4 |
--------------------------------------------------------------------------------
/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/clip/__pycache__/clip.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/clip.cpython-37.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/clip.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/clip.cpython-38.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/Closs.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/Closs.cpython-37.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/Closs.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/Closs.cpython-38.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/simple_tokenizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/simple_tokenizer.cpython-37.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/simple_tokenizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/simple_tokenizer.cpython-38.pyc
--------------------------------------------------------------------------------
/modules/transform/__init__.py:
--------------------------------------------------------------------------------
1 | from .analysis import *
2 | from .synthesis import *
3 | from .context import *
4 | from .quantization import *
5 | from .entropy import *
6 |
--------------------------------------------------------------------------------
/playground/warmup.sh:
--------------------------------------------------------------------------------
1 | work_path=$(dirname $0)
2 | export PYTHONPATH=..:$PYTHONPATH
3 | nohup python warmup.py --metrics mse --exp mlicpp_mse_q1 --gpu_id 0 --lambda 0.0018 -lr 1e-4 --clip_max_norm 1.0 --seed 1984 --batch-size 8 & > 0018v2.txt
4 |
--------------------------------------------------------------------------------
/tests/test.sh:
--------------------------------------------------------------------------------
1 | work_path=$(dirname $0)
2 | export PYTHONPATH=..:$PYTHONPATH
3 | CUDA_VISIBLE_DEVICES='0' python test.py -exp test_human --gpu_id 0 --beta 0 -c /path/to/checkpoint -d /path/to/dataset
4 | CUDA_VISIBLE_DEVICES='0' python test.py -exp test_machine --gpu_id 0 --beta 1 -c /path/to/checkpoint -d /path/to/dataset
5 |
6 |
--------------------------------------------------------------------------------
/config/config.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | from utils.utils import Config
3 |
4 | def model_config():
5 | config = Config({
6 | # MLIC and MLIC+
7 | "N": 192,
8 | "M": 320,
9 | "slice_num": 10,
10 | "context_window": 5,
11 | "act": nn.GELU,
12 | })
13 |
14 | return config
15 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from datetime import datetime
4 |
5 |
6 | def get_timestamp():
7 | return datetime.now().strftime('%y%m%d-%H%M%S')
8 |
9 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
10 | '''set up logger'''
11 | lg = logging.getLogger(logger_name)
12 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
13 | datefmt='%y-%m-%d %H:%M:%S')
14 | lg.setLevel(level)
15 | if tofile:
16 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp()))
17 | fh = logging.FileHandler(log_file, mode='w')
18 | fh.setFormatter(formatter)
19 | lg.addHandler(fh)
20 | if screen:
21 | sh = logging.StreamHandler()
22 | sh.setFormatter(formatter)
23 | lg.addHandler(sh)
24 |
--------------------------------------------------------------------------------
/modules/layers/conv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from compressai.layers import conv3x3
6 |
7 |
8 | def conv1x1(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module:
9 | """1x1 convolution."""
10 | return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride)
11 |
12 |
13 | def conv(in_channels, out_channels, kernel_size=5, stride=2):
14 | return nn.Conv2d(
15 | in_channels,
16 | out_channels,
17 | kernel_size=kernel_size,
18 | stride=stride,
19 | padding=kernel_size // 2,
20 | )
21 |
22 |
23 | def deconv(in_channels, out_channels, kernel_size=5, stride=2):
24 | return nn.ConvTranspose2d(
25 | in_channels,
26 | out_channels,
27 | kernel_size=kernel_size,
28 | stride=stride,
29 | output_padding=stride - 1,
30 | padding=kernel_size // 2,
31 | )
32 |
--------------------------------------------------------------------------------
/modules/transform/quantization.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from compressai.layers import subpel_conv3x3
5 | from modules.layers.conv import conv1x1, conv3x3, conv, deconv
6 | from modules.layers.res_blk import *
7 |
8 |
9 | class LatentResidualPrediction(nn.Module):
10 | def __init__(self, in_dim, out_dim, act=nn.GELU):
11 | super().__init__()
12 | diff = abs(out_dim - in_dim)
13 | # such setting leads to much more parameters, you'd better use the setting in Minnen'20 ICIP paper.
14 | self.lrp_transform = nn.Sequential(
15 | conv3x3(in_dim, in_dim - diff // 4),
16 | act(),
17 | conv3x3(in_dim - diff // 4, in_dim - diff // 2),
18 | act(),
19 | conv3x3(in_dim - diff // 2, in_dim - diff * 3 // 4),
20 | act(),
21 | conv3x3(in_dim - diff * 3 // 4, out_dim),
22 | )
23 |
24 | def forward(self, x):
25 | x = self.lrp_transform(x)
26 | x = 0.5 * torch.tanh(x)
27 | return x
28 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 YinKangSheng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/loss/rd_loss.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from pytorch_msssim import ms_ssim
5 |
6 |
7 | class RateDistortionLoss(nn.Module):
8 | """Custom rate distortion loss with a Lagrangian parameter."""
9 |
10 | def __init__(self, lmbda=1e-2, metrics='mse'):
11 | super().__init__()
12 | self.mse = nn.MSELoss()
13 | self.lmbda = lmbda
14 | self.metrics = metrics
15 |
16 | def set_lmbda(self, lmbda):
17 | self.lmbda = lmbda
18 |
19 | def forward(self, output, target):
20 | N, _, H, W = target.size()
21 | out = {}
22 | num_pixels = N * H * W
23 |
24 | out["bpp_loss"] = sum(
25 | (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels))
26 | for likelihoods in output["likelihoods"].values()
27 | )
28 | if self.metrics == 'mse':
29 | out["mse_loss"] = self.mse(output["x_hat"], target)
30 | out["ms_ssim_loss"] = None
31 | out["loss"] = self.lmbda * 255 ** 2 * out["mse_loss"] + out["bpp_loss"]
32 | elif self.metrics == 'ms-ssim':
33 | out["mse_loss"] = None
34 | out["ms_ssim_loss"] = 1 - ms_ssim(output["x_hat"], target, data_range=1.0)
35 | out["loss"] = self.lmbda * out["ms_ssim_loss"] + out["bpp_loss"]
36 |
37 | return out
38 |
--------------------------------------------------------------------------------
/utils/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import PIL.Image as Image
4 | from typing import Dict, List, Optional, Tuple, Union
5 | from pytorch_msssim import ms_ssim
6 | import math
7 | from skimage.metrics import structural_similarity as ssim
8 | def PSNR(img1, img2):
9 | mse = np.mean((img1 / 255. - img2 / 255.) ** 2)
10 | if mse < 1.0e-10:
11 | return 100
12 | PIXEL_MAX = 1
13 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
14 |
15 |
16 | def SSIM(img1, img2):
17 |
18 | return ssim(img1,img2,multichannel=True)
19 |
20 | def compute_metrics(
21 | a: Union[np.array, Image.Image],
22 | b: Union[np.array, Image.Image],
23 | max_val: float = 255.0,
24 | ) -> Tuple[float, float]:
25 | """Returns PSNR and MS-SSIM between images `a` and `b`. """
26 | if isinstance(a, Image.Image):
27 | a = np.asarray(a)
28 | if isinstance(b, Image.Image):
29 | b = np.asarray(b)
30 |
31 | a = torch.from_numpy(a.copy()).float().unsqueeze(0)
32 | if a.size(3) == 3:
33 | a = a.permute(0, 3, 1, 2)
34 | b = torch.from_numpy(b.copy()).float().unsqueeze(0)
35 | if b.size(3) == 3:
36 | b = b.permute(0, 3, 1, 2)
37 |
38 | mse = torch.mean((a - b) ** 2).item()
39 | p = 20 * np.log10(max_val) - 10 * np.log10(mse)
40 | m = ms_ssim(a, b, data_range=max_val).item()
41 | return p, m
42 |
43 | def compute_metrics2(
44 | a: Union[np.array, Image.Image],
45 | b: Union[np.array, Image.Image],
46 | max_val: float = 255.0,
47 | ) -> Tuple[float, float]:
48 |
49 |
50 | p = PSNR(np.array(b), np.array(a))
51 | m = SSIM(np.array(b), np.array(a))
52 | return p, m
--------------------------------------------------------------------------------
/modules/layers/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.init import trunc_normal_
4 | from torch.nn import functional as F
5 | from timm.models.layers import to_2tuple
6 | from modules.layers.res_blk import ResidualBottleneck
7 |
8 | class MLP(nn.Module):
9 |
10 | def __init__(self, in_dim, hidden_dim=None, out_dim=None, act_layer=nn.GELU, drop=0.):
11 | super().__init__()
12 | out_dim = out_dim or in_dim
13 | hidden_dim = hidden_dim or in_dim
14 | self.fc1 = nn.Linear(in_dim, hidden_dim)
15 | self.act = act_layer()
16 | self.fc2 = nn.Linear(hidden_dim, out_dim)
17 | self.drop = nn.Dropout(drop)
18 |
19 | def forward(self, x):
20 | x = self.fc1(x)
21 | x = self.act(x)
22 | x = self.drop(x)
23 | x = self.fc2(x)
24 | x = self.drop(x)
25 | return x
26 |
27 |
28 | def build_position_index(window_size):
29 | coords_h = torch.arange(window_size[0])
30 | coords_w = torch.arange(window_size[1])
31 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
32 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
33 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
34 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
35 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
36 | relative_coords[:, :, 1] += window_size[1] - 1
37 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1
38 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
39 | return relative_position_index
--------------------------------------------------------------------------------
/config/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def test_options():
5 | parser = argparse.ArgumentParser(description="Testing script.")
6 | parser.add_argument(
7 | "-exp",
8 | "--experiment",
9 | default="test",
10 | type=str,
11 | required=False,
12 | help="Experiment name"
13 | )
14 | parser.add_argument(
15 | "-d",
16 | "--dataset",
17 | default="/home/npr/dataset/",
18 | type=str,
19 | required=False,
20 | help="Training dataset"
21 | )
22 | parser.add_argument(
23 | "-n",
24 | "--num-workers",
25 | type=int,
26 | default=1,
27 | help="Dataloaders threads (default: %(default)s)",
28 | )
29 | parser.add_argument(
30 | "--metrics",
31 | type=str,
32 | default="mse",
33 | help="Optimized for (default: %(default)s)",
34 | )
35 | parser.add_argument(
36 | "--test-batch-size",
37 | type=int,
38 | default=1,
39 | help="Test batch size (default: %(default)s)",
40 | )
41 | parser.add_argument(
42 | "--gpu_id",
43 | type=int,
44 | default=3,
45 | help="GPU ID"
46 | )
47 | parser.add_argument(
48 | "--cuda",
49 | default=True,
50 | help="Use cuda"
51 | )
52 | parser.add_argument(
53 | "--save",
54 | default=True,
55 | help="Save model to disk"
56 | )
57 | parser.add_argument(
58 | "-c",
59 | "--checkpoint",
60 | default=None,
61 | type=str,
62 | help="pretrained model path"
63 | )
64 | args = parser.parse_args()
65 | return args
66 |
--------------------------------------------------------------------------------
/classification/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.transforms as transforms
3 | import torchvision.datasets as datasets
4 | import torchvision.models as models
5 | from torchvision.models import ResNet101_Weights
6 | from torch.utils.data import DataLoader
7 | import torch.nn.functional as F
8 | import sys
9 | import os
10 | from tqdm import tqdm
11 | # Pretrain ResNet101
12 | net = models.resnet101(weights=None, progress=True)
13 | checkpoint_path = 'path/to/resnet101-cd907fc2.pth'
14 | checkpoint = torch.load(checkpoint_path)
15 |
16 | net.load_state_dict(checkpoint)
17 |
18 |
19 | net.eval()
20 |
21 | # Data
22 | preprocess = transforms.Compose([
23 | transforms.Resize(256),
24 | transforms.CenterCrop(224),
25 | transforms.ToTensor(),
26 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
27 | ])
28 |
29 | # load ImageNet val
30 | imagenet_val_dataset = datasets.ImageFolder(root='path/to/val', transform=preprocess)
31 | val_loader = DataLoader(imagenet_val_dataset, batch_size=32, shuffle=False, num_workers=4)
32 |
33 |
34 | def evaluate_model(model, data_loader):
35 | model.eval()
36 | correct = 0
37 | total = 0
38 | with torch.no_grad():
39 | for images, labels in tqdm(data_loader):
40 | images = images.to('cuda')
41 | labels = labels.to('cuda')
42 | outputs = model(images)
43 | _, predicted = torch.max(outputs.data, 1)
44 | total += labels.size(0)
45 | correct += (predicted == labels).sum().item()
46 | # if predicted!=labels:
47 | # print(total%25)
48 | # print(labels)
49 | accuracy = 100 * correct / total
50 | return accuracy
51 |
52 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
53 | net.to(device)
54 |
55 | accuracy = evaluate_model(net, val_loader)
56 | print(f'Accuracy of the model on the ImageNet validation set: {accuracy:.2f}%')
57 |
--------------------------------------------------------------------------------
/utils/optimizers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 |
5 |
6 | def configure_optimizers(net, args):
7 | """Separate parameters for the main optimizer and the auxiliary optimizer.
8 | Return two optimizers"""
9 |
10 | parameters = [
11 | p for n, p in net.named_parameters() if not n.endswith(".quantiles")
12 | ]
13 | aux_parameters = [
14 | p for n, p in net.named_parameters() if n.endswith(".quantiles")
15 | ]
16 | # Make sure we don't have an intersection of parameters
17 | params_dict = dict(net.named_parameters())
18 | inter_params = set(parameters) & set(aux_parameters)
19 | union_params = set(parameters) | set(aux_parameters)
20 |
21 | assert len(inter_params) == 0
22 | assert len(union_params) - len(params_dict.keys()) == 0
23 |
24 | optimizer = optim.Adam(
25 | (p for p in parameters if p.requires_grad),
26 | lr=args.learning_rate,
27 | )
28 | aux_optimizer = optim.Adam(
29 | (p for p in aux_parameters if p.requires_grad),
30 | lr=args.aux_learning_rate,
31 | )
32 | return optimizer, aux_optimizer
33 | def configure_optimizer(net, args):
34 | """Separate parameters for the main optimizer and the auxiliary optimizer.
35 | Return two optimizers"""
36 |
37 | parameters = [
38 | p for n, p in net.named_parameters() if not n.endswith(".quantiles")
39 | ]
40 | aux_parameters = [
41 | p for n, p in net.named_parameters() if n.endswith(".quantiles")
42 | ]
43 | # Make sure we don't have an intersection of parameters
44 | params_dict = dict(net.named_parameters())
45 | inter_params = set(parameters) & set(aux_parameters)
46 | union_params = set(parameters) | set(aux_parameters)
47 |
48 | assert len(inter_params) == 0
49 | assert len(union_params) - len(params_dict.keys()) == 0
50 |
51 | optimizer = optim.Adam(
52 | (p for p in parameters if p.requires_grad),
53 | lr=args.learning_rate,
54 | )
55 |
56 | return optimizer
57 |
--------------------------------------------------------------------------------
/instance segmantation/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import random
5 | import detectron2
6 | from detectron2.engine import DefaultPredictor, DefaultTrainer
7 | from detectron2.config import get_cfg
8 | from detectron2.data import DatasetCatalog, MetadataCatalog
9 | from detectron2.data.datasets import register_pascal_voc
10 | from detectron2.utils.visualizer import Visualizer
11 | from detectron2.evaluation import COCOEvaluator, DatasetEvaluators, inference_on_dataset
12 | from detectron2.data import build_detection_test_loader
13 | from detectron2.utils.logger import setup_logger
14 | from detectron2.data.datasets import register_coco_instances
15 | # Initialize the logger
16 | # from detectron2.model_zoo import get_config_file,
17 | from detectron2 import model_zoo
18 | from PIL import ImageFile
19 | import numpy as np
20 | # ImageFile.LOAD_TRUNCATED_IMAGES = True
21 |
22 | setup_logger()
23 |
24 | dataset_name = "coco_8k_hy"
25 |
26 | json_file = "Coco/coco_2017.json"
27 |
28 | image_dir = "path/to/images"
29 |
30 | register_coco_instances(dataset_name, {}, json_file, image_dir)
31 |
32 | # Get dataset metadata and dataset dictionaries
33 | metadata = MetadataCatalog.get(dataset_name)
34 | dataset_dicts = DatasetCatalog.get(dataset_name)
35 |
36 | # Configure and load the pre-trained mask_rcnn model
37 | cfg = get_cfg()
38 |
39 | print("===============load model==============")
40 | cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml"))
41 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")
42 | print("==========complete==================")
43 |
44 | cfg.MODEL.DEVICE = "cuda" #
45 | cfg.DATASETS.TEST = (dataset_name, )
46 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(metadata.thing_classes)
47 |
48 | predictor = DefaultPredictor(cfg)
49 |
50 | evaluator = COCOEvaluator(dataset_name, cfg, False, output_dir="./mask/")
51 | val_loader = build_detection_test_loader(cfg, dataset_name)
52 |
53 | output_dir = "./mask/"
54 | os.makedirs(output_dir, exist_ok=True)
55 |
56 | # Perform inference and evaluation
57 | print("Running inference...")
58 | results = inference_on_dataset(predictor.model, val_loader, evaluator)
59 |
--------------------------------------------------------------------------------
/modules/transform/entropy.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from modules.layers.attention import MLP
5 |
6 |
7 | class EntropyParameters(nn.Module):
8 | def __init__(self, in_dim, out_dim, act=nn.GELU) -> None:
9 | super().__init__()
10 | self.fusion = nn.Sequential(
11 | nn.Conv2d(in_dim, 320, kernel_size=1, stride=1, padding=0),
12 | act(),
13 | nn.Conv2d(320, 256, kernel_size=1, stride=1, padding=0),
14 | act(),
15 | nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0),
16 | act(),
17 | nn.Conv2d(128, out_dim, kernel_size=1, stride=1, padding=0),
18 | )
19 |
20 | def forward(self, params):
21 | """
22 | Args:
23 | params(Tensor): [B, C * K, H, W]
24 | return:
25 | gaussian_params(Tensor): [B, C * 2, H, W]
26 | """
27 | gaussian_params = self.fusion(params)
28 |
29 | return gaussian_params
30 |
31 |
32 | class EntropyParametersEX(nn.Module):
33 | def __init__(self, in_dim, out_dim, act=nn.GELU) -> None:
34 | super().__init__()
35 | self.fusion = nn.Sequential(
36 | nn.Conv2d(in_dim, out_dim * 5 // 3, 1),
37 | act(),
38 | nn.Conv2d(out_dim * 5 // 3, out_dim * 4 // 3, 1),
39 | act(),
40 | nn.Conv2d(out_dim * 4 // 3, out_dim, 1),
41 | )
42 |
43 | def forward(self, params):
44 | """
45 | Args:
46 | params(Tensor): [B, C * K, H, W]
47 | return:
48 | gaussian_params(Tensor): [B, C * 2, H, W]
49 | """
50 | gaussian_params = self.fusion(params)
51 |
52 | return gaussian_params
53 |
54 |
55 | class ChannelWiseEntropyParameters(nn.Module):
56 | def __init__(self, in_channels=192, out_channels=192):
57 | super().__init__()
58 | diff = (in_channels - out_channels) // 3
59 | self.layers = nn.Sequential(
60 | nn.Conv2d(in_channels, in_channels - diff, 1),
61 | nn.LeakyReLU(inplace=True),
62 | nn.Conv2d(in_channels - diff, in_channels - 2 * diff, 1),
63 | nn.LeakyReLU(inplace=True),
64 | nn.Conv2d(in_channels - 2 * diff, out_channels, 1),
65 | )
66 |
67 | def forward(self, x):
68 | x = self.layers(x)
69 | return x
70 |
--------------------------------------------------------------------------------
/playground/test.py:
--------------------------------------------------------------------------------
1 | from re import T
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import os
6 | import logging
7 | import sys
8 |
9 |
10 |
11 | sys.path.append("..")
12 | from utils.train_data import TestData, testDataset_collate, NewTestData, TestImagenet
13 | from config.args import test_options
14 | from config.config import model_config
15 | from compressai.datasets import ImageFolder
16 | from torchvision import transforms
17 | from torch.utils.data import DataLoader
18 | from PIL import ImageFile, Image
19 | from models import *
20 | from utils.testing import test_model
21 | from utils.logger import setup_logger
22 | from utils.newtrain_data import TrainData
23 |
24 | def main():
25 | ImageFile.LOAD_TRUNCATED_IMAGES = True
26 | Image.MAX_IMAGE_PIXELS = None
27 |
28 | args = test_options()
29 | config = model_config()
30 |
31 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
32 |
33 | torch.backends.cudnn.deterministic = True
34 |
35 | if not os.path.exists(os.path.join('./experiments', args.experiment)):
36 | os.makedirs(os.path.join('./experiments', args.experiment))
37 | setup_logger('test', os.path.join('./experiments', args.experiment), 'test_' + args.experiment, level=logging.INFO,
38 | screen=True, tofile=True)
39 | logger_test = logging.getLogger('test')
40 |
41 | #
42 | # test_dataset = TestData("../voc_norm_test.txt")
43 | # test_dataset = TestImagenet("../../data/liuquan/Imagenet_val25k/val_set.txt")
44 | # test_dataset = NewTestData("../../data")
45 | test_dataset = NewTestData("../../data/Kodak24")
46 | # test_dataset = NewTestData("../../yolov3/coco/val2017")
47 | test_dataloader = DataLoader(
48 | test_dataset,
49 | batch_size=args.test_batch_size,
50 | num_workers=args.num_workers,
51 | shuffle=False,
52 | pin_memory=True,
53 | collate_fn=testDataset_collate,
54 | )
55 |
56 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
57 |
58 | net = MLICPlusPlus(config=config)
59 | net = net.to(device)
60 | checkpoint = torch.load(args.checkpoint)
61 | # new_ckpt = modify_checkpoint(checkpoint['state_dict'])
62 | net.load_state_dict(checkpoint['state_dict'])
63 | epoch = checkpoint["epoch"]
64 | logger_test.info(f"Start testing!" )
65 | save_dir = os.path.join('./experiments', args.experiment)
66 | if not os.path.exists(save_dir):
67 | os.makedirs(save_dir)
68 | test_model(net=net, test_dataloader=test_dataloader, logger_test=logger_test, save_dir=save_dir, epoch=epoch)
69 |
70 |
71 | if __name__ == '__main__':
72 | main()
73 |
74 |
--------------------------------------------------------------------------------
/tests/test.py:
--------------------------------------------------------------------------------
1 | from re import T
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import os
6 | import logging
7 | import sys
8 |
9 |
10 |
11 | sys.path.append("..")
12 | from utils.utils import save_checkpoint
13 | from utils.train_data import TestData, testDataset_collate
14 |
15 | from config.args import test_options
16 | from config.config import model_config
17 |
18 | from torch.utils.data import DataLoader
19 | from PIL import ImageFile, Image
20 | from models import *
21 | from utils.testing import test_model
22 | from utils.logger import setup_logger
23 |
24 |
25 | def main():
26 | ImageFile.LOAD_TRUNCATED_IMAGES = True
27 | Image.MAX_IMAGE_PIXELS = None
28 |
29 | args = test_options()
30 | config = model_config()
31 |
32 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
33 |
34 | torch.backends.cudnn.deterministic = True
35 |
36 | if not os.path.exists(os.path.join('./experiments', args.experiment)):
37 | os.makedirs(os.path.join('./experiments', args.experiment))
38 | setup_logger('test', os.path.join('./experiments', args.experiment), 'test_' + args.experiment, level=logging.INFO,
39 | screen=True, tofile=True)
40 | logger_test = logging.getLogger('test')
41 |
42 |
43 | test_dataset = TestData(args.dataset)
44 | test_dataloader = DataLoader(
45 | test_dataset,
46 | batch_size=args.test_batch_size,
47 | num_workers=args.num_workers,
48 | shuffle=False,
49 | pin_memory=True,
50 | collate_fn=testDataset_collate,
51 | )
52 |
53 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
54 |
55 | net = UG_MLICPlusPlus(config=config)
56 | net = net.to(device)
57 | checkpoint = torch.load(args.checkpoint)
58 | net.load_state_dict(checkpoint['state_dict'])
59 | logger_test.info(f"Start testing!" )
60 | save_dir = os.path.join('./experiments', args.experiment, "bitstream")
61 | save_dir_human = os.path.join('./experiments', args.experiment, "human")
62 | save_dir_machine = os.path.join('./experiments', args.experiment, "machine")
63 | if not os.path.exists(save_dir):
64 | os.makedirs(save_dir)
65 | if not os.path.exists(save_dir_human):
66 | os.makedirs(save_dir_human)
67 | if not os.path.exists(save_dir_machine):
68 | os.makedirs(save_dir_machine)
69 | test_model(net=net, test_dataloader=test_dataloader, logger_test=logger_test, save_dir=save_dir ,save_dir_human=save_dir_human ,save_dir_machine=save_dir_machine)
70 |
71 |
72 | if __name__ == '__main__':
73 | main()
74 |
75 |
--------------------------------------------------------------------------------
/playground/test_condi.py:
--------------------------------------------------------------------------------
1 | from re import T
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import os
6 | import logging
7 | import sys
8 |
9 |
10 |
11 | sys.path.append("..")
12 | from utils.train_data import TestData, testDataset_collate, NewTestData, TestImagenet
13 | from models.mlicpp_con import ConditionalMLICPlusPlus
14 | from config.args import test_options
15 | from config.config import model_config
16 | from compressai.datasets import ImageFolder
17 | from torchvision import transforms
18 | from torch.utils.data import DataLoader
19 | from PIL import ImageFile, Image
20 | from models import *
21 | from utils.testing import test_model
22 | from utils.logger import setup_logger
23 | from utils.newtrain_data import TrainData
24 |
25 | def main():
26 | ImageFile.LOAD_TRUNCATED_IMAGES = True
27 | Image.MAX_IMAGE_PIXELS = None
28 |
29 | args = test_options()
30 | config = model_config()
31 |
32 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
33 |
34 | torch.backends.cudnn.deterministic = True
35 |
36 | if not os.path.exists(os.path.join('./experiments', args.experiment)):
37 | os.makedirs(os.path.join('./experiments', args.experiment))
38 | setup_logger('test', os.path.join('./experiments', args.experiment), 'test_' + args.experiment, level=logging.INFO,
39 | screen=True, tofile=True)
40 | logger_test = logging.getLogger('test')
41 |
42 |
43 | # test_dataset = TestData("../voc_norm_test.txt")
44 | test_dataset = NewTestData("../../data/segment")
45 | # test_dataset = TestImagenet("../../data/liuquan/Imagenet_val25k/val_set.txt")
46 | # test_dataset = NewTestData("../../yolov3/coco/val2017")
47 | # test_dataset = NewTestData("../../data/Kodak24")
48 | test_dataloader = DataLoader(
49 | test_dataset,
50 | batch_size=args.test_batch_size,
51 | num_workers=args.num_workers,
52 | shuffle=False,
53 | pin_memory=True,
54 | collate_fn=testDataset_collate,
55 | )
56 |
57 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
58 |
59 | net = ConditionalMLICPlusPlus(config=config)
60 | net.beta= args.beta
61 | net = net.to(device)
62 | checkpoint = torch.load(args.checkpoint)
63 | # new_ckpt = modify_checkpoint(checkpoint['state_dict'])
64 | net.load_state_dict(checkpoint['state_dict'])
65 | epoch = checkpoint["epoch"]
66 | logger_test.info(f"Start testing!" )
67 | save_dir = os.path.join('./experiments', args.experiment)
68 | if not os.path.exists(save_dir):
69 | os.makedirs(save_dir)
70 | test_model(net=net, test_dataloader=test_dataloader, logger_test=logger_test, save_dir=save_dir, epoch=epoch)
71 |
72 |
73 | if __name__ == '__main__':
74 | main()
75 |
76 |
--------------------------------------------------------------------------------
/modules/transform/analysis.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from compressai.layers import subpel_conv3x3
5 | from modules.layers.conv import conv1x1, conv3x3, conv, deconv
6 | from modules.layers.res_blk import *
7 |
8 |
9 | class AnalysisTransform(nn.Module):
10 | def __init__(self, N, M):
11 | super().__init__()
12 | self.analysis_transform = nn.Sequential(
13 | ResidualBlockWithStride(3, N, stride=2),
14 | ResidualBlock(N, N),
15 | ResidualBlockWithStride(N, N, stride=2),
16 | ResidualBlock(N, N),
17 | ResidualBlockWithStride(N, N, stride=2),
18 | ResidualBlock(N, N),
19 | conv3x3(N, M, stride=2)
20 | )
21 |
22 | def forward(self, x):
23 | x = self.analysis_transform(x)
24 |
25 | return x
26 |
27 |
28 | class HyperAnalysis(nn.Module):
29 | """
30 | Local reference
31 | """
32 | def __init__(self, M=192, N=192):
33 | super().__init__()
34 | self.M = M
35 | self.N = N
36 | self.reduction = nn.Sequential(
37 | conv3x3(M, N),
38 | nn.GELU(),
39 | conv3x3(N, N),
40 | nn.GELU(),
41 | conv3x3(N, N, stride=2),
42 | nn.GELU(),
43 | conv3x3(N, N),
44 | nn.GELU(),
45 | conv3x3(N, N, stride=2),
46 | )
47 |
48 | def forward(self, x):
49 | x = self.reduction(x)
50 |
51 | return x
52 |
53 | class AnalysisTransformEX(nn.Module):
54 | def __init__(self, N, M, act=nn.GELU):
55 | super().__init__()
56 | self.analysis_transform = nn.Sequential(
57 | conv(3, N),
58 | ResidualBottleneck(N, act=act, groups=N * 2),
59 | ResidualBottleneck(N, act=act, groups=N * 2),
60 | ResidualBottleneck(N, act=act, groups=N * 2),
61 | conv(N, N),
62 | ResidualBottleneck(N, act=act, groups=N * 2),
63 | ResidualBottleneck(N, act=act, groups=N * 2),
64 | ResidualBottleneck(N, act=act, groups=N * 2),
65 | AttentionBlock(N),
66 | conv(N, N),
67 | ResidualBottleneck(N, act=act, groups=N * 2),
68 | ResidualBottleneck(N, act=act, groups=N * 2),
69 | ResidualBottleneck(N, act=act, groups=N * 2),
70 | conv(N, M),
71 | AttentionBlock(M)
72 | )
73 |
74 | def forward(self, x):
75 | x = self.analysis_transform(x)
76 | return x
77 |
78 |
79 | class HyperAnalysisEX(nn.Module):
80 | def __init__(self, N, M, act=nn.GELU) -> None:
81 | super().__init__()
82 | self.M = M
83 | self.N = N
84 | self.reduction = nn.Sequential(
85 | conv3x3(M, N),
86 | act(),
87 | conv(N, N),
88 | act(),
89 | conv(N, N)
90 | )
91 |
92 | def forward(self, x):
93 | x = self.reduction(x)
94 | return x
--------------------------------------------------------------------------------
/models/fourier_cond.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class BetaMlp(nn.Module):
6 | def __init__(self, channels=512, act_layer=nn.ReLU):
7 | super(BetaMlp, self).__init__()
8 | self.output_features = channels
9 | self.fc1 = nn.Linear(21, self.output_features)
10 | self.act = act_layer()
11 | self.fc2 = nn.Linear(self.output_features, self.output_features)
12 |
13 | def forward(self, x):
14 |
15 | x = self.fc1(x)
16 | x = self.act(x)
17 | x = self.fc2(x)
18 | x = self.act(x)
19 | return x
20 |
21 | class Embedder:
22 | def __init__(self, **kwargs):
23 | self.kwargs = kwargs
24 | self.create_embedding_fn()
25 |
26 | def create_embedding_fn(self):
27 | embed_fns = []
28 | d = self.kwargs['input_dims']
29 | out_dim = 0
30 | if self.kwargs['include_input']:
31 | embed_fns.append(lambda x: x)
32 | out_dim += d
33 |
34 | max_freq = self.kwargs['max_freq_log2']
35 | N_freqs = self.kwargs['num_freqs']
36 |
37 | if self.kwargs['log_sampling']:
38 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
39 | else:
40 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, N_freqs)
41 |
42 | for freq in freq_bands:
43 | for p_fn in self.kwargs['periodic_fns']:
44 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
45 | out_dim += d
46 |
47 | self.embed_fns = embed_fns
48 | self.out_dim = out_dim
49 |
50 | def embed(self, inputs):
51 | # 将 self.embed_fns 中的每个函数应用于 inputs,并将结果堆叠起来
52 | stacked_outputs = torch.stack([fn(inputs) for fn in self.embed_fns])
53 |
54 | # 对堆叠后的张量进行转置操作
55 | transposed_outputs = stacked_outputs.permute(1, 0)
56 | return transposed_outputs
57 |
58 | def get_embedder(multires, i=0):
59 | if i == -1:
60 | return lambda x: x, 3
61 |
62 | embed_kwargs = {
63 | 'include_input': True,
64 | 'input_dims': 1,
65 | 'max_freq_log2': multires - 1,
66 | 'num_freqs': multires,
67 | 'log_sampling': True,
68 | 'periodic_fns': [torch.sin, torch.cos],
69 | }
70 |
71 | embedder_obj = Embedder(**embed_kwargs)
72 |
73 | def embed(x, eo=embedder_obj): return eo.embed(x)
74 |
75 | return embed, embedder_obj.out_dim
76 |
77 | # def build_global_conditioning():
78 | # beta = torch.randn(1) # Dummy input, replace with actual input tensor
79 | #
80 | # embed_fn, input_ch = get_embedder(multires=10, i=0)
81 | # beta_mlp = BetaMlp()
82 | # fourier_features = embed_fn(beta)
83 | # fourier_features_mlp = beta_mlp(fourier_features)
84 | #
85 | # print(fourier_features_mlp)
86 | #
87 | # return beta_mlp # Return the model instance
88 |
89 | class GlobalConditioning(nn.Module):
90 | def __init__(self):
91 | super(GlobalConditioning, self).__init__()
92 | self.beta_mlp = BetaMlp()
93 | self.embed_fn, self.input_ch = get_embedder(multires=10, i=0) # Assuming get_embedder is defined elsewhere
94 |
95 | def forward(self, beta):
96 | fourier_features = self.embed_fn(beta)
97 | fourier_features_mlp = self.beta_mlp(fourier_features)
98 | return fourier_features_mlp
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | import PIL.Image as Image
2 | import shutil
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import struct
7 | from pathlib import Path
8 | from torchvision.transforms import ToPILImage
9 | import json
10 |
11 |
12 | """ configuration json """
13 | class Config(dict):
14 | __getattr__ = dict.__getitem__
15 | __setattr__ = dict.__setitem__
16 |
17 | @classmethod
18 | def load(cls, file):
19 | with open(file, 'r') as f:
20 | config = json.loads(f.read())
21 | return Config(config)
22 |
23 |
24 | def write_uchars(fd, values, fmt=">{:d}B"):
25 | fd.write(struct.pack(fmt.format(len(values)), *values))
26 | return len(values) * 1
27 |
28 |
29 | def write_uints(fd, values, fmt=">{:d}I"):
30 | fd.write(struct.pack(fmt.format(len(values)), *values))
31 | return len(values) * 4
32 |
33 |
34 | def read_uints(fd, n, fmt=">{:d}I"):
35 | sz = struct.calcsize("I")
36 | return struct.unpack(fmt.format(n), fd.read(n * sz))
37 |
38 |
39 | def read_uchars(fd, n, fmt=">{:d}B"):
40 | sz = struct.calcsize("B")
41 | return struct.unpack(fmt.format(n), fd.read(n * sz))
42 |
43 |
44 | def write_bytes(fd, values, fmt=">{:d}s"):
45 | if len(values) == 0:
46 | return
47 | fd.write(struct.pack(fmt.format(len(values)), values))
48 | return len(values) * 1
49 |
50 |
51 | def read_bytes(fd, n, fmt=">{:d}s"):
52 | sz = struct.calcsize("s")
53 | return struct.unpack(fmt.format(n), fd.read(n * sz))[0]
54 |
55 |
56 | def read_body(fd):
57 | lstrings = []
58 | shape = read_uints(fd, 2)
59 | n_strings = read_uints(fd, 1)[0]
60 | for _ in range(n_strings):
61 | s = read_bytes(fd, read_uints(fd, 1)[0])
62 | lstrings.append([s])
63 |
64 | return lstrings, shape
65 |
66 |
67 | def write_body(fd, shape, out_strings):
68 | bytes_cnt = 0
69 | bytes_cnt = write_uints(fd, (shape[0], shape[1], len(out_strings)))
70 | for s in out_strings:
71 | bytes_cnt += write_uints(fd, (len(s[0]),))
72 | bytes_cnt += write_bytes(fd, s[0])
73 | return bytes_cnt
74 |
75 |
76 | def filesize(filepath: str) -> int:
77 | if not Path(filepath).is_file():
78 | raise ValueError(f'Invalid file "{filepath}".')
79 | return Path(filepath).stat().st_size
80 |
81 |
82 | def torch2img(x: torch.Tensor) -> Image.Image:
83 | return ToPILImage()(x.clamp_(0, 1).squeeze())
84 |
85 |
86 | class AverageMeter:
87 | """Compute running average."""
88 |
89 | def __init__(self):
90 | self.val = 0
91 | self.avg = 0
92 | self.sum = 0
93 | self.count = 0
94 |
95 | def update(self, val, n=1):
96 | self.val = val
97 | self.sum += val * n
98 | self.count += n
99 | self.avg = self.sum / self.count
100 |
101 |
102 | class CustomDataParallel(nn.DataParallel):
103 | """Custom DataParallel to access the module methods."""
104 |
105 | def __getattr__(self, key):
106 | try:
107 | return super().__getattr__(key)
108 | except AttributeError:
109 | return getattr(self.module, key)
110 |
111 |
112 | def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"):
113 | torch.save(state, filename)
114 | if is_best:
115 | best_filename = filename.replace(filename.split('/')[-1], "checkpoint_best_loss.pth.tar")
116 | shutil.copyfile(filename, best_filename)
117 |
118 |
119 | def split_data(source_path, destination_path, train_file):
120 | f = open(train_file, encoding='utf-8')
121 | while True:
122 | line = f.readline()
123 | if line:
124 | line = line.strip('\n')
125 | print(line)
126 | img_path = source_path + line
127 | shutil.move(img_path, destination_path + line)
128 |
--------------------------------------------------------------------------------
/modules/transform/synthesis.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from compressai.layers import subpel_conv3x3, AttentionBlock
5 | from modules.layers.conv import conv1x1, conv3x3, conv, deconv
6 | from modules.layers.res_blk import *
7 |
8 |
9 | class HyperSynthesis(nn.Module):
10 | """
11 | Local Reference
12 | """
13 | def __init__(self, M=192, N=192) -> None:
14 | super().__init__()
15 | self.M = M
16 | self.N = N
17 |
18 | self.increase = nn.Sequential(
19 | conv3x3(N, M),
20 | nn.GELU(),
21 | subpel_conv3x3(M, M, 2),
22 | nn.GELU(),
23 | conv3x3(M, M * 3 // 2),
24 | nn.GELU(),
25 | subpel_conv3x3(M * 3 // 2, M * 3 // 2, 2),
26 | nn.GELU(),
27 | conv3x3(M * 3 // 2, M * 2),
28 | )
29 |
30 | def forward(self, x):
31 | x = self.increase(x)
32 |
33 | return x
34 |
35 |
36 | class SynthesisTransform(nn.Module):
37 | def __init__(self, N, M):
38 | super().__init__()
39 | self.synthesis_transform = nn.Sequential(
40 | ResidualBlock(M, N),
41 | ResidualBlockUpsample(N, N, 2),
42 | ResidualBlock(N, N),
43 | ResidualBlockUpsample(N, N, 2),
44 | ResidualBlock(N, N),
45 | ResidualBlockUpsample(N, N, 2),
46 | ResidualBlock(N, N),
47 | subpel_conv3x3(N, 3, 2),
48 | )
49 |
50 | def forward(self, x):
51 | x = self.synthesis_transform(x)
52 |
53 | return x
54 |
55 | class ConditionalSynthesisTransform(nn.Module):
56 | def __init__(self, N, M):
57 | super().__init__()
58 | self.synthesis_transform = nn.Sequential(
59 | PCDM(M, N),
60 | ResidualBlockUpsample(N, N, 2),
61 | PCDM(N, N),
62 | ResidualBlockUpsample(N, N, 2),
63 | PCDM(N, N),
64 | ResidualBlockUpsample(N, N, 2),
65 | PCDM(N, N),
66 | subpel_conv3x3(N, 3, 2),
67 | )
68 |
69 | def forward(self, x):
70 | x, fourier_features_mlp = x
71 | for block in self.synthesis_transform:
72 | if isinstance(block, PCDM):
73 | x = block((x, fourier_features_mlp))
74 | else:
75 | x = block(x)
76 |
77 | return x
78 |
79 |
80 | class SynthesisTransformEX(nn.Module):
81 | def __init__(self, N, M, act=nn.GELU) -> None:
82 | super().__init__()
83 | self.synthesis_transform = nn.Sequential(
84 | AttentionBlock(M),
85 | deconv(M, N),
86 | ResidualBottleneck(N, act=act, groups=N * 2),
87 | ResidualBottleneck(N, act=act, groups=N * 2),
88 | ResidualBottleneck(N, act=act, groups=N * 2),
89 | deconv(N, N),
90 | AttentionBlock(N),
91 | ResidualBottleneck(N, act=act, groups=N * 2),
92 | ResidualBottleneck(N, act=act, groups=N * 2),
93 | ResidualBottleneck(N, act=act),
94 | deconv(N, N),
95 | ResidualBottleneck(N, act=act, groups=N * 2),
96 | ResidualBottleneck(N, act=act, groups=N * 2),
97 | ResidualBottleneck(N, act=act, groups=N * 2),
98 | deconv(N, 3)
99 | )
100 |
101 | def forward(self, x):
102 | x = self.synthesis_transform(x)
103 | return x
104 |
105 |
106 | class HyperSynthesisEX(nn.Module):
107 | def __init__(self, N, M, act=nn.GELU) -> None:
108 | super().__init__()
109 | self.increase = nn.Sequential(
110 | deconv(N, M),
111 | act(),
112 | deconv(M, M * 3 // 2),
113 | act(),
114 | deconv(M * 3 // 2, M * 2, kernel_size=3, stride=1),
115 | )
116 |
117 | def forward(self, x):
118 | x = self.increase(x)
119 | return x
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Unified Coding for Both Human Perception and Generalized Machine Analytics with CLIP Supervision(UG-ICM)
2 | This repo contains the official PyTorch implementation for the paper “Unified Coding for Both Human Perception and Generalized Machine Analytics with CLIP Supervision”.
3 |
4 | ## Updates
5 |
6 | #### 2025/11/17
7 | The training script has been open-sourced.
8 |
9 | #### 2024/12/13
10 | The training script is about to be open-sourced.
11 |
12 | #### 2024/12/10
13 | *Unified Coding for Both Human Perception and Generalized Machine Analytics with CLIP Supervision(UG-ICM)* is accepted at **AAAI 2025**!
14 |
15 |
16 | ## Abstract
17 | The image compression model has long struggled with adaptability and generalization, as the decoded bitstream typically serves only human or machine needs and fails to preserve information for unseen visual tasks. Therefore, this paper innovatively introduces supervision obtained from multimodal pre-training models and incorporates adaptive multi-objective optimization tailored to support both human visual perception and machine vision simultaneously with a single bitstream, denoted as Unified and Generalized Image Coding for Machine (UG-ICM). Specifically, to get rid of the reliance between compression models with downstream task supervision, we introduce Contrastive Language-Image Pre-training(CLIP) models into the training constraint for improved generalization. Global-to-instance-wise CLIP supervision is applied to help obtain hierarchical semantics that make models more generalizable for the tasks relying on the information of different granularity. Furthermore, for supporting both human and machine visions with only a unifying bitstream, we incorporate a conditional decoding strategy that takes as conditions human or machine preferences, enabling the bitstream to be decoded into different versions for corresponding preferences. As such, our proposed UG-ICM is fully trained in a self-supervised manner, i.e., without awareness of any specific downstream models and tasks. The extensive experiments have shown that the proposed UG-ICM is capable of achieving remarkable improvements in various unseen machine analytics tasks, while simultaneously providing perceptually satisfying images.
18 |
19 | 
20 |
21 | 
22 |
23 | ## Environment
24 |
25 | Pytorch 3.7.16
26 |
27 | CompressAI 1.2.0b3
28 |
29 | Pytorch 1.13.0
30 |
31 | ## Weights
32 |
33 |
34 |
35 |
36 | | | Link |
37 | |:--------:|:--------:|
38 | | UG-ICM| [BaiDu Drive](https://pan.baidu.com/s/1bEcDDbiSIPj67yVrUt6tGQ?pwd=5doz) |
39 |
40 |
41 |
42 | ## Training:
43 | Stage One
44 | ```python
45 | cd ./playground && python train_condi.py --metrics mse --exp mlicpp_condi_q1 --gpu_id 0 --lambda 0.0022 --lambda_beta1 0 --lambda_clip 22 -lr 1e-5 --seed 2000 --batch-size 128 --test-batch-size 128 --tune -e 20
46 | ```
47 | Stage Two
48 | ```python
49 | cd ./playground && python train_mask.py --metrics mse --exp mlicpp_condi_q4_tune --gpu_id 0 --lambda 0.001 --lambda_beta1 1 --lambda_clip 1 -lr 1e-5 --seed 2000 --batch-size 128 --test-batch-size 128 --tune -e 20 -c experiments/mlicpp_condi_q4_2/checkpoints/checkpoint_best_loss.pth.tar
50 | ```
51 | ## Testing:
52 |
53 | ### Compress:
54 | Encode and compress the test images, and obtain the decoded images for human and the decoded images for machine from the unified bitstreams.
55 | ```python
56 | cd tests
57 |
58 | python test.py -exp test --gpu_id 0 -c /path/to/checkpoint -d /path/to/dataset
59 | ```
60 |
61 | ### Classification:
62 | #### Dataset: ImageNet-1k
63 |
64 | The performance of decoded images for classification task (modify the decoded image path in `test.py`).
65 | ```python
66 | cd classification
67 |
68 | python test.py
69 | ```
70 |
71 | ### Instance segmantation:
72 | #### Dataset: COCO2017_val
73 |
74 | The performance of decoded images for instance segmantation task (modify the decoded image path in `test.py`).
75 | ```python
76 | cd instance segmantation
77 |
78 | python test.py
79 | ```
80 |
81 | ## License
82 |
83 | [MIT License](https://opensource.org/licenses/MIT)
84 |
85 | ## Acknowledgments
86 | Thanks [Compressai](https://github.com/InterDigitalInc/CompressAI), [MLIC](https://github.com/JiangWeibeta/MLIC), [CLIP](https://github.com/openai/CLIP), [TransTIC](https://github.com/NYCU-MAPL/TransTIC) for their public code and released models.
87 |
88 |
89 |
--------------------------------------------------------------------------------
/utils/func.py:
--------------------------------------------------------------------------------
1 | # From CompresssAI
2 | # Modified by Wei Jiang
3 |
4 | import torch
5 | import torch.nn as nn
6 | import math
7 | import einops
8 |
9 | def calc_params(model):
10 | """
11 | Calculate the number of parameters in a model
12 | """
13 | return sum(p.numel() for p in model.parameters())
14 |
15 |
16 | def get_scale_table(
17 | min=0.11, max=256, levels=64
18 | ): # pylint: disable=W0622
19 | return torch.exp(torch.linspace(math.log(min), math.log(max), levels)) # 为什么要先ln再求e次方,是为了更高的精度吗?
20 |
21 |
22 | def find_named_module(module, query):
23 | """Helper function to find a named module. Returns a `nn.Module` or `None`
24 |
25 | Args:
26 | module (nn.Module): the root module
27 | query (str): the module name to find
28 |
29 | Returns:
30 | nn.Module or None
31 | """
32 |
33 | return next((m for n, m in module.named_modules() if n == query), None)
34 |
35 |
36 | def find_named_buffer(module, query):
37 | """Helper function to find a named buffer. Returns a `torch.Tensor` or `None`
38 |
39 | Args:
40 | module (nn.Module): the root module
41 | query (str): the buffer name to find
42 |
43 | Returns:
44 | torch.Tensor or None
45 | """
46 | return next((b for n, b in module.named_buffers() if n == query), None)
47 |
48 |
49 | def _update_registered_buffer(
50 | module,
51 | buffer_name,
52 | state_dict_key,
53 | state_dict,
54 | policy="resize_if_empty",
55 | dtype=torch.int,
56 | ):
57 | new_size = state_dict[state_dict_key].size()
58 | registered_buf = find_named_buffer(module, buffer_name)
59 |
60 | if policy in ("resize_if_empty", "resize"):
61 | if registered_buf is None:
62 | raise RuntimeError(f'buffer "{buffer_name}" was not registered')
63 |
64 | if policy == "resize" or registered_buf.numel() == 0:
65 | registered_buf.resize_(new_size)
66 |
67 | elif policy == "register":
68 | if registered_buf is not None:
69 | raise RuntimeError(f'buffer "{buffer_name}" was already registered')
70 |
71 | module.register_buffer(buffer_name, torch.empty(new_size, dtype=dtype).fill_(0))
72 |
73 | else:
74 | raise ValueError(f'Invalid policy "{policy}"')
75 |
76 |
77 | def update_registered_buffers(
78 | module,
79 | module_name,
80 | buffer_names,
81 | state_dict,
82 | policy="resize_if_empty",
83 | dtype=torch.int,
84 | ):
85 | """Update the registered buffers in a module according to the tensors sized
86 | in a state_dict.
87 |
88 | (There's no way in torch to directly load a buffer with a dynamic size)
89 |
90 | Args:
91 | module (nn.Module): the module
92 | module_name (str): module name in the state dict
93 | buffer_names (list(str)): list of the buffer names to resize in the module
94 | state_dict (dict): the state dict
95 | policy (str): Update policy, choose from
96 | ('resize_if_empty', 'resize', 'register')
97 | dtype (dtype): Type of buffer to be registered (when policy is 'register')
98 | """
99 | valid_buffer_names = [n for n, _ in module.named_buffers()]
100 | for buffer_name in buffer_names:
101 | if buffer_name not in valid_buffer_names:
102 | raise ValueError(f'Invalid buffer name "{buffer_name}"')
103 |
104 | for buffer_name in buffer_names:
105 | _update_registered_buffer(
106 | module,
107 | buffer_name,
108 | f"{module_name}.{buffer_name}",
109 | state_dict,
110 | policy,
111 | dtype,
112 | )
113 |
114 | def cal_params(model):
115 | """
116 | Calculate the number of parameters in a model
117 | """
118 | return sum(p.numel() for p in model.parameters())
119 |
120 |
121 |
122 | def image2patch(x, patch_size):
123 | """Image to patches."""
124 | batch, channels, height, width = x.shape
125 | grid_height = height // patch_size[0]
126 | grid_width = width // patch_size[1]
127 | x = einops.rearrange(
128 | x, "n c (gh fh) (gw fw) -> n c (gh gw) (fh fw)",
129 | gh=grid_height, gw=grid_width, fh=patch_size[0], fw=patch_size[1])
130 | return x
131 |
132 |
133 | def patch2image(x, grid_size, patch_size):
134 | """patches to images."""
135 | x = einops.rearrange(
136 | x, "n c (gh gw) (fh fw) -> n c (gh fh) (gw fw)",
137 | gh=grid_size[0], gw=grid_size[1], fh=patch_size[0], fw=patch_size[1])
138 | return x
139 |
--------------------------------------------------------------------------------
/utils/newtrain_data.py:
--------------------------------------------------------------------------------
1 | """
2 | paper: GridDehazeNet: Attention-Based Multi-Scale Network for Image Dehazing
3 | file: train_data.py
4 | about: build the training dataset
5 | author: Xiaohong Liu
6 | date: 01/08/19
7 | """
8 |
9 | # --- Imports --- #
10 | import gzip
11 | import random
12 |
13 | import numpy as np
14 | import torch
15 | import torch.utils.data as data
16 | from PIL import Image
17 | import torch.nn.functional as F
18 |
19 | from torchvision.transforms import Compose, ToTensor, Resize, RandomCrop
20 |
21 |
22 | def compute_padding(in_h: int, in_w: int, *, out_h=None, out_w=None, min_div=1):
23 | """Returns tuples for padding and unpadding.
24 |
25 | Args:
26 | in_h: Input height.
27 | in_w: Input width.
28 | out_h: Output height.
29 | out_w: Output width.
30 | min_div: Length that output dimensions should be divisible by.
31 | """
32 | if out_h is None:
33 | out_h = (in_h + min_div - 1) // min_div * min_div
34 | if out_w is None:
35 | out_w = (in_w + min_div - 1) // min_div * min_div
36 |
37 | if out_h % min_div != 0 or out_w % min_div != 0:
38 | raise ValueError(
39 | f"Padded output height and width are not divisible by min_div={min_div}."
40 | )
41 |
42 | left = (out_w - in_w) // 2
43 | right = out_w - in_w - left
44 | top = (out_h - in_h) // 2
45 | bottom = out_h - in_h - top
46 |
47 | pad = (left, right, top, bottom)
48 | unpad = (-left, -right, -top, -bottom)
49 |
50 | return pad, unpad
51 |
52 | def pad(x, p=2**6):
53 | h, w = x.size(1), x.size(2)
54 | pad, _ = compute_padding(h, w, min_div=p)
55 | return F.pad(x, pad, mode="constant", value=0)
56 |
57 |
58 | # --- Training dataset --- #
59 | class TrainData(data.Dataset):
60 | def __init__(self, train_data_dir):
61 | super().__init__()
62 | train_list = train_data_dir
63 | with open(train_list) as f:
64 | txt = f.readlines()
65 | annotations = [line.strip() for line in txt]
66 |
67 | self.annotations = annotations
68 |
69 | def get_images(self, index):
70 | seed = torch.random.seed()
71 |
72 |
73 | img_name = self.annotations[index]
74 |
75 |
76 | image = Image.open("../../../coco2voc/coco2voc/" + 'JPEGImages/' + str(img_name) + '.jpg')
77 | # --- Transform to tensor --- #
78 | transform1 = Compose([RandomCrop((256, 256)), ToTensor()])
79 | transform2 = Compose([Resize((256, 256)), ToTensor()])
80 | image_width, image_height = image.size
81 | if image_width >= 256 and image_height >= 256:
82 | torch.random.manual_seed(seed)
83 | image = transform1(image)
84 | else:
85 | image = transform2(image)
86 |
87 | return image
88 |
89 | def __getitem__(self, index):
90 | res = self.get_images(index)
91 | return res
92 |
93 | def __len__(self):
94 | return len(self.annotations)
95 |
96 | class TrainDataWithMask(data.Dataset):
97 | def __init__(self, train_data_dir):
98 | super().__init__()
99 | train_list = train_data_dir
100 | with open(train_list) as f:
101 | txt = f.readlines()
102 | annotations = [line.strip() for line in txt]
103 |
104 | self.annotations = annotations
105 |
106 | def get_images(self, index):
107 | seed = torch.random.seed()
108 |
109 | img_name = self.annotations[index]
110 |
111 | image1 = Image.open("../../../coco2voc/coco2voc/" + 'JPEGImages/' + str(img_name) + '.jpg')
112 | image2 = image1.copy()
113 | # --- Transform to tensor --- #
114 | transform1 = Compose([RandomCrop((256, 256)), ToTensor()])
115 | transform2 = Compose([Resize((256, 256)), ToTensor()])
116 | image_width, image_height = image1.size
117 | if image_width >= 256 and image_height >= 256:
118 | torch.random.manual_seed(seed)
119 | image1 = transform1(image1)
120 | else:
121 | image1 = transform2(image1)
122 | image2 = transform2(image2)
123 | with gzip.open("../../cocomask/cocomask/" + str(img_name) + '.pth.gz', 'rb') as f:
124 | MaskList = torch.load(f)
125 |
126 | return image1, image2, MaskList
127 |
128 |
129 |
130 | def __getitem__(self, index):
131 | res = self.get_images(index)
132 | return res
133 |
134 | def __len__(self):
135 | return len(self.annotations)
136 |
137 |
138 | # DataLoader中collate_fn使用
139 | def dataset_collate(batch):
140 |
141 | imagesets = []
142 | for imageset in batch:
143 | imagesets.append(imageset[0])
144 | imagesets = torch.from_numpy(np.array([item.numpy() for item in imageset])).type(torch.FloatTensor)
145 | print(imagesets.shape)
146 | return imagesets
147 |
--------------------------------------------------------------------------------
/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/utils/ckbd.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from compressai.entropy_models import GaussianConditional, EntropyModel
4 |
5 |
6 | def ckbd_split(y):
7 | """
8 | Split y to anchor and non-anchor
9 | anchor :
10 | 0 1 0 1 0
11 | 1 0 1 0 1
12 | 0 1 0 1 0
13 | 1 0 1 0 1
14 | 0 1 0 1 0
15 | non-anchor:
16 | 1 0 1 0 1
17 | 0 1 0 1 0
18 | 1 0 1 0 1
19 | 0 1 0 1 0
20 | 1 0 1 0 1
21 | """
22 | anchor = ckbd_anchor(y)
23 | nonanchor = ckbd_nonanchor(y)
24 | return anchor, nonanchor
25 |
26 | def ckbd_merge(anchor, nonanchor):
27 | # out = torch.zeros_like(anchor).to(anchor.device)
28 | # out[:, :, 0::2, 0::2] = non_anchor[:, :, 0::2, 0::2]
29 | # out[:, :, 1::2, 1::2] = non_anchor[:, :, 1::2, 1::2]
30 | # out[:, :, 0::2, 1::2] = anchor[:, :, 0::2, 1::2]
31 | # out[:, :, 1::2, 0::2] = anchor[:, :, 1::2, 0::2]
32 |
33 | return anchor + nonanchor
34 |
35 | def ckbd_anchor(y):
36 | anchor = torch.zeros_like(y).to(y.device)
37 | anchor[:, :, 0::2, 1::2] = y[:, :, 0::2, 1::2]
38 | anchor[:, :, 1::2, 0::2] = y[:, :, 1::2, 0::2]
39 | return anchor
40 |
41 | def ckbd_nonanchor(y):
42 | nonanchor = torch.zeros_like(y).to(y.device)
43 | nonanchor[:, :, 0::2, 0::2] = y[:, :, 0::2, 0::2]
44 | nonanchor[:, :, 1::2, 1::2] = y[:, :, 1::2, 1::2]
45 | return nonanchor
46 |
47 | def ckbd_anchor_sequeeze(y):
48 | B, C, H, W = y.shape
49 | anchor = torch.zeros([B, C, H, W // 2]).to(y.device)
50 | anchor[:, :, 0::2, :] = y[:, :, 0::2, 1::2]
51 | anchor[:, :, 1::2, :] = y[:, :, 1::2, 0::2]
52 | return anchor
53 |
54 | def ckbd_nonanchor_sequeeze(y):
55 | B, C, H, W = y.shape
56 | nonanchor = torch.zeros([B, C, H, W // 2]).to(y.device)
57 | nonanchor[:, :, 0::2, :] = y[:, :, 0::2, 0::2]
58 | nonanchor[:, :, 1::2, :] = y[:, :, 1::2, 1::2]
59 | return nonanchor
60 |
61 | def ckbd_anchor_unsequeeze(anchor):
62 | B, C, H, W = anchor.shape
63 | y_anchor = torch.zeros([B, C, H, W * 2]).to(anchor.device)
64 | y_anchor[:, :, 0::2, 1::2] = anchor[:, :, 0::2, :]
65 | y_anchor[:, :, 1::2, 0::2] = anchor[:, :, 1::2, :]
66 | return y_anchor
67 |
68 | def ckbd_nonanchor_unsequeeze(nonanchor):
69 | B, C, H, W = nonanchor.shape
70 | y_nonanchor = torch.zeros([B, C, H, W * 2]).to(nonanchor.device)
71 | y_nonanchor[:, :, 0::2, 0::2] = nonanchor[:, :, 0::2, :]
72 | y_nonanchor[:, :, 1::2, 1::2] = nonanchor[:, :, 1::2, :]
73 | return y_nonanchor
74 |
75 |
76 | def compress_anchor(gaussian_conditional:EntropyModel, anchor, scales_anchor, means_anchor, symbols_list, indexes_list):
77 | # squeeze anchor to avoid non-anchor symbols
78 | anchor_squeeze = ckbd_anchor_sequeeze(anchor)
79 | scales_anchor_squeeze = ckbd_anchor_sequeeze(scales_anchor)
80 | means_anchor_squeeze = ckbd_anchor_sequeeze(means_anchor)
81 | indexes = gaussian_conditional.build_indexes(scales_anchor_squeeze)
82 | anchor_hat = gaussian_conditional.quantize(anchor_squeeze, "symbols", means_anchor_squeeze)
83 | symbols_list.extend(anchor_hat.reshape(-1).tolist())
84 | indexes_list.extend(indexes.reshape(-1).tolist())
85 | anchor_hat = ckbd_anchor_unsequeeze(anchor_hat + means_anchor_squeeze)
86 | return anchor_hat
87 |
88 | def compress_nonanchor(gaussian_conditional:EntropyModel, nonanchor, scales_nonanchor, means_nonanchor, symbols_list, indexes_list):
89 | nonanchor_squeeze = ckbd_nonanchor_sequeeze(nonanchor)
90 | scales_nonanchor_squeeze = ckbd_nonanchor_sequeeze(scales_nonanchor)
91 | means_nonanchor_squeeze = ckbd_nonanchor_sequeeze(means_nonanchor)
92 | indexes = gaussian_conditional.build_indexes(scales_nonanchor_squeeze)
93 | nonanchor_hat = gaussian_conditional.quantize(nonanchor_squeeze, "symbols", means_nonanchor_squeeze)
94 | symbols_list.extend(nonanchor_hat.reshape(-1).tolist())
95 | indexes_list.extend(indexes.reshape(-1).tolist())
96 | nonanchor_hat = ckbd_nonanchor_unsequeeze(nonanchor_hat + means_nonanchor_squeeze)
97 | return nonanchor_hat
98 |
99 | def decompress_anchor(gaussian_conditional:EntropyModel, scales_anchor, means_anchor, decoder, cdf, cdf_lengths, offsets):
100 | scales_anchor_squeeze = ckbd_anchor_sequeeze(scales_anchor)
101 | means_anchor_squeeze = ckbd_anchor_sequeeze(means_anchor)
102 | indexes = gaussian_conditional.build_indexes(scales_anchor_squeeze)
103 | anchor_hat = decoder.decode_stream(indexes.reshape(-1).tolist(), cdf, cdf_lengths, offsets)
104 | anchor_hat = torch.Tensor(anchor_hat).reshape(scales_anchor_squeeze.shape).to(scales_anchor.device) + means_anchor_squeeze
105 | anchor_hat = ckbd_anchor_unsequeeze(anchor_hat)
106 | return anchor_hat
107 |
108 | def decompress_nonanchor(gaussian_conditional:EntropyModel, scales_nonanchor, means_nonanchor, decoder, cdf, cdf_lengths, offsets):
109 | scales_nonanchor_squeeze = ckbd_nonanchor_sequeeze(scales_nonanchor)
110 | means_nonanchor_squeeze = ckbd_nonanchor_sequeeze(means_nonanchor)
111 | indexes = gaussian_conditional.build_indexes(scales_nonanchor_squeeze)
112 | nonanchor_hat = decoder.decode_stream(indexes.reshape(-1).tolist(), cdf, cdf_lengths, offsets)
113 | nonanchor_hat = torch.Tensor(nonanchor_hat).reshape(scales_nonanchor_squeeze.shape).to(scales_nonanchor.device) + means_nonanchor_squeeze
114 | nonanchor_hat = ckbd_nonanchor_unsequeeze(nonanchor_hat)
115 | return nonanchor_hat
116 |
--------------------------------------------------------------------------------
/playground/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import logging
4 | import sys
5 | from PIL import ImageFile, Image
6 | import math
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import torch.nn.functional as F
11 | from torch.utils.tensorboard import SummaryWriter
12 | from torch.utils.data import DataLoader
13 | from torchvision import transforms
14 | from compressai.datasets import ImageFolder
15 | sys.path.append("..")
16 | from utils.logger import setup_logger
17 | from utils.utils import CustomDataParallel, save_checkpoint
18 | from utils.optimizers import configure_optimizers
19 | from utils.training import train_one_epoch
20 | from utils.testing import test_one_epoch
21 | from loss.rd_loss import RateDistortionLoss
22 | from config.args import train_options
23 | from config.config import model_config
24 | from models import *
25 | from utils.newtrain_data import TrainData,dataset_collate
26 | import random
27 |
28 |
29 | def main():
30 | torch.backends.cudnn.benchmark = True
31 | ImageFile.LOAD_TRUNCATED_IMAGES = True
32 | Image.MAX_IMAGE_PIXELS = None
33 |
34 | args = train_options()
35 | config = model_config()
36 |
37 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
38 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
39 |
40 | if args.seed is not None:
41 | seed = args.seed
42 | else:
43 | seed = 100 * random.random()
44 | torch.manual_seed(seed)
45 | random.seed(seed)
46 |
47 | if not os.path.exists(os.path.join('./experiments', args.experiment)):
48 | os.makedirs(os.path.join('./experiments', args.experiment))
49 |
50 | setup_logger('train', os.path.join('./experiments', args.experiment), 'train_' + args.experiment, level=logging.INFO,
51 | screen=True, tofile=True)
52 | setup_logger('val', os.path.join('./experiments', args.experiment), 'val_' + args.experiment, level=logging.INFO,
53 | screen=True, tofile=True)
54 |
55 | logger_train = logging.getLogger('train')
56 | logger_val = logging.getLogger('val')
57 | tb_logger = SummaryWriter(log_dir='./tb_logger/' + args.experiment)
58 |
59 | if not os.path.exists(os.path.join('./experiments', args.experiment, 'checkpoints')):
60 | os.makedirs(os.path.join('./experiments', args.experiment, 'checkpoints'))
61 |
62 |
63 | train_dataset = TrainData("../new_trainvallist.txt")
64 | val_dataset = TrainData("../new_testlist.txt")
65 |
66 | train_dataloader = DataLoader(
67 | train_dataset,
68 | batch_size=args.batch_size,
69 | num_workers=args.num_workers,
70 | shuffle=True,
71 | pin_memory=(device == "cuda"),
72 | collate_fn=dataset_collate
73 | )
74 |
75 | test_dataloader = DataLoader(
76 | val_dataset,
77 | batch_size=args.test_batch_size,
78 | num_workers=args.num_workers,
79 | shuffle=False,
80 | pin_memory=(device == "cuda"),
81 | collate_fn=dataset_collate
82 | )
83 |
84 | net = MLICPlusPlus(config=config)
85 | if args.cuda and torch.cuda.device_count() > 1:
86 | net = CustomDataParallel(net)
87 | net = net.to(device)
88 | optimizer, aux_optimizer = configure_optimizers(net, args)
89 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 100], gamma=0.1)
90 | criterion = RateDistortionLoss(lmbda=args.lmbda, metrics=args.metrics)
91 |
92 | if args.checkpoint != None:
93 | checkpoint = torch.load(args.checkpoint)
94 | # new_ckpt = modify_checkpoint(checkpoint['state_dict'])
95 | net.load_state_dict(checkpoint['state_dict'])
96 | start_epoch = 0
97 | best_loss = 1e10
98 | current_step = 0
99 | # optimizer.load_state_dict(checkpoint['optimizer'])
100 | # aux_optimizer.load_state_dict(checkpoint['aux_optimizer'])
101 | # lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
102 | # lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[450,550], gamma=0.1)
103 | # lr_scheduler._step_count = checkpoint['lr_scheduler']['_step_count']
104 | # lr_scheduler.last_epoch = checkpoint['lr_scheduler']['last_epoch']
105 | # print(lr_scheduler.state_dict())
106 | # start_epoch = checkpoint['epoch']
107 | # best_loss = checkpoint['loss']
108 | # current_step = start_epoch * math.ceil(len(train_dataloader.dataset) / args.batch_size)
109 | checkpoint = None
110 | else:
111 | start_epoch = 0
112 | best_loss = 1e10
113 | current_step = 0
114 |
115 | # start_epoch = 0
116 | # best_loss = 1e10
117 | # current_step = 0
118 |
119 | logger_train.info(args)
120 | logger_train.info(config)
121 | logger_train.info(net)
122 | logger_train.info(optimizer)
123 | optimizer.param_groups[0]['lr'] = args.learning_rate
124 | for epoch in range(start_epoch, args.epochs):
125 | logger_train.info(f"Learning rate: {optimizer.param_groups[0]['lr']}")
126 | current_step = train_one_epoch(
127 | net,
128 | criterion,
129 | train_dataloader,
130 | optimizer,
131 | aux_optimizer,
132 | epoch,
133 | args.clip_max_norm,
134 | logger_train,
135 | tb_logger,
136 | current_step
137 | )
138 |
139 | save_dir = os.path.join('./experiments', args.experiment, 'val_images', '%03d' % (epoch + 1))
140 | loss = test_one_epoch(epoch, test_dataloader, net, criterion, save_dir, logger_val, tb_logger)
141 |
142 | lr_scheduler.step()
143 | is_best = loss < best_loss
144 | best_loss = min(loss, best_loss)
145 |
146 | net.update(force=True)
147 | if args.save:
148 | save_checkpoint(
149 | {
150 | "epoch": epoch + 1,
151 | "state_dict": net.state_dict(),
152 | "loss": loss,
153 | "optimizer": optimizer.state_dict(),
154 | "aux_optimizer": aux_optimizer.state_dict(),
155 | "lr_scheduler": lr_scheduler.state_dict(),
156 | },
157 | is_best,
158 | os.path.join('./experiments', args.experiment, 'checkpoints', "checkpoint_%03d.pth.tar" % (epoch + 1))
159 | )
160 | if is_best:
161 | logger_val.info('best checkpoint saved.')
162 |
163 | if __name__ == '__main__':
164 | main()
165 |
--------------------------------------------------------------------------------
/playground/warmup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import logging
4 | from PIL import ImageFile, Image
5 | import math
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | import torch.nn.functional as F
10 | from torch.utils.tensorboard import SummaryWriter
11 | from torch.utils.data import DataLoader
12 | from torchvision import transforms
13 | from transformers import get_linear_schedule_with_warmup
14 | from compressai.datasets import ImageFolder
15 | from utils.logger import setup_logger
16 | from utils.utils import CustomDataParallel, save_checkpoint
17 | from utils.optimizers import configure_optimizers
18 | from utils.training import warmup_one_epoch
19 | from utils.testing import test_one_epoch
20 | from loss.rd_loss import RateDistortionLoss
21 | from config.args import train_options
22 | from config.config import model_config
23 | from models import *
24 | import random
25 |
26 |
27 | def main():
28 | torch.backends.cudnn.benchmark = True
29 | ImageFile.LOAD_TRUNCATED_IMAGES = True
30 | Image.MAX_IMAGE_PIXELS = None
31 |
32 | args = train_options()
33 | config = model_config()
34 |
35 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
36 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
37 |
38 | if args.seed is not None:
39 | # seed = 100 * random.random()
40 | seed = args.seed
41 | torch.manual_seed(seed)
42 | random.seed(seed)
43 |
44 | if not os.path.exists(os.path.join('./experiments', args.experiment)):
45 | os.makedirs(os.path.join('./experiments', args.experiment))
46 |
47 | setup_logger('train', os.path.join('./experiments', args.experiment), 'train_' + args.experiment, level=logging.INFO,
48 | screen=True, tofile=True)
49 | setup_logger('val', os.path.join('./experiments', args.experiment), 'val_' + args.experiment, level=logging.INFO,
50 | screen=True, tofile=True)
51 |
52 | logger_train = logging.getLogger('train')
53 | logger_val = logging.getLogger('val')
54 | tb_logger = SummaryWriter(log_dir='./tb_logger/' + args.experiment)
55 |
56 | if not os.path.exists(os.path.join('./experiments', args.experiment, 'checkpoints')):
57 | os.makedirs(os.path.join('./experiments', args.experiment, 'checkpoints'))
58 |
59 | train_transforms = transforms.Compose(
60 | [transforms.RandomCrop(args.patch_size), transforms.ToTensor()]
61 | )
62 | test_transforms = transforms.Compose(
63 | [transforms.ToTensor()]
64 | )
65 |
66 | train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms)
67 | test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms)
68 |
69 | train_dataloader = DataLoader(
70 | train_dataset,
71 | batch_size=args.batch_size,
72 | num_workers=args.num_workers,
73 | shuffle=True,
74 | pin_memory=(device == "cuda"),
75 | )
76 |
77 | test_dataloader = DataLoader(
78 | test_dataset,
79 | batch_size=args.test_batch_size,
80 | num_workers=args.num_workers,
81 | shuffle=False,
82 | pin_memory=(device == "cuda"),
83 | )
84 |
85 | net = MLICPlusPlus(config=config)
86 | net = torch.compile(net)
87 | if args.cuda and torch.cuda.device_count() > 1:
88 | net = CustomDataParallel(net)
89 | net = net.to(device)
90 | optimizer, aux_optimizer = configure_optimizers(net, args)
91 | # lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min")
92 | warmup_steps = len(train_dataloader) * 1
93 | total_steps = len(train_dataloader) * 150
94 | lr_scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)
95 | criterion = RateDistortionLoss(lmbda=args.lmbda, metrics=args.metrics)
96 |
97 | if args.checkpoint != None:
98 | checkpoint = torch.load(args.checkpoint)
99 | net.load_state_dict(checkpoint["state_dict"])
100 | optimizer.load_state_dict(checkpoint['optimizer'])
101 | aux_optimizer.load_state_dict(checkpoint['aux_optimizer'])
102 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
103 | # lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[450,550], gamma=0.1)
104 | # lr_scheduler._step_count = checkpoint['lr_scheduler']['_step_count']
105 | # lr_scheduler.last_epoch = checkpoint['lr_scheduler']['last_epoch']
106 | # print(lr_scheduler.state_dict())
107 | start_epoch = checkpoint['epoch']
108 | best_loss = checkpoint['loss']
109 | current_step = start_epoch * math.ceil(len(train_dataloader.dataset) / args.batch_size)
110 | checkpoint = None
111 | else:
112 | start_epoch = 0
113 | best_loss = 1e10
114 | current_step = 0
115 |
116 | # start_epoch = 0
117 | # best_loss = 1e10
118 | # current_step = 0
119 |
120 | logger_train.info(args)
121 | logger_train.info(config)
122 | logger_train.info(net)
123 | logger_train.info(optimizer)
124 | for epoch in range(start_epoch, args.epochs):
125 | logger_train.info(f"Learning rate: {optimizer.param_groups[0]['lr']}")
126 | current_step = warmup_one_epoch(
127 | net,
128 | criterion,
129 | train_dataloader,
130 | optimizer,
131 | aux_optimizer,
132 | epoch,
133 | args.clip_max_norm,
134 | logger_train,
135 | tb_logger,
136 | current_step,
137 | lr_scheduler
138 | )
139 |
140 | save_dir = os.path.join('./experiments', args.experiment, 'val_images', '%03d' % (epoch + 1))
141 | loss = test_one_epoch(epoch, test_dataloader, net, criterion, save_dir, logger_val, tb_logger)
142 |
143 | is_best = loss < best_loss
144 | best_loss = min(loss, best_loss)
145 |
146 | net.update(force=True)
147 | if args.save:
148 | save_checkpoint(
149 | {
150 | "epoch": epoch + 1,
151 | "state_dict": net.state_dict(),
152 | "loss": loss,
153 | "optimizer": optimizer.state_dict(),
154 | "aux_optimizer": aux_optimizer.state_dict(),
155 | "lr_scheduler": lr_scheduler.state_dict(),
156 | },
157 | is_best,
158 | os.path.join('./experiments', args.experiment, 'checkpoints', "checkpoint_%03d.pth.tar" % (epoch + 1))
159 | )
160 | if is_best:
161 | logger_val.info('best checkpoint saved.')
162 |
163 | if __name__ == '__main__':
164 | main()
165 |
--------------------------------------------------------------------------------
/modules/layers/res_blk.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor
4 | from compressai.layers import GDN, subpel_conv3x3
5 | from modules.layers.conv import conv1x1, conv3x3
6 |
7 |
8 | class AttentionBlock(nn.Module):
9 | """Self attention block.
10 |
11 | Simplified variant from `"Learned Image Compression with
12 | Discretized Gaussian Mixture Likelihoods and Attention Modules"
13 | `_, by Zhengxue Cheng, Heming Sun, Masaru
14 | Takeuchi, Jiro Katto.
15 |
16 | Args:
17 | N (int): Number of channels)
18 | """
19 |
20 | def __init__(self, N: int):
21 | super().__init__()
22 |
23 | class ResidualUnit(nn.Module):
24 | """Simple residual unit."""
25 |
26 | def __init__(self):
27 | super().__init__()
28 | self.conv = nn.Sequential(
29 | conv1x1(N, N // 2),
30 | nn.GELU(),
31 | conv3x3(N // 2, N // 2),
32 | nn.GELU(),
33 | conv1x1(N // 2, N),
34 | )
35 | self.relu = nn.GELU()
36 |
37 | def forward(self, x: Tensor) -> Tensor:
38 | identity = x
39 | out = self.conv(x)
40 | out += identity
41 | out = self.relu(out)
42 | return out
43 |
44 | self.conv_a = nn.Sequential(ResidualUnit(), ResidualUnit(), ResidualUnit())
45 |
46 | self.conv_b = nn.Sequential(
47 | ResidualUnit(),
48 | ResidualUnit(),
49 | ResidualUnit(),
50 | conv1x1(N, N),
51 | )
52 |
53 | def forward(self, x: Tensor) -> Tensor:
54 | identity = x
55 | a = self.conv_a(x)
56 | b = self.conv_b(x)
57 | out = a * torch.sigmoid(b)
58 | out += identity
59 | return out
60 |
61 |
62 | class ResidualBlockWithStride(nn.Module):
63 | """Residual block with a stride on the first convolution.
64 |
65 | Args:
66 | in_ch (int): number of input channels
67 | out_ch (int): number of output channels
68 | stride (int): stride value (default: 2)
69 | """
70 |
71 | def __init__(self, in_ch: int, out_ch: int, stride: int = 2):
72 | super().__init__()
73 | self.conv1 = conv3x3(in_ch, out_ch, stride=stride)
74 | self.act = nn.GELU()
75 | self.conv2 = conv3x3(out_ch, out_ch)
76 | self.gdn = GDN(out_ch)
77 | if stride != 1 or in_ch != out_ch:
78 | self.skip = conv1x1(in_ch, out_ch, stride=stride)
79 | else:
80 | self.skip = None
81 |
82 | def forward(self, x):
83 | identity = x
84 | out = self.conv1(x)
85 | out = self.act(out)
86 | out = self.conv2(out)
87 | out = self.gdn(out)
88 |
89 | if self.skip is not None:
90 | identity = self.skip(x)
91 |
92 | out += identity
93 | return out
94 |
95 |
96 | class ResidualBlockUpsample(nn.Module):
97 | """Residual block with sub-pixel upsampling on the last convolution.
98 |
99 | Args:
100 | in_ch (int): number of input channels
101 | out_ch (int): number of output channels
102 | upsample (int): upsampling factor (default: 2)
103 | """
104 |
105 | def __init__(self, in_ch: int, out_ch: int, upsample: int = 2):
106 | super().__init__()
107 | self.subpel_conv = subpel_conv3x3(in_ch, out_ch, upsample)
108 | self.act = nn.GELU()
109 | self.conv = conv3x3(out_ch, out_ch)
110 | self.igdn = GDN(out_ch, inverse=True)
111 | self.upsample = subpel_conv3x3(in_ch, out_ch, upsample)
112 |
113 | def forward(self, x):
114 | identity = x
115 | out = self.subpel_conv(x)
116 | out = self.act(out)
117 | out = self.conv(out)
118 | out = self.igdn(out)
119 | identity = self.upsample(x)
120 | out += identity
121 | return out
122 |
123 |
124 | class ResidualBlock(nn.Module):
125 | """Simple residual block with two 3x3 convolutions.
126 |
127 | Args:
128 | in_ch (int): number of input channels
129 | out_ch (int): number of output channels
130 | """
131 |
132 | def __init__(self, in_ch: int, out_ch: int):
133 | super().__init__()
134 | self.conv1 = conv3x3(in_ch, out_ch)
135 | self.act = nn.GELU()
136 | self.conv2 = conv3x3(out_ch, out_ch)
137 | if in_ch != out_ch:
138 | self.skip = conv1x1(in_ch, out_ch)
139 | else:
140 | self.skip = None
141 |
142 | def forward(self, x):
143 | identity = x
144 |
145 | out = self.conv1(x)
146 | out = self.act(out)
147 | out = self.conv2(out)
148 | out = self.act(out)
149 |
150 | if self.skip is not None:
151 | identity = self.skip(x)
152 |
153 | out = out + identity
154 | return out
155 |
156 |
157 | class ResidualBottleneck(nn.Module):
158 | def __init__(self, N=192, act=nn.GELU, groups=1) -> None:
159 | super().__init__()
160 | self.branch = nn.Sequential(
161 | conv1x1(N, N // 2),
162 | act(),
163 | nn.Conv2d(N // 2, N // 2, kernel_size=3, stride=1, padding=1),
164 | act(),
165 | conv1x1(N // 2, N)
166 | )
167 |
168 | def forward(self, x):
169 | out = x + self.branch(x)
170 |
171 | return out
172 |
173 | class PCDM(nn.Module):
174 | """Simple residual block with two 3x3 convolutions.
175 |
176 | Args:
177 | in_ch (int): number of input channels
178 | out_ch (int): number of output channels
179 | """
180 |
181 | def __init__(self, in_ch: int, out_ch: int):
182 | super().__init__()
183 | self.conv1 = conv3x3(in_ch, out_ch)
184 | self.act = nn.GELU()
185 | self.conv2 = conv3x3(out_ch, out_ch)
186 | if in_ch != out_ch:
187 | self.skip = conv1x1(in_ch, out_ch)
188 | else:
189 | self.skip = None
190 | self._proj1 = nn.Sequential(
191 | nn.Linear(512, out_ch, bias=False)
192 | )
193 | self._proj2 = nn.Sequential(
194 | nn.Linear(512, out_ch, bias=False)
195 | )
196 |
197 | def forward(self, x):
198 | x, fourier_features_mlp = x
199 | identity = x
200 |
201 | out = self.conv1(x)
202 | out = self.act(out)
203 |
204 | proj1 = self._proj1(fourier_features_mlp)
205 | out += proj1[:, :, None, None]
206 |
207 | out = self.conv2(out)
208 | out = self.act(out)
209 |
210 | proj2 = self._proj2(fourier_features_mlp)
211 | out += proj2[:, :, None, None]
212 |
213 |
214 | if self.skip is not None:
215 | identity = self.skip(x)
216 |
217 | out = out + identity
218 | return out
--------------------------------------------------------------------------------
/playground/train_mask.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import random
4 | import logging
5 | from PIL import ImageFile, Image
6 | import math
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import torch.nn.functional as F
11 | from torch.utils.tensorboard import SummaryWriter
12 | from torch.utils.data import DataLoader
13 | from torchvision import transforms
14 | from compressai.datasets import ImageFolder
15 | import sys
16 |
17 |
18 |
19 | sys.path.append("..")
20 | from models.mlicpp_con import ConditionalMLICPlusPlus
21 | from clip.Closs import CLIPConvLoss
22 | from utils.logger import setup_logger
23 | from utils.newtrain_data import TrainData, dataset_collate, TrainDataWithMask
24 | from utils.utils import CustomDataParallel, save_checkpoint
25 | from utils.optimizers import configure_optimizers, configure_optimizer
26 | from utils.training import train_one_epoch, train_clip_one_epoch, train_conditional_one_epoch, train_mask_conditional_one_epoch
27 | from utils.testing import test_one_epoch, test_clip_one_epoch, test_conditional_one_epoch, newtest_conditional_one_epoch
28 | from loss.rd_loss import RateDistortionLoss
29 | from config.args import train_options
30 | from config.config import model_config
31 | from models import *
32 | import random
33 |
34 |
35 | def main():
36 | torch.backends.cudnn.benchmark = True
37 | ImageFile.LOAD_TRUNCATED_IMAGES = True
38 | Image.MAX_IMAGE_PIXELS = None
39 |
40 | args = train_options()
41 | config = model_config()
42 |
43 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
44 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
45 |
46 | if args.seed is not None:
47 | # seed = args.seed
48 | # else:
49 | seed = 100 * random.random()
50 | torch.manual_seed(seed)
51 | random.seed(seed)
52 |
53 | if not os.path.exists(os.path.join('./experiments', args.experiment)):
54 | os.makedirs(os.path.join('./experiments', args.experiment))
55 |
56 | setup_logger('train', os.path.join('./experiments', args.experiment), 'train_' + args.experiment, level=logging.INFO,
57 | screen=True, tofile=True)
58 | setup_logger('val', os.path.join('./experiments', args.experiment), 'val_' + args.experiment, level=logging.INFO,
59 | screen=True, tofile=True)
60 |
61 | logger_train = logging.getLogger('train')
62 | logger_val = logging.getLogger('val')
63 | tb_logger = SummaryWriter(log_dir='./tb_logger/' + args.experiment)
64 |
65 | if not os.path.exists(os.path.join('./experiments', args.experiment, 'checkpoints')):
66 | os.makedirs(os.path.join('./experiments', args.experiment, 'checkpoints'))
67 |
68 | train_dataset = TrainDataWithMask("../../../coco2voc/coco2voc/ImageSets/Main/trainval.txt")
69 | test_dataset = TrainData("../../../coco2voc/coco2voc/ImageSets/Main/test.txt")
70 | train_dataloader = DataLoader(
71 | train_dataset,
72 | batch_size=args.batch_size,
73 | num_workers=args.num_workers,
74 | shuffle=True,
75 | pin_memory=(device == "cuda"),
76 | )
77 |
78 | test_dataloader = DataLoader(
79 | test_dataset,
80 | batch_size=args.test_batch_size,
81 | num_workers=args.num_workers,
82 | shuffle=False,
83 | pin_memory=(device == "cuda"),
84 | )
85 |
86 | net = ConditionalMLICPlusPlus(config=config)
87 | if args.cuda and torch.cuda.device_count() > 1:
88 | net = CustomDataParallel(net)
89 | net = net.to(device)
90 | optimizer, aux_optimizer = configure_optimizers(net, args)
91 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 100], gamma=0.1)
92 | criterion = RateDistortionLoss(lmbda=args.lmbda, metrics=args.metrics)
93 | clip_loss = CLIPConvLoss(device)
94 |
95 | if args.checkpoint != None:
96 | if args.tune:
97 | checkpoint = torch.load(args.checkpoint)
98 | # new_ckpt = modify_checkpoint(checkpoint['state_dict'])
99 | net.load_state_dict(checkpoint['state_dict'])
100 | for name, param in net.named_parameters():
101 | if not ('g_s' in name or 'global_cond' in name):
102 | param.requires_grad = False
103 | # 检查哪些参数被冻结,哪些参数是可训练的
104 | # for name, param in net.named_parameters():
105 | # print(f"{name}: requires_grad={param.requires_grad}")
106 | optimizer = configure_optimizer(net, args)
107 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1)
108 | start_epoch = 0
109 | best_loss = 1e10
110 | current_step = 0
111 | else:
112 | start_epoch = 0
113 | best_loss = 1e10
114 | current_step = 0
115 |
116 | # start_epoch = 0
117 | # best_loss = 1e10
118 | # current_step = 0
119 |
120 | logger_train.info(args)
121 | logger_train.info(config)
122 | logger_train.info(net)
123 | logger_train.info(optimizer)
124 | optimizer.param_groups[0]['lr'] = args.learning_rate
125 | save_dir = os.path.join('./experiments', args.experiment, 'val_images', '%03d' % (1 + 1))
126 | for epoch in range(start_epoch, args.epochs):
127 | logger_train.info(f"Learning rate: {optimizer.param_groups[0]['lr']}")
128 | current_step = train_mask_conditional_one_epoch(
129 | net,
130 | criterion,
131 | train_dataloader,
132 | optimizer,
133 | aux_optimizer,
134 | args.tune,
135 | epoch,
136 | args.clip_max_norm,
137 | logger_train,
138 | tb_logger,
139 | current_step,
140 | clip_loss,
141 | args.lambda_clip,
142 | args.lmbda,
143 | args.lambda_beta1,
144 | )
145 |
146 |
147 | loss = test_conditional_one_epoch(epoch, test_dataloader, net, criterion, save_dir, logger_val, tb_logger, clip_loss, args.lambda_clip)
148 |
149 | lr_scheduler.step()
150 | is_best = loss < best_loss
151 | best_loss = min(loss, best_loss)
152 |
153 | net.update(force=True)
154 | if args.save:
155 | save_checkpoint(
156 | {
157 | "epoch": epoch + 1,
158 | "state_dict": net.state_dict(),
159 | "loss": loss,
160 | "optimizer": optimizer.state_dict(),
161 | "aux_optimizer": aux_optimizer.state_dict(),
162 | "lr_scheduler": lr_scheduler.state_dict(),
163 | },
164 | is_best,
165 | os.path.join('./experiments', args.experiment, 'checkpoints', "checkpoint_%03d.pth.tar" % (epoch + 1))
166 | )
167 | if is_best:
168 | logger_val.info('best checkpoint saved.')
169 |
170 | if __name__ == '__main__':
171 | main()
--------------------------------------------------------------------------------
/utils/train_data.py:
--------------------------------------------------------------------------------
1 | """
2 | paper: GridDehazeNet: Attention-Based Multi-Scale Network for Image Dehazing
3 | file: train_data.py
4 | about: build the training dataset
5 | author: Xiaohong Liu
6 | date: 01/08/19
7 | """
8 |
9 | # --- Imports --- #
10 | import os
11 | import random
12 |
13 | import numpy as np
14 | import torch
15 | import torch.utils.data as data
16 | from PIL import Image
17 | import torch.nn.functional as F
18 |
19 | from torchvision.transforms import Compose, ToTensor, Resize, RandomCrop
20 |
21 |
22 | def compute_padding(in_h: int, in_w: int, *, out_h=None, out_w=None, min_div=1):
23 | """Returns tuples for padding and unpadding.
24 |
25 | Args:
26 | in_h: Input height.
27 | in_w: Input width.
28 | out_h: Output height.
29 | out_w: Output width.
30 | min_div: Length that output dimensions should be divisible by.
31 | """
32 | if out_h is None:
33 | out_h = (in_h + min_div - 1) // min_div * min_div
34 | if out_w is None:
35 | out_w = (in_w + min_div - 1) // min_div * min_div
36 |
37 | if out_h % min_div != 0 or out_w % min_div != 0:
38 | raise ValueError(
39 | f"Padded output height and width are not divisible by min_div={min_div}."
40 | )
41 |
42 | left = (out_w - in_w) // 2
43 | right = out_w - in_w - left
44 | top = (out_h - in_h) // 2
45 | bottom = out_h - in_h - top
46 |
47 | pad = (left, right, top, bottom)
48 | unpad = (-left, -right, -top, -bottom)
49 |
50 | return pad, unpad
51 |
52 | def pad(x, p=2**6):
53 | h, w = x.size(1), x.size(2)
54 | pad, _ = compute_padding(h, w, min_div=p)
55 | return F.pad(x, pad, mode="constant", value=0)
56 |
57 |
58 |
59 | class TestData(data.Dataset):
60 | def __init__(self, train_data_dir):
61 | super().__init__()
62 | train_list = train_data_dir
63 | with open(train_list) as f:
64 | txt = f.readlines()
65 | # annotations = [line.strip() for line in txt if len(line.strip().split()[1:]) != 0]
66 | annotations = [line.strip() for line in txt]
67 |
68 | self.annotations = annotations
69 | self.train_data_dir = train_data_dir
70 |
71 | def get_images(self, index):
72 | line = self.annotations[index].split()
73 | image_path = line[0]
74 | # print(image_path)
75 | img_name = image_path.split('/')[-1]
76 | # print(img_name)
77 | image_name = img_name.split('.')[0]
78 | # print(image_name)
79 | gt_name = image_name
80 | pgt_name = image_name
81 |
82 | # --- Transform to tensor --- #
83 | transform = Compose([ToTensor()])
84 |
85 | gts = []
86 | haze_img = Image.open("../../yolov3/" + 'VOCdevkit/VOC2007/JPEGImages/' + gt_name +'.jpg' )
87 | gt = transform(haze_img)
88 | # --- Check the channel is 3 or not --- #
89 | if list(gt.shape)[0] is not 3 :
90 | raise Exception('Bad image channel: {}'.format(gt_name))
91 | gts.append(gt)
92 |
93 | names = []
94 | names.append(pgt_name)
95 |
96 | return gts, names
97 |
98 | def __getitem__(self, index):
99 | res = self.get_images(index)
100 | return res
101 |
102 | def __len__(self):
103 | return len(self.annotations)
104 |
105 |
106 | class TestImagenet(data.Dataset):
107 | def __init__(self, train_data_dir):
108 | super().__init__()
109 | train_list = train_data_dir
110 | with open(train_list) as f:
111 | txt = f.readlines()
112 | # annotations = [line.strip() for line in txt if len(line.strip().split()[1:]) != 0]
113 | annotations = [line.strip() for line in txt]
114 |
115 | self.annotations = annotations
116 | self.train_data_dir = train_data_dir
117 |
118 | def get_images(self, index):
119 | image_path = self.annotations[index]
120 | # print(image_path)
121 | img_name = image_path.split('/')[-1]
122 | # print(img_name)
123 | image_name = img_name.split('.')[0]
124 | # print(image_name)
125 | gt_name = image_name
126 | pgt_name = image_name
127 |
128 | # --- Transform to tensor --- #
129 | transform = Compose([ToTensor()])
130 |
131 | gts = []
132 | haze_img = Image.open("../../data/liuquan/Imagenet_val25k/" + image_path)
133 | haze_img = haze_img.convert('RGB')
134 | gt = transform(haze_img)
135 | # --- Check the channel is 3 or not --- #
136 | if list(gt.shape)[0] is not 3 :
137 | raise Exception('Bad image channel: {}'.format(gt_name))
138 | gts.append(gt)
139 |
140 | names = []
141 | names.append(pgt_name)
142 |
143 | return gts, names
144 |
145 | def __getitem__(self, index):
146 | res = self.get_images(index)
147 | return res
148 |
149 | def __len__(self):
150 | return len(self.annotations)
151 |
152 | class NewTestData(data.Dataset):
153 | def __init__(self, img_dir):
154 | super().__init__()
155 | self.img_dir = img_dir
156 | self.img_labels = [f for f in os.listdir(img_dir) if f.endswith(".JPEG") or f.endswith(".png") or f.endswith(".jpg")]
157 |
158 | def get_images(self, index):
159 | img_label=self.img_labels[index]
160 | image_name = img_label.split('.')[0]
161 | img_path = os.path.join(self.img_dir, img_label)
162 |
163 | # --- Transform to tensor --- #
164 | transform = Compose([ToTensor()])
165 |
166 | gts = []
167 | haze_img = Image.open(img_path)
168 | haze_img = haze_img.convert('RGB')
169 | gt = transform(haze_img)
170 | # --- Check the channel is 3 or not --- #
171 | if list(gt.shape)[0] is not 3 :
172 | raise Exception('Bad image channel: {}'.format(image_name))
173 | gts.append(gt)
174 |
175 | names = []
176 | names.append(image_name)
177 |
178 | return gts, names
179 |
180 | def __getitem__(self, index):
181 | res = self.get_images(index)
182 | return res
183 |
184 | def __len__(self):
185 | return len(self.img_labels)
186 |
187 |
188 | # DataLoader中collate_fn使用
189 | def dataset_collate(batch):
190 | gts = []
191 | pgts = []
192 | for gt, pgt in batch:
193 | gts.append(gt[0])
194 | pgts.append(pgt[0])
195 | gts = torch.from_numpy(np.array([item.numpy() for item in gts])).type(torch.FloatTensor)
196 | pgts = torch.from_numpy(np.array([item.numpy() for item in pgts])).type(torch.FloatTensor)
197 |
198 | return gts, pgts
199 |
200 | def testDataset_collate(batch):
201 | gts = []
202 | names = []
203 | for gt, name in batch:
204 | gts.append(gt[0])
205 | names.append(name[0])
206 | gts = torch.from_numpy(np.array([item.numpy() for item in gts])).type(torch.FloatTensor)
207 | return gts, names
--------------------------------------------------------------------------------
/playground/train_condi.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import logging
4 | from PIL import ImageFile, Image
5 | import math
6 | import torch
7 | import torch.nn as nn
8 | import torch.optim as optim
9 | import torch.nn.functional as F
10 | from torch.utils.tensorboard import SummaryWriter
11 | from torch.utils.data import DataLoader
12 | from torchvision import transforms
13 | from compressai.datasets import ImageFolder
14 | import sys
15 |
16 |
17 |
18 | sys.path.append("..")
19 | from models.mlicpp_con import ConditionalMLICPlusPlus
20 | from clip.Closs import CLIPConvLoss
21 | from utils.logger import setup_logger
22 | from utils.newtrain_data import TrainData, dataset_collate
23 | from utils.utils import CustomDataParallel, save_checkpoint
24 | from utils.optimizers import configure_optimizers, configure_optimizer
25 | from utils.training import train_one_epoch, train_clip_one_epoch, train_conditional_one_epoch
26 | from utils.testing import test_one_epoch, test_clip_one_epoch, test_conditional_one_epoch, newtest_conditional_one_epoch
27 | from loss.rd_loss import RateDistortionLoss
28 | from config.args import train_options
29 | from config.config import model_config
30 | from models import *
31 | import random
32 |
33 |
34 | def main():
35 | torch.backends.cudnn.benchmark = True
36 | ImageFile.LOAD_TRUNCATED_IMAGES = True
37 | Image.MAX_IMAGE_PIXELS = None
38 |
39 | args = train_options()
40 | config = model_config()
41 |
42 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
43 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu"
44 |
45 | if args.seed is not None:
46 | # seed = args.seed
47 | # else:
48 | seed = 100 * random.random()
49 | torch.manual_seed(seed)
50 | random.seed(seed)
51 |
52 | if not os.path.exists(os.path.join('./experiments', args.experiment)):
53 | os.makedirs(os.path.join('./experiments', args.experiment))
54 |
55 | setup_logger('train', os.path.join('./experiments', args.experiment), 'train_' + args.experiment, level=logging.INFO,
56 | screen=True, tofile=True)
57 | setup_logger('val', os.path.join('./experiments', args.experiment), 'val_' + args.experiment, level=logging.INFO,
58 | screen=True, tofile=True)
59 |
60 | logger_train = logging.getLogger('train')
61 | logger_val = logging.getLogger('val')
62 | tb_logger = SummaryWriter(log_dir='./tb_logger/' + args.experiment)
63 |
64 | if not os.path.exists(os.path.join('./experiments', args.experiment, 'checkpoints')):
65 | os.makedirs(os.path.join('./experiments', args.experiment, 'checkpoints'))
66 |
67 | train_transforms = transforms.Compose(
68 | [transforms.RandomCrop(args.patch_size), transforms.ToTensor()]
69 | )
70 | test_transforms = transforms.Compose(
71 | [transforms.ToTensor()]
72 | )
73 |
74 | train_dataset = TrainData("../../../coco2voc/coco2voc/ImageSets/Main/trainval.txt")
75 | test_dataset = TrainData("../../../coco2voc/coco2voc/ImageSets/Main/test.txt")
76 | train_dataloader = DataLoader(
77 | train_dataset,
78 | batch_size=args.batch_size,
79 | num_workers=args.num_workers,
80 | shuffle=True,
81 | pin_memory=(device == "cuda"),
82 | )
83 |
84 | test_dataloader = DataLoader(
85 | test_dataset,
86 | batch_size=args.test_batch_size,
87 | num_workers=args.num_workers,
88 | shuffle=False,
89 | pin_memory=(device == "cuda"),
90 | )
91 |
92 | net = ConditionalMLICPlusPlus(config=config)
93 | if args.cuda and torch.cuda.device_count() > 1:
94 | net = CustomDataParallel(net)
95 | net = net.to(device)
96 | optimizer, aux_optimizer = configure_optimizers(net, args)
97 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 100], gamma=0.1)
98 | criterion = RateDistortionLoss(lmbda=args.lmbda, metrics=args.metrics)
99 | clip_loss = CLIPConvLoss(device)
100 |
101 | if args.checkpoint != None:
102 | if args.tune:
103 | checkpoint = torch.load(args.checkpoint)
104 | # new_ckpt = modify_checkpoint(checkpoint['state_dict'])
105 | net.load_state_dict(checkpoint['state_dict'])
106 | for name, param in net.named_parameters():
107 | if not ('g_s' in name or 'global_cond' in name):
108 | param.requires_grad = False
109 | # 检查哪些参数被冻结,哪些参数是可训练的
110 | # for name, param in net.named_parameters():
111 | # print(f"{name}: requires_grad={param.requires_grad}")
112 | optimizer = configure_optimizer(net, args)
113 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1)
114 | start_epoch = 0
115 | best_loss = 1e10
116 | current_step = 0
117 | else:
118 | checkpoint = torch.load(args.checkpoint)
119 | # new_ckpt = modify_checkpoint(checkpoint['state_dict'])
120 | net.load_state_dict(checkpoint['state_dict'])
121 | optimizer.load_state_dict(checkpoint['optimizer'])
122 | aux_optimizer.load_state_dict(checkpoint['aux_optimizer'])
123 | # lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
124 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[450, 550], gamma=0.1)
125 | # lr_scheduler._step_count = checkpoint['lr_scheduler']['_step_count']
126 | # lr_scheduler.last_epoch = checkpoint['lr_scheduler']['last_epoch']
127 | # print(lr_scheduler.state_dict())
128 | start_epoch = checkpoint['epoch']
129 | best_loss = checkpoint['loss']
130 | current_step = start_epoch * math.ceil(len(train_dataloader.dataset) / args.batch_size)
131 | checkpoint = None
132 |
133 | else:
134 | start_epoch = 0
135 | best_loss = 1e10
136 | current_step = 0
137 |
138 | # start_epoch = 0
139 | # best_loss = 1e10
140 | # current_step = 0
141 |
142 | logger_train.info(args)
143 | logger_train.info(config)
144 | logger_train.info(net)
145 | logger_train.info(optimizer)
146 | optimizer.param_groups[0]['lr'] = args.learning_rate
147 | save_dir = os.path.join('./experiments', args.experiment, 'val_images', '%03d' % (1 + 1))
148 | for epoch in range(start_epoch, args.epochs):
149 | logger_train.info(f"Learning rate: {optimizer.param_groups[0]['lr']}")
150 | current_step = train_conditional_one_epoch(
151 | net,
152 | criterion,
153 | train_dataloader,
154 | optimizer,
155 | aux_optimizer,
156 | args.tune,
157 | epoch,
158 | args.clip_max_norm,
159 | logger_train,
160 | tb_logger,
161 | current_step,
162 | clip_loss,
163 | args.lambda_clip,
164 | args.lmbda,
165 | args.lambda_beta1,
166 | args.lambda_beta2
167 | )
168 |
169 |
170 | loss = test_conditional_one_epoch(epoch, test_dataloader, net, criterion, save_dir, logger_val, tb_logger, clip_loss, args.lambda_clip)
171 |
172 | lr_scheduler.step()
173 | is_best = loss < best_loss
174 | best_loss = min(loss, best_loss)
175 |
176 | net.update(force=True)
177 | if args.save:
178 | save_checkpoint(
179 | {
180 | "epoch": epoch + 1,
181 | "state_dict": net.state_dict(),
182 | "loss": loss,
183 | "optimizer": optimizer.state_dict(),
184 | "aux_optimizer": aux_optimizer.state_dict(),
185 | "lr_scheduler": lr_scheduler.state_dict(),
186 | },
187 | is_best,
188 | os.path.join('./experiments', args.experiment, 'checkpoints', "checkpoint_%03d.pth.tar" % (epoch + 1))
189 | )
190 | if is_best:
191 | logger_val.info('best checkpoint saved.')
192 |
193 | if __name__ == '__main__':
194 | main()
--------------------------------------------------------------------------------
/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 |
13 | from .model import build_model
14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15 | import ssl
16 | ssl._create_default_https_context = ssl._create_unverified_context
17 |
18 | try:
19 | from torchvision.transforms import InterpolationMode
20 | BICUBIC = InterpolationMode.BICUBIC
21 | except ImportError:
22 | BICUBIC = Image.BICUBIC
23 |
24 |
25 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
26 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
27 |
28 |
29 | __all__ = ["available_models", "load", "tokenize"]
30 | _tokenizer = _Tokenizer()
31 |
32 | _MODELS = {
33 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
34 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
35 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
36 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
37 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
38 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
39 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
40 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
41 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
42 | }
43 |
44 |
45 | def _download(url: str, root: str):
46 | os.makedirs(root, exist_ok=True)
47 | filename = os.path.basename(url)
48 |
49 | expected_sha256 = url.split("/")[-2]
50 | download_target = os.path.join(root, filename)
51 |
52 | if os.path.exists(download_target) and not os.path.isfile(download_target):
53 | raise RuntimeError(f"{download_target} exists and is not a regular file")
54 |
55 | if os.path.isfile(download_target):
56 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
57 | return download_target
58 | else:
59 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
60 |
61 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
62 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
63 | while True:
64 | buffer = source.read(8192)
65 | if not buffer:
66 | break
67 |
68 | output.write(buffer)
69 | loop.update(len(buffer))
70 |
71 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
72 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
73 |
74 | return download_target
75 |
76 |
77 | def _convert_image_to_rgb(image):
78 | return image.convert("RGB")
79 |
80 |
81 | def _transform(n_px):
82 | return Compose([
83 | Resize(n_px, interpolation=BICUBIC),
84 | CenterCrop(n_px),
85 | _convert_image_to_rgb,
86 | ToTensor(),
87 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
88 | ])
89 |
90 |
91 | def available_models() -> List[str]:
92 | """Returns the names of available CLIP models"""
93 | return list(_MODELS.keys())
94 |
95 |
96 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
97 | """Load a CLIP model
98 |
99 | Parameters
100 | ----------
101 | name : str
102 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
103 |
104 | device : Union[str, torch.device]
105 | The device to put the loaded model
106 |
107 | jit : bool
108 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
109 |
110 | download_root: str
111 | path to download the model files; by default, it uses "~/.cache/clip"
112 |
113 | Returns
114 | -------
115 | model : torch.nn.Module
116 | The CLIP model
117 |
118 | preprocess : Callable[[PIL.Image], torch.Tensor]
119 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
120 | """
121 | if name in _MODELS:
122 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
123 | elif os.path.isfile(name):
124 | model_path = name
125 | else:
126 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
127 |
128 | with open(model_path, 'rb') as opened_file:
129 | try:
130 | # loading JIT archive
131 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
132 | state_dict = None
133 | except RuntimeError:
134 | # loading saved state dict
135 | if jit:
136 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
137 | jit = False
138 | state_dict = torch.load(opened_file, map_location="cpu")
139 |
140 | if not jit:
141 | model = build_model(state_dict or model.state_dict()).to(device)
142 | if str(device) == "cpu":
143 | model.float()
144 | return model, _transform(model.visual.input_resolution)
145 |
146 | # patch the device names
147 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
148 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
149 |
150 | def _node_get(node: torch._C.Node, key: str):
151 | """Gets attributes of a node which is polymorphic over return type.
152 |
153 | From https://github.com/pytorch/pytorch/pull/82628
154 | """
155 | sel = node.kindOf(key)
156 | return getattr(node, sel)(key)
157 |
158 | def patch_device(module):
159 | try:
160 | graphs = [module.graph] if hasattr(module, "graph") else []
161 | except RuntimeError:
162 | graphs = []
163 |
164 | if hasattr(module, "forward1"):
165 | graphs.append(module.forward1.graph)
166 |
167 | for graph in graphs:
168 | for node in graph.findAllNodes("prim::Constant"):
169 | if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
170 | node.copyAttributes(device_node)
171 |
172 | model.apply(patch_device)
173 | patch_device(model.encode_image)
174 | patch_device(model.encode_text)
175 |
176 | # patch dtype to float32 on CPU
177 | if str(device) == "cpu":
178 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
179 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
180 | float_node = float_input.node()
181 |
182 | def patch_float(module):
183 | try:
184 | graphs = [module.graph] if hasattr(module, "graph") else []
185 | except RuntimeError:
186 | graphs = []
187 |
188 | if hasattr(module, "forward1"):
189 | graphs.append(module.forward1.graph)
190 |
191 | for graph in graphs:
192 | for node in graph.findAllNodes("aten::to"):
193 | inputs = list(node.inputs())
194 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
195 | if _node_get(inputs[i].node(), "value") == 5:
196 | inputs[i].node().copyAttributes(float_node)
197 |
198 | model.apply(patch_float)
199 | patch_float(model.encode_image)
200 | patch_float(model.encode_text)
201 |
202 | model.float()
203 |
204 | return model, _transform(model.input_resolution.item())
205 |
206 |
207 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
208 | """
209 | Returns the tokenized representation of given input string(s)
210 |
211 | Parameters
212 | ----------
213 | texts : Union[str, List[str]]
214 | An input string or a list of input strings to tokenize
215 |
216 | context_length : int
217 | The context length to use; all CLIP models use 77 as the context length
218 |
219 | truncate: bool
220 | Whether to truncate the text in case its encoding is longer than the context length
221 |
222 | Returns
223 | -------
224 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
225 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
226 | """
227 | if isinstance(texts, str):
228 | texts = [texts]
229 |
230 | sot_token = _tokenizer.encoder["<|startoftext|>"]
231 | eot_token = _tokenizer.encoder["<|endoftext|>"]
232 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
233 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
234 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
235 | else:
236 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
237 |
238 | for i, tokens in enumerate(all_tokens):
239 | if len(tokens) > context_length:
240 | if truncate:
241 | tokens = tokens[:context_length]
242 | tokens[-1] = eot_token
243 | else:
244 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
245 | result[i, :len(tokens)] = torch.tensor(tokens)
246 |
247 | return result
248 |
--------------------------------------------------------------------------------
/clip/Closs.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import collections
3 | import torch.nn as nn
4 |
5 | sys.path.append("..")
6 | import clip.clip as clip
7 | import torch
8 | from torchvision import models, transforms
9 |
10 | class CLIPLoss(torch.nn.Module):
11 | def __init__(self, device, num, affine):
12 | super(CLIPLoss, self).__init__()
13 |
14 | self.model, clip_preprocess = clip.load(
15 | 'ViT-B/32', device, jit=False)
16 | self.model.eval()
17 | self.preprocess = transforms.Compose(
18 | [clip_preprocess.transforms[-1], clip_preprocess.transforms[0]]) # clip normalisation
19 | self.device = device
20 | self.NUM_AUGS = num
21 | self.mse = nn.MSELoss()
22 | augemntations = []
23 | if affine:
24 | # augemntations.append(transforms.RandomPerspective(
25 | # fill=0, p=1.0, distortion_scale=0.5))
26 | # augemntations.append(transforms.Resize([224, 224]))
27 | augemntations.append(transforms.Resize([224, 224]))
28 | augemntations.append(
29 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)))
30 | self.augment_trans = transforms.Compose(augemntations)
31 |
32 | self.calc_target = True
33 | self.counter = 0
34 |
35 | def forward(self, sketches, targets, mode="train"):
36 | if self.calc_target:
37 | targets_ = self.preprocess(targets).to(self.device)
38 | self.targets_features = self.model.encode_image(targets_).detach()
39 | self.calc_target = False
40 |
41 | if mode == "eval":
42 | # for regular clip distance, no augmentations
43 | with torch.no_grad():
44 | sketches = self.preprocess(sketches).to(self.device)
45 | sketches_features = self.model.encode_image(sketches)
46 | return 1. - torch.cosine_similarity(sketches_features, self.targets_features)
47 |
48 | loss_clip = 0
49 | sketch_augs = []
50 | img_augs = []
51 | for n in range(self.NUM_AUGS):
52 | # augmented_pair = self.augment_trans(torch.cat([sketches, targets]))
53 | augmented_pair = self.augment_trans(sketches)
54 | sketch_augs.append(augmented_pair[0].unsqueeze(0))
55 | sketch_batch = torch.cat(sketch_augs)
56 | # sketch_utils.plot_batch(img_batch, sketch_batch, self.args, self.counter, use_wandb=False, title="fc_aug{}_iter{}_{}.jpg".format(1, self.counter, mode))
57 | # if self.counter % 100 == 0:
58 | # sketch_utils.plot_batch(img_batch, sketch_batch, self.args, self.counter, use_wandb=False, title="aug{}_iter{}_{}.jpg".format(1, self.counter, mode))
59 |
60 | sketch_features = self.model.encode_image(sketch_batch)
61 |
62 | for n in range(self.NUM_AUGS):
63 | loss_clip += (1. - torch.cosine_similarity(
64 | sketch_features[n:n+1], self.targets_features, dim=1))
65 | # loss_clip += self.mse(sketch_features[n:n+1], self.targets_features)
66 | self.counter += 1
67 | return loss_clip
68 | # return 1. - torch.cosine_similarity(sketches_features, self.targets_features)
69 |
70 |
71 | class CLIPVisualEncoder(nn.Module):
72 | def __init__(self, clip_model):
73 | super().__init__()
74 | self.clip_model = clip_model
75 | self.featuremaps = None
76 |
77 | for i in range(12): # 12 resblocks in VIT visual transformer
78 | self.clip_model.visual.transformer.resblocks[i].register_forward_hook(
79 | self.make_hook(i))
80 |
81 | def make_hook(self, name):
82 | def hook(module, input, output):
83 | if len(output.shape) == 3:
84 | self.featuremaps[name] = output.permute(
85 | 1, 0, 2) # LND -> NLD bs, smth, 768
86 | else:
87 | self.featuremaps[name] = output
88 |
89 | return hook
90 |
91 | def forward(self, x):
92 | self.featuremaps = collections.OrderedDict()
93 | fc_features = self.clip_model.encode_image(x).float()
94 | featuremaps = [self.featuremaps[k] for k in range(12)]
95 |
96 | return fc_features, featuremaps
97 |
98 |
99 | def l2_layers(xs_conv_features, ys_conv_features, clip_model_name):
100 | return [torch.square(x_conv - y_conv).mean() for x_conv, y_conv in
101 | zip(xs_conv_features, ys_conv_features)]
102 |
103 |
104 | def l1_layers(xs_conv_features, ys_conv_features, clip_model_name):
105 | return [torch.abs(x_conv - y_conv).mean() for x_conv, y_conv in
106 | zip(xs_conv_features, ys_conv_features)]
107 |
108 |
109 | def cos_layers(xs_conv_features, ys_conv_features, clip_model_name):
110 | if "RN" in clip_model_name:
111 | return [torch.square(x_conv, y_conv, dim=1).mean() for x_conv, y_conv in
112 | zip(xs_conv_features, ys_conv_features)]
113 | return [(1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean() for x_conv, y_conv in
114 | zip(xs_conv_features, ys_conv_features)]
115 |
116 |
117 | class CLIPConvLoss(torch.nn.Module):
118 | def __init__(self, device):
119 | super(CLIPConvLoss, self).__init__()
120 | self.clip_model_name = "ViT-B/32"
121 | assert self.clip_model_name in [
122 | "RN50",
123 | "RN101",
124 | "RN50x4",
125 | "RN50x16",
126 | "ViT-B/32",
127 | "ViT-B/16",
128 | ]
129 |
130 | self.clip_conv_loss_type = "L2"
131 | self.clip_fc_loss_type = "Cos" # args.clip_fc_loss_type
132 | assert self.clip_conv_loss_type in [
133 | "L2", "Cos", "L1",
134 | ]
135 | assert self.clip_fc_loss_type in [
136 | "L2", "Cos", "L1",
137 | ]
138 |
139 | self.distance_metrics = \
140 | {
141 | "L2": l2_layers,
142 | "L1": l1_layers,
143 | "Cos": cos_layers
144 | }
145 |
146 | self.model, clip_preprocess = clip.load(
147 | self.clip_model_name, device, jit=False, download_root=".")
148 |
149 | if self.clip_model_name.startswith("ViT"):
150 | self.visual_encoder = CLIPVisualEncoder(self.model)
151 |
152 | else:
153 | self.visual_model = self.model.visual
154 | layers = list(self.model.visual.children())
155 | init_layers = torch.nn.Sequential(*layers)[:8]
156 | self.layer1 = layers[8]
157 | self.layer2 = layers[9]
158 | self.layer3 = layers[10]
159 | self.layer4 = layers[11]
160 | self.att_pool2d = layers[12]
161 |
162 |
163 | self.img_size = clip_preprocess.transforms[1].size
164 | self.model.eval()
165 | self.target_transform = transforms.Compose([
166 | transforms.ToTensor(),
167 | ]) # clip normalisation
168 | self.normalize_transform = transforms.Compose([
169 | clip_preprocess.transforms[0], # Resize
170 | clip_preprocess.transforms[1], # CenterCrop
171 | clip_preprocess.transforms[-1], # Normalize
172 | ])
173 |
174 | self.model.eval()
175 | self.device = device
176 | # self.num_augs = num
177 |
178 | augemntations = []
179 | # if affine:
180 | # augemntations.append(transforms.RandomPerspective(
181 | # fill=0, p=1.0, distortion_scale=0.5))
182 | # augemntations.append(transforms.RandomResizedCrop(
183 | # 224, scale=(0.8, 0.8), ratio=(1.0, 1.0)))
184 | # else:
185 | augemntations.append(transforms.Resize([224, 224]))
186 | augemntations.append(
187 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)))
188 | self.augment_trans = transforms.Compose(augemntations)
189 |
190 | self.clip_fc_layer_dims = None # self.args.clip_fc_layer_dims
191 | self.clip_conv_layer_dims = None # self.args.clip_conv_layer_dims
192 | self.clip_fc_loss_weight = 0.1
193 | self.counter = 0
194 | self.clip_conv_layer_weights = "0,0,0,0,0"
195 | self.clip_conv_layer_weights = [
196 | float(item) for item in self.clip_conv_layer_weights.split(',')]
197 |
198 | def forward(self, sketch, target, mode="train"):
199 | """
200 | Parameters
201 | ----------
202 | sketch: Torch Tensor [1, C, H, W]
203 | target: Torch Tensor [1, C, H, W]
204 | """
205 | # y = self.target_transform(target).to(self.args.device)
206 | conv_loss_dict = {}
207 | x = sketch.to(self.device)
208 | y = target.to(self.device)
209 | sketch_augs, img_augs = [self.normalize_transform(x)], [
210 | self.normalize_transform(y)]
211 | # if mode == "train":
212 | # for n in range(self.num_augs):
213 | # augmented_pair = self.augment_trans(torch.cat([x, y]))
214 | # sketch_augs.append(augmented_pair[0].unsqueeze(0))
215 | # img_augs.append(augmented_pair[1].unsqueeze(0))
216 |
217 | xs = torch.cat(sketch_augs, dim=0).to(self.device)
218 | ys = torch.cat(img_augs, dim=0).to(self.device)
219 |
220 | if self.clip_model_name.startswith("RN"):
221 | xs_fc_features, xs_conv_features = self.forward_inspection_clip_resnet(
222 | xs.contiguous())
223 | ys_fc_features, ys_conv_features = self.forward_inspection_clip_resnet(
224 | ys.detach())
225 |
226 | else:
227 | xs_fc_features, xs_conv_features = self.visual_encoder(xs)
228 | ys_fc_features, ys_conv_features = self.visual_encoder(ys)
229 |
230 | conv_loss = self.distance_metrics[self.clip_conv_loss_type](
231 | xs_conv_features, ys_conv_features, self.clip_model_name)
232 | conv_loss_total=0
233 | for layer, w in enumerate(self.clip_conv_layer_weights):
234 | if w:
235 | conv_loss_dict[f"clip_conv_loss_layer{layer}"] = conv_loss[layer] * w
236 | conv_loss_total = conv_loss_total + conv_loss_dict[f"clip_conv_loss_layer{layer}"]
237 | if self.clip_fc_loss_weight:
238 | # fc distance is always cos
239 | fc_loss = (1 - torch.cosine_similarity(xs_fc_features,
240 | ys_fc_features, dim=1)).mean()
241 | conv_loss_dict["fc"] = fc_loss * self.clip_fc_loss_weight
242 | conv_loss_total += conv_loss_dict["fc"]
243 | self.counter += 1
244 | return conv_loss_total
245 |
246 | def forward_inspection_clip_resnet(self, x):
247 | def stem(m, x):
248 | x = m.relu1(m.bn1(m.conv1(x)))
249 | x = m.relu2(m.bn2(m.conv2(x)))
250 | x = m.relu3(m.bn3(m.conv3(x)))
251 | x = m.avgpool(x)
252 | return x
253 |
254 | x = x.type(self.visual_model.conv1.weight.dtype)
255 | x = stem(self.visual_model, x)
256 | x1 = self.layer1(x)
257 | x2 = self.layer2(x1)
258 | x3 = self.layer3(x2)
259 | x4 = self.layer4(x3)
260 | y = self.att_pool2d(x4)
261 | return y, [x, x1, x2, x3, x4]
262 |
--------------------------------------------------------------------------------
/modules/transform/context.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.nn.init import trunc_normal_
5 | from einops import rearrange
6 | from modules.layers import MLP, build_position_index
7 | from modules.layers import conv, deconv
8 | from utils.ckbd import *
9 |
10 |
11 | class LocalContext(nn.Module):
12 | def __init__(self,
13 | dim=32,
14 | window_size=5,
15 | mlp_ratio=2.,
16 | num_heads=2,
17 | qkv_bias=True,
18 | qk_scale=None
19 | ) -> None:
20 | super().__init__()
21 | self.H = -1
22 | self.W = -1
23 | self.dim = dim
24 | self.window_size = window_size
25 | self.num_heads = num_heads
26 | self.head_dim = dim // num_heads
27 | self.scale = qk_scale or self.head_dim ** -0.5
28 | self.qkv_proj = nn.Linear(dim, dim * 3, bias=qkv_bias)
29 | self.unfold = nn.Unfold(kernel_size=window_size, stride=1, padding=(window_size - 1) // 2)
30 | self.relative_position_table = nn.Parameter(
31 | torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
32 | )
33 | trunc_normal_(self.relative_position_table, std=0.02)
34 | self.softmax = nn.Softmax(dim=-1)
35 | self.proj = nn.Linear(dim * 2, dim * 2)
36 | self.mlp = MLP(in_dim=dim * 2, hidden_dim=int(dim * 2 * mlp_ratio), out_dim=dim * 2)
37 | self.norm1 = nn.LayerNorm(dim)
38 | self.norm2 = nn.LayerNorm(dim * 2)
39 | self.register_buffer("relative_position_index", build_position_index((window_size, window_size)))
40 | self.attn_mask = None
41 | self.fusion = nn.Conv2d(dim, dim * 2, kernel_size=window_size)
42 |
43 | def update_resolution(self, H, W, device, mask=None):
44 | updated=False
45 | if self.H != H or self.W != W:
46 | self.H = H
47 | self.W = W
48 | if mask is not None:
49 | self.attn_mask = mask.to(device)
50 | updated=True
51 | return updated
52 | ckbd = torch.zeros((1, 2, H, W), requires_grad=False)
53 | # anchor
54 | ckbd[:, :, 0::2, 1::2] = 1
55 | ckbd[:, :, 1::2, 0::2] = 1
56 | qk_windows = self.unfold(ckbd).permute(0, 2, 1)
57 | qk_windows = qk_windows.view(1, H * W, 2, 1, self.window_size, self.window_size).permute(2, 0, 1, 3, 4, 5)
58 | q_windows, k_windows = qk_windows[0], qk_windows[1]
59 | q = q_windows.reshape(1, H * W, 1, self.window_size * self.window_size).permute(0, 1, 3, 2)
60 | k = k_windows.reshape(1, H * W, 1, self.window_size * self.window_size).permute(0, 1, 3, 2)
61 | attn_mask = (q @ k.transpose(-2, -1))
62 | attn_mask = attn_mask.masked_fill(attn_mask == 0., float(-100.0)).masked_fill(attn_mask == 1, float(0.0))
63 | self.attn_mask = attn_mask[0].to(device).detach()
64 | updated=True
65 | return updated
66 |
67 | def forward(self, x):
68 | B, C, H, W = x.shape
69 | L = H * W
70 | self.update_resolution(H, W, x.device)
71 | # [B, L, C]
72 | x = x.reshape(B, C, L).permute(0, 2, 1)
73 | x = self.norm1(x)
74 |
75 | # [3, B, C, H, W]
76 | qkv = self.qkv_proj(x).reshape(B, H, W, 3, C).permute(3, 0, 4, 1, 2)
77 |
78 | # window partition
79 | q, k, v = qkv[0], qkv[1], qkv[2]
80 | qkv = torch.cat([q, k, v], dim=1)
81 | qkv_windows = self.unfold(qkv).permute(0, 2, 1)
82 | qkv_windows = qkv_windows.view(B, L, 3, C, self.window_size, self.window_size).permute(2, 0, 1, 3, 4, 5)
83 | # [B, L, C, window_size, window_size]
84 | q_windows, k_windows, v_windows = qkv_windows[0], qkv_windows[1], qkv_windows[2]
85 |
86 | # [B, L, num_heads, window_size * window_size, head_dim]
87 | q = q_windows.reshape(B, L, self.head_dim, self.num_heads, self.window_size * self.window_size).permute(0, 1, 3, 4, 2)
88 | k = k_windows.reshape(B, L, self.head_dim, self.num_heads, self.window_size * self.window_size).permute(0, 1, 3, 4, 2)
89 | v = v_windows.reshape(B, L, self.head_dim, self.num_heads, self.window_size * self.window_size).permute(0, 1, 3, 4, 2)
90 |
91 | q = q * self.scale
92 | # [B, L, num_heads, window_size * window_size, window_size * window_size]
93 | attn = (q @ k.transpose(-2, -1))
94 |
95 | relative_position_bias = self.relative_position_table[self.relative_position_index.view(-1)].view(
96 | self.window_size * self.window_size, self.window_size * self.window_size, -1
97 | )
98 | # [num_heads, window_size * window_size, window_size * window_size]
99 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
100 | attn = attn + relative_position_bias.unsqueeze(0).unsqueeze(1)
101 |
102 | attn = attn + self.attn_mask.unsqueeze(0).unsqueeze(2)
103 |
104 | attn = self.softmax(attn)
105 |
106 | x = (attn @ v).reshape(B, L, self.num_heads, self.window_size, self.window_size, self.head_dim).permute(0, 1, 3, 4, 2, 5)
107 | x = x.reshape(B * L, self.window_size, self.window_size, C).permute(0, 3, 1, 2)
108 | x = self.fusion(x).reshape(B, L, C * 2)
109 | x = self.proj(x)
110 | x = x + self.mlp(self.norm2(x))
111 | x = x.permute(0, 2, 1).reshape(B, C * 2, H, W)
112 | return x
113 |
114 |
115 | class ChannelContext(nn.Module):
116 | def __init__(self, in_dim, out_dim) -> None:
117 | super().__init__()
118 | self.fushion = nn.Sequential(
119 | nn.Conv2d(in_dim, 192, kernel_size=3, stride=1, padding=1),
120 | nn.GELU(),
121 | nn.Conv2d(192, 128, kernel_size=3, stride=1, padding=1),
122 | nn.GELU(),
123 | nn.Conv2d(128, out_dim * 4, kernel_size=3, stride=1, padding=1)
124 | )
125 |
126 | def forward(self, channel_params):
127 | """
128 | Args:
129 | channel_params(Tensor): [B, C * K, H, W]
130 | return:
131 | channel_params(Tensor): [B, C * 2, H, W]
132 | """
133 | channel_params = self.fushion(channel_params)
134 |
135 | return channel_params
136 |
137 |
138 | class ChannelContextEX(nn.Module):
139 | def __init__(self, in_dim, out_dim, act=nn.GELU) -> None:
140 | super().__init__()
141 | self.fushion = nn.Sequential(
142 | nn.Conv2d(in_dim, 224, kernel_size=3, stride=1, padding=1),
143 | act(),
144 | nn.Conv2d(224, 128, kernel_size=3, stride=1, padding=1),
145 | act(),
146 | nn.Conv2d(128, out_dim, kernel_size=3, stride=1, padding=1)
147 | )
148 |
149 | def forward(self, channel_params):
150 | """
151 | Args:
152 | channel_params(Tensor): [B, C * K, H, W]
153 | return:
154 | channel_params(Tensor): [B, C * 2, H, W]
155 | """
156 | channel_params = self.fushion(channel_params)
157 |
158 | return channel_params
159 |
160 | class LinearGlobalIntraContext(nn.Module):
161 | def __init__(
162 | self,
163 | dim=32,
164 | num_heads=2) -> None:
165 | super().__init__()
166 | self.dim = dim
167 | self.num_heads = num_heads
168 | self.keys = nn.Sequential(
169 | nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0),
170 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
171 | )
172 | self.queries = nn.Sequential(
173 | nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0),
174 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
175 | )
176 | self.values = nn.Sequential(
177 | nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0),
178 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
179 | )
180 | self.reprojection = nn.Conv2d(dim, dim * 2, kernel_size=5, stride=1, padding=2)
181 | self.mlp = nn.Sequential(
182 | nn.Conv2d(dim * 2, dim * 4, kernel_size=1, stride=1),
183 | nn.GELU(),
184 | nn.Conv2d(dim * 4, dim * 4, kernel_size=3, stride=1, padding=1, groups=dim * 4),
185 | nn.GELU(),
186 | nn.Conv2d(dim * 4, dim * 2, kernel_size=1, stride=1)
187 | )
188 |
189 | def forward(self, x1, x2):
190 | B, C, H, W = x1.shape
191 | x1_ac = ckbd_anchor(x1)
192 | x1_na = ckbd_nonanchor(x1)
193 | queries = ckbd_nonanchor_sequeeze(self.queries(x1_na)).reshape(B, self.dim, H * W//2)
194 | keys = ckbd_anchor_sequeeze(self.keys(x1_ac)).reshape(B, self.dim, H * W//2)
195 | values = ckbd_anchor_sequeeze(self.values(x2)).reshape(B, self.dim, H * W//2)
196 | head_dim = self.dim // self.num_heads
197 |
198 | attended_values = []
199 | for i in range(self.num_heads):
200 | key = F.softmax(keys[:, i * head_dim: (i + 1) * head_dim, :], dim=2)
201 | query = F.softmax(queries[:, i * head_dim: (i + 1) * head_dim, :], dim=1)
202 | value = values[:, i * head_dim: (i + 1) * head_dim, :]
203 | key = ckbd_anchor_unsequeeze(key.reshape(B, head_dim, H, W //2)).reshape(B, head_dim, H * W)
204 | value = ckbd_anchor_unsequeeze(value.reshape(B, head_dim, H, W //2)).reshape(B, head_dim, H * W)
205 | query = ckbd_nonanchor_unsequeeze(query.reshape(B, head_dim, H, W //2)).reshape(B, head_dim, H * W)
206 | context = key @ value.transpose(1, 2)
207 | attended_value = (context.transpose(1, 2) @ query).reshape(B, head_dim, H, W)
208 | attended_values.append(attended_value)
209 |
210 | aggregated_values = torch.cat(attended_values, dim=1)
211 | attention = self.reprojection(aggregated_values)
212 |
213 | return attention + self.mlp(attention)
214 |
215 | class LinearGlobalInterContext(nn.Module):
216 | def __init__(
217 | self,
218 | dim=32,
219 | out_dim=64,
220 | num_heads=2) -> None:
221 | super().__init__()
222 | self.dim = dim
223 | self.num_heads = num_heads
224 | self.keys = nn.Sequential(
225 | nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0),
226 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
227 | )
228 | self.queries = nn.Sequential(
229 | nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0),
230 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
231 | )
232 | self.values = nn.Sequential(
233 | nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0),
234 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim)
235 | )
236 | self.reprojection = nn.Conv2d(dim, out_dim * 3 // 2, kernel_size=5, stride=1, padding=2)
237 | self.mlp = nn.Sequential(
238 | nn.Conv2d(out_dim * 3 // 2, out_dim * 2, kernel_size=1, stride=1),
239 | nn.GELU(),
240 | nn.Conv2d(out_dim * 2, out_dim * 2, kernel_size=3, stride=1, padding=1, groups=out_dim * 2),
241 | nn.GELU(),
242 | nn.Conv2d(out_dim * 2, out_dim, kernel_size=1, stride=1)
243 | )
244 | self.skip = nn.Conv2d(out_dim * 3 // 2, out_dim, kernel_size=1, stride=1, padding=0)
245 |
246 | def forward(self, x1):
247 | B, C, H, W = x1.shape
248 | queries = self.queries(x1).reshape(B, self.dim, H * W)
249 | keys = self.keys(x1).reshape(B, self.dim, H * W)
250 | values = self.values(x1).reshape(B, self.dim, H * W)
251 | head_dim = self.dim // self.num_heads
252 |
253 | attended_values = []
254 | for i in range(self.num_heads):
255 | key = F.softmax(keys[:, i * head_dim: (i + 1) * head_dim, :], dim=2)
256 | query = F.softmax(queries[:, i * head_dim: (i + 1) * head_dim, :], dim=1)
257 | value = values[:, i * head_dim: (i + 1) * head_dim, :]
258 | context = key @ value.transpose(1, 2)
259 | attended_value = (context.transpose(1, 2) @ query).reshape(B, head_dim, H, W)
260 | attended_values.append(attended_value)
261 |
262 | aggregated_values = torch.cat(attended_values, dim=1)
263 | attention = self.reprojection(aggregated_values)
264 |
265 | return self.skip(attention) + self.mlp(attention)
266 |
--------------------------------------------------------------------------------
/utils/testing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from utils.metrics import compute_metrics, compute_metrics2
6 | from utils.utils import *
7 |
8 |
9 | def test_one_epoch(epoch, test_dataloader, model, criterion, save_dir, logger_val, tb_logger):
10 | model.eval()
11 | device = next(model.parameters()).device
12 |
13 | loss = AverageMeter()
14 | bpp_loss = AverageMeter()
15 | mse_loss = AverageMeter()
16 | ms_ssim_loss = AverageMeter()
17 | aux_loss = AverageMeter()
18 | psnr = AverageMeter()
19 | ms_ssim = AverageMeter()
20 |
21 | with torch.no_grad():
22 | for i, d in enumerate(test_dataloader):
23 | d = d.to(device)
24 | out_net = model(d)
25 | out_criterion = criterion(out_net, d)
26 |
27 | aux_loss.update(model.aux_loss())
28 | bpp_loss.update(out_criterion["bpp_loss"])
29 | loss.update(out_criterion["loss"])
30 | if out_criterion["mse_loss"] is not None:
31 | mse_loss.update(out_criterion["mse_loss"])
32 | if out_criterion["ms_ssim_loss"] is not None:
33 | ms_ssim_loss.update(out_criterion["ms_ssim_loss"])
34 |
35 | rec = torch2img(out_net['x_hat'][0])
36 | img = torch2img(d[0])
37 | p, m = compute_metrics(rec, img)
38 | psnr.update(p)
39 | ms_ssim.update(m)
40 |
41 | # if not os.path.exists(save_dir):
42 | # os.makedirs(save_dir)
43 | # rec.save(os.path.join(save_dir, '%03d_rec.png' % i))
44 | # img.save(os.path.join(save_dir, '%03d_gt.png' % i))
45 |
46 | tb_logger.add_scalar('{}'.format('[val]: loss'), loss.avg, epoch + 1)
47 | tb_logger.add_scalar('{}'.format('[val]: bpp_loss'), bpp_loss.avg, epoch + 1)
48 | tb_logger.add_scalar('{}'.format('[val]: psnr'), psnr.avg, epoch + 1)
49 | tb_logger.add_scalar('{}'.format('[val]: ms-ssim'), ms_ssim.avg, epoch + 1)
50 |
51 | if out_criterion["mse_loss"] is not None:
52 | logger_val.info(
53 | f"Test epoch {epoch}: Average losses: "
54 | f"Loss: {loss.avg:.4f} | "
55 | f"MSE loss: {mse_loss.avg:.6f} | "
56 | f"Bpp loss: {bpp_loss.avg:.4f} | "
57 | f"Aux loss: {aux_loss.avg:.2f} | "
58 | f"PSNR: {psnr.avg:.6f} | "
59 | f"MS-SSIM: {ms_ssim.avg:.6f}"
60 | )
61 | tb_logger.add_scalar('{}'.format('[val]: mse_loss'), mse_loss.avg, epoch + 1)
62 | if out_criterion["ms_ssim_loss"] is not None:
63 | logger_val.info(
64 | f"Test epoch {epoch}: Average losses: "
65 | f"Loss: {loss.avg:.4f} | "
66 | f"MS-SSIM loss: {ms_ssim_loss.avg:.6f} | "
67 | f"Bpp loss: {bpp_loss.avg:.4f} | "
68 | f"Aux loss: {aux_loss.avg:.2f} | "
69 | f"PSNR: {psnr.avg:.6f} | "
70 | f"MS-SSIM: {ms_ssim.avg:.6f}"
71 | )
72 | tb_logger.add_scalar('{}'.format('[val]: ms_ssim_loss'), ms_ssim_loss.avg, epoch + 1)
73 |
74 | return loss.avg
75 |
76 | def test_clip_one_epoch(epoch, test_dataloader, model, criterion, save_dir, logger_val, tb_logger, clip_loss, lambda_clip):
77 | model.eval()
78 | device = next(model.parameters()).device
79 |
80 | loss = AverageMeter()
81 | bpp_loss = AverageMeter()
82 | mse_loss = AverageMeter()
83 | ms_ssim_loss = AverageMeter()
84 | aux_loss = AverageMeter()
85 | psnr = AverageMeter()
86 | ms_ssim = AverageMeter()
87 | cliploss = AverageMeter()
88 |
89 | with torch.no_grad():
90 | for i, d in enumerate(test_dataloader):
91 | d = d.to(device)
92 | out_net = model(d)
93 | out_criterion = criterion(out_net, d)
94 | closs = clip_loss(out_net["x_hat"], d, "eval")
95 | aux_loss.update(model.aux_loss())
96 | bpp_loss.update(out_criterion["bpp_loss"])
97 | cliploss.update(closs)
98 | loss.update(out_criterion["loss"]+(lambda_clip * closs))
99 | if out_criterion["mse_loss"] is not None:
100 | mse_loss.update(out_criterion["mse_loss"])
101 | if out_criterion["ms_ssim_loss"] is not None:
102 | ms_ssim_loss.update(out_criterion["ms_ssim_loss"])
103 |
104 |
105 |
106 | rec = torch2img(out_net['x_hat'])
107 | img = torch2img(d)
108 | p, m = compute_metrics(rec, img)
109 | psnr.update(p)
110 | ms_ssim.update(m)
111 |
112 |
113 | if not os.path.exists(save_dir):
114 | os.makedirs(save_dir)
115 | rec.save(os.path.join(save_dir, '%03d_rec.png' % i))
116 | img.save(os.path.join(save_dir, '%03d_gt.png' % i))
117 |
118 | tb_logger.add_scalar('{}'.format('[val]: loss'), loss.avg, epoch + 1)
119 | tb_logger.add_scalar('{}'.format('[val]: bpp_loss'), bpp_loss.avg, epoch + 1)
120 | tb_logger.add_scalar('{}'.format('[val]: cliploss'), cliploss.avg, epoch + 1)
121 | tb_logger.add_scalar('{}'.format('[val]: psnr'), psnr.avg, epoch + 1)
122 | tb_logger.add_scalar('{}'.format('[val]: ms-ssim'), ms_ssim.avg, epoch + 1)
123 |
124 |
125 | if out_criterion["mse_loss"] is not None:
126 | logger_val.info(
127 | f"Test epoch {epoch}: Average losses: "
128 | f"Loss: {loss.avg:.4f} | "
129 | f"MSE loss: {mse_loss.avg:.6f} | "
130 | f"CLIP loss: {cliploss.avg:.6f} | "
131 | f"Bpp loss: {bpp_loss.avg:.4f} | "
132 | f"Aux loss: {aux_loss.avg:.2f} | "
133 | f"PSNR: {psnr.avg:.6f} | "
134 | f"MS-SSIM: {ms_ssim.avg:.6f}"
135 | )
136 | tb_logger.add_scalar('{}'.format('[val]: mse_loss'), mse_loss.avg, epoch + 1)
137 | if out_criterion["ms_ssim_loss"] is not None:
138 | logger_val.info(
139 | f"Test epoch {epoch}: Average losses: "
140 | f"Loss: {loss.avg:.4f} | "
141 | f"MS-SSIM loss: {ms_ssim_loss.avg:.6f} | "
142 | f"CLIP loss: {cliploss.avg:.6f} | "
143 | f"Bpp loss: {bpp_loss.avg:.4f} | "
144 | f"Aux loss: {aux_loss.avg:.2f} | "
145 | f"PSNR: {psnr.avg:.6f} | "
146 | f"MS-SSIM: {ms_ssim.avg:.6f}"
147 | )
148 | tb_logger.add_scalar('{}'.format('[val]: ms_ssim_loss'), ms_ssim_loss.avg, epoch + 1)
149 |
150 | return loss.avg
151 |
152 | def test_conditional_one_epoch(epoch, test_dataloader, model, criterion, save_dir, logger_val, tb_logger, clip_loss, lambda_clip):
153 | model.eval()
154 | device = next(model.parameters()).device
155 |
156 | loss = AverageMeter()
157 | bpp_loss = AverageMeter()
158 | mse_loss = AverageMeter()
159 | ms_ssim_loss = AverageMeter()
160 | aux_loss = AverageMeter()
161 | psnr = AverageMeter()
162 | ms_ssim = AverageMeter()
163 | cliploss = AverageMeter()
164 |
165 | with torch.no_grad():
166 | for i, d in enumerate(test_dataloader):
167 | bs = d.size(0)
168 | betas = torch.randint(0, 2, (bs,))
169 |
170 |
171 |
172 | d = d.to(device)
173 | betas = betas.to(device)
174 |
175 | out_net = model((d, betas))
176 | out_criterion = criterion(out_net, d)
177 | closs = clip_loss(out_net["x_hat"], d, "eval")
178 | aux_loss.update(model.aux_loss())
179 | bpp_loss.update(out_criterion["bpp_loss"])
180 | cliploss.update(closs)
181 | loss.update(out_criterion["loss"]+(lambda_clip * closs))
182 | if out_criterion["mse_loss"] is not None:
183 | mse_loss.update(out_criterion["mse_loss"])
184 | if out_criterion["ms_ssim_loss"] is not None:
185 | ms_ssim_loss.update(out_criterion["ms_ssim_loss"])
186 |
187 |
188 |
189 | rec = torch2img(out_net['x_hat'][0])
190 | img = torch2img(d[0])
191 | p, m = compute_metrics(rec, img)
192 | psnr.update(p)
193 | ms_ssim.update(m)
194 |
195 |
196 | # if not os.path.exists(save_dir):
197 | # os.makedirs(save_dir)
198 | # rec.save(os.path.join(save_dir, '%03d_rec.png' % i))
199 | # img.save(os.path.join(save_dir, '%03d_gt.png' % i))
200 |
201 | tb_logger.add_scalar('{}'.format('[val]: loss'), loss.avg, epoch + 1)
202 | tb_logger.add_scalar('{}'.format('[val]: bpp_loss'), bpp_loss.avg, epoch + 1)
203 | tb_logger.add_scalar('{}'.format('[val]: cliploss'), cliploss.avg, epoch + 1)
204 | tb_logger.add_scalar('{}'.format('[val]: psnr'), psnr.avg, epoch + 1)
205 | tb_logger.add_scalar('{}'.format('[val]: ms-ssim'), ms_ssim.avg, epoch + 1)
206 |
207 |
208 | if out_criterion["mse_loss"] is not None:
209 | logger_val.info(
210 | f"Test epoch {epoch}: Average losses: "
211 | f"Loss: {loss.avg:.4f} | "
212 | f"MSE loss: {mse_loss.avg:.6f} | "
213 | f"CLIP loss: {cliploss.avg:.6f} | "
214 | f"Bpp loss: {bpp_loss.avg:.4f} | "
215 | f"Aux loss: {aux_loss.avg:.2f} | "
216 | f"PSNR: {psnr.avg:.6f} | "
217 | f"MS-SSIM: {ms_ssim.avg:.6f}"
218 | )
219 | tb_logger.add_scalar('{}'.format('[val]: mse_loss'), mse_loss.avg, epoch + 1)
220 | if out_criterion["ms_ssim_loss"] is not None:
221 | logger_val.info(
222 | f"Test epoch {epoch}: Average losses: "
223 | f"Loss: {loss.avg:.4f} | "
224 | f"MS-SSIM loss: {ms_ssim_loss.avg:.6f} | "
225 | f"CLIP loss: {cliploss.avg:.6f} | "
226 | f"Bpp loss: {bpp_loss.avg:.4f} | "
227 | f"Aux loss: {aux_loss.avg:.2f} | "
228 | f"PSNR: {psnr.avg:.6f} | "
229 | f"MS-SSIM: {ms_ssim.avg:.6f}"
230 | )
231 | tb_logger.add_scalar('{}'.format('[val]: ms_ssim_loss'), ms_ssim_loss.avg, epoch + 1)
232 |
233 | return loss.avg
234 |
235 | def newtest_conditional_one_epoch(epoch, test_dataloader, model, criterion, save_dir, logger_val, tb_logger, clip_loss, lambda_clip):
236 | model.eval()
237 | device = next(model.parameters()).device
238 |
239 | loss = AverageMeter()
240 | bpp_loss = AverageMeter()
241 | mse_loss = AverageMeter()
242 | ms_ssim_loss = AverageMeter()
243 | aux_loss = AverageMeter()
244 | psnr = AverageMeter()
245 | ms_ssim = AverageMeter()
246 | cliploss = AverageMeter()
247 |
248 | with torch.no_grad():
249 | for i, d in enumerate(test_dataloader):
250 | bs = d.size(0)
251 | betas = torch.randint(0, 2, (bs,))
252 |
253 |
254 |
255 | d = d.to(device)
256 | betas = betas.to(device)
257 |
258 | out_net = model((d, betas))
259 | out_criterion = criterion(out_net, d)
260 | clip_out = clip_loss(out_net["x_hat"], d, "eval")
261 | closs = clip_out["fc"] + clip_out["clip_conv_all"]
262 | closs = closs.mean()
263 | aux_loss.update(model.aux_loss())
264 | bpp_loss.update(out_criterion["bpp_loss"])
265 | cliploss.update(closs)
266 | loss.update(out_criterion["loss"]+(lambda_clip * closs))
267 | if out_criterion["mse_loss"] is not None:
268 | mse_loss.update(out_criterion["mse_loss"])
269 | if out_criterion["ms_ssim_loss"] is not None:
270 | ms_ssim_loss.update(out_criterion["ms_ssim_loss"])
271 |
272 |
273 |
274 | rec = torch2img(out_net['x_hat'][0])
275 | img = torch2img(d[0])
276 | p, m = compute_metrics(rec, img)
277 | psnr.update(p)
278 | ms_ssim.update(m)
279 |
280 |
281 | # if not os.path.exists(save_dir):
282 | # os.makedirs(save_dir)
283 | # rec.save(os.path.join(save_dir, '%03d_rec.png' % i))
284 | # img.save(os.path.join(save_dir, '%03d_gt.png' % i))
285 |
286 | tb_logger.add_scalar('{}'.format('[val]: loss'), loss.avg, epoch + 1)
287 | tb_logger.add_scalar('{}'.format('[val]: bpp_loss'), bpp_loss.avg, epoch + 1)
288 | tb_logger.add_scalar('{}'.format('[val]: cliploss'), cliploss.avg, epoch + 1)
289 | tb_logger.add_scalar('{}'.format('[val]: psnr'), psnr.avg, epoch + 1)
290 | tb_logger.add_scalar('{}'.format('[val]: ms-ssim'), ms_ssim.avg, epoch + 1)
291 |
292 |
293 | if out_criterion["mse_loss"] is not None:
294 | logger_val.info(
295 | f"Test epoch {epoch}: Average losses: "
296 | f"Loss: {loss.avg:.4f} | "
297 | f"MSE loss: {mse_loss.avg:.6f} | "
298 | f"CLIP loss: {cliploss.avg:.6f} | "
299 | f"Bpp loss: {bpp_loss.avg:.4f} | "
300 | f"Aux loss: {aux_loss.avg:.2f} | "
301 | f"PSNR: {psnr.avg:.6f} | "
302 | f"MS-SSIM: {ms_ssim.avg:.6f}"
303 | )
304 | tb_logger.add_scalar('{}'.format('[val]: mse_loss'), mse_loss.avg, epoch + 1)
305 | if out_criterion["ms_ssim_loss"] is not None:
306 | logger_val.info(
307 | f"Test epoch {epoch}: Average losses: "
308 | f"Loss: {loss.avg:.4f} | "
309 | f"MS-SSIM loss: {ms_ssim_loss.avg:.6f} | "
310 | f"CLIP loss: {cliploss.avg:.6f} | "
311 | f"Bpp loss: {bpp_loss.avg:.4f} | "
312 | f"Aux loss: {aux_loss.avg:.2f} | "
313 | f"PSNR: {psnr.avg:.6f} | "
314 | f"MS-SSIM: {ms_ssim.avg:.6f}"
315 | )
316 | tb_logger.add_scalar('{}'.format('[val]: ms_ssim_loss'), ms_ssim_loss.avg, epoch + 1)
317 |
318 | return loss.avg
319 |
320 | def compress_one_image(model, x, stream_path, H, W, img_name):
321 | with torch.no_grad():
322 | out = model.compress(x)
323 |
324 | shape = out["shape"]
325 | output = os.path.join(stream_path, img_name)
326 | with Path(output).open("wb") as f:
327 | write_uints(f, (H, W))
328 | write_body(f, shape, out["strings"])
329 |
330 | size = filesize(output)
331 | bpp = float(size) * 8 / (H * W)
332 | return bpp, out["cost_time"]
333 |
334 |
335 | def decompress_one_image(model, stream_path, img_name):
336 | output = os.path.join(stream_path, img_name)
337 | with Path(output).open("rb") as f:
338 | original_size = read_uints(f, 2)
339 | strings, shape = read_body(f)
340 |
341 | with torch.no_grad():
342 | out = model.decompress(strings, shape)
343 |
344 | x_hat = out["x_hat"]
345 | x_hat = x_hat[:, :, 0 : original_size[0], 0 : original_size[1]]
346 | cost_time = out["cost_time"]
347 | return x_hat, cost_time
348 |
349 |
350 |
351 | def test_model(test_dataloader, net, logger_test, save_dir, epoch):
352 | net.eval()
353 | device = next(net.parameters()).device
354 |
355 | avg_psnr = AverageMeter()
356 | avg_ms_ssim = AverageMeter()
357 | avg_bpp = AverageMeter()
358 | avg_enc_time = AverageMeter()
359 | avg_dec_time = AverageMeter()
360 |
361 | with torch.no_grad():
362 | for i, data in enumerate(test_dataloader):
363 | img, names = data
364 | img = img.to(device)
365 | B, C, H, W = img.shape
366 | if H >1000 or W>1000:
367 | img = F.interpolate(img, size=512, mode='bilinear', align_corners=False)
368 | B, C, H, W = img.shape
369 | pad_h = 0
370 | pad_w = 0
371 | if H % 64 != 0:
372 | pad_h = 64 * (H // 64 + 1) - H
373 | if W % 64 != 0:
374 | pad_w = 64 * (W // 64 + 1) - W
375 | img_pad = F.pad(img, (0, pad_w, 0, pad_h), mode='constant', value=0)
376 | # warmup GPU
377 | if i == 0:
378 | bpp, enc_time = compress_one_image(model=net, x=img_pad, stream_path=save_dir, H=H, W=W, img_name=str(i))
379 | # avoid resolution leakage
380 | net.update_resolutions(16, 16)
381 | bpp, enc_time = compress_one_image(model=net, x=img_pad, stream_path=save_dir, H=H, W=W, img_name=str(i))
382 | # avoid resolution leakage
383 | net.update_resolutions(16, 16)
384 | x_hat, dec_time = decompress_one_image(model=net, stream_path=save_dir, img_name=str(i))
385 | rec = torch2img(x_hat)
386 | img = torch2img(img)
387 | # img.save(os.path.join(save_dir, '%03d_gt.png' % i))
388 | rec.save(os.path.join(save_dir, names[0]) + ".png")
389 | p, m = compute_metrics2(rec, img)
390 | avg_psnr.update(p)
391 | avg_ms_ssim.update(m)
392 | avg_bpp.update(bpp)
393 | avg_enc_time.update(enc_time)
394 | avg_dec_time.update(dec_time)
395 | logger_test.info(
396 | f"Image[{i}] | "
397 | f"Bpp loss: {bpp:.2f} | "
398 | f"PSNR: {p:.4f} | "
399 | f"MS-SSIM: {m:.4f} | "
400 | f"Encoding Latency: {enc_time:.4f} | "
401 | f"Decoding Latency: {dec_time:.4f}"
402 | )
403 | logger_test.info(
404 | f"Epoch:[{epoch}] | "
405 | f"Avg Bpp: {avg_bpp.avg:.4f} | "
406 | f"Avg PSNR: {avg_psnr.avg:.4f} | "
407 | f"Avg MS-SSIM: {avg_ms_ssim.avg:.4f} | "
408 | f"Avg Encoding Latency:: {avg_enc_time.avg:.4f} | "
409 | f"Avg decoding Latency:: {avg_dec_time.avg:.4f}"
410 | )
411 |
412 |
413 |
--------------------------------------------------------------------------------
/clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.relu1 = nn.ReLU(inplace=True)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.relu2 = nn.ReLU(inplace=True)
24 |
25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 | self.relu3 = nn.ReLU(inplace=True)
30 |
31 | self.downsample = None
32 | self.stride = stride
33 |
34 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36 | self.downsample = nn.Sequential(OrderedDict([
37 | ("-1", nn.AvgPool2d(stride)),
38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39 | ("1", nn.BatchNorm2d(planes * self.expansion))
40 | ]))
41 |
42 | def forward(self, x: torch.Tensor):
43 | identity = x
44 |
45 | out = self.relu1(self.bn1(self.conv1(x)))
46 | out = self.relu2(self.bn2(self.conv2(out)))
47 | out = self.avgpool(out)
48 | out = self.bn3(self.conv3(out))
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.relu3(out)
55 | return out
56 |
57 |
58 | class AttentionPool2d(nn.Module):
59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60 | super().__init__()
61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62 | self.k_proj = nn.Linear(embed_dim, embed_dim)
63 | self.q_proj = nn.Linear(embed_dim, embed_dim)
64 | self.v_proj = nn.Linear(embed_dim, embed_dim)
65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66 | self.num_heads = num_heads
67 |
68 | def forward(self, x):
69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72 | x, _ = F.multi_head_attention_forward(
73 | query=x[:1], key=x, value=x,
74 | embed_dim_to_check=x.shape[-1],
75 | num_heads=self.num_heads,
76 | q_proj_weight=self.q_proj.weight,
77 | k_proj_weight=self.k_proj.weight,
78 | v_proj_weight=self.v_proj.weight,
79 | in_proj_weight=None,
80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81 | bias_k=None,
82 | bias_v=None,
83 | add_zero_attn=False,
84 | dropout_p=0,
85 | out_proj_weight=self.c_proj.weight,
86 | out_proj_bias=self.c_proj.bias,
87 | use_separate_proj_weight=True,
88 | training=self.training,
89 | need_weights=False
90 | )
91 | return x.squeeze(0)
92 |
93 |
94 | class ModifiedResNet(nn.Module):
95 | """
96 | A ResNet class that is similar to torchvision's but contains the following changes:
97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99 | - The final pooling layer is a QKV attention instead of an average pool
100 | """
101 |
102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103 | super().__init__()
104 | self.output_dim = output_dim
105 | self.input_resolution = input_resolution
106 |
107 | # the 3-layer stem
108 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
109 | self.bn1 = nn.BatchNorm2d(width // 2)
110 | self.relu1 = nn.ReLU(inplace=True)
111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112 | self.bn2 = nn.BatchNorm2d(width // 2)
113 | self.relu2 = nn.ReLU(inplace=True)
114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115 | self.bn3 = nn.BatchNorm2d(width)
116 | self.relu3 = nn.ReLU(inplace=True)
117 | self.avgpool = nn.AvgPool2d(2)
118 |
119 | # residual layers
120 | self._inplanes = width # this is a *mutable* variable used during construction
121 | self.layer1 = self._make_layer(width, layers[0])
122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125 |
126 | embed_dim = width * 32 # the ResNet feature dimension
127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128 |
129 | def _make_layer(self, planes, blocks, stride=1):
130 | layers = [Bottleneck(self._inplanes, planes, stride)]
131 |
132 | self._inplanes = planes * Bottleneck.expansion
133 | for _ in range(1, blocks):
134 | layers.append(Bottleneck(self._inplanes, planes))
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x):
139 | def stem(x):
140 | x = self.relu1(self.bn1(self.conv1(x)))
141 | x = self.relu2(self.bn2(self.conv2(x)))
142 | x = self.relu3(self.bn3(self.conv3(x)))
143 | x = self.avgpool(x)
144 | return x
145 |
146 | x = x.type(self.conv1.weight.dtype)
147 | x = stem(x)
148 | x = self.layer1(x)
149 | x = self.layer2(x)
150 | x = self.layer3(x)
151 | x = self.layer4(x)
152 | x = self.attnpool(x)
153 |
154 | return x
155 |
156 |
157 | class LayerNorm(nn.LayerNorm):
158 | """Subclass torch's LayerNorm to handle fp16."""
159 |
160 | def forward(self, x: torch.Tensor):
161 | orig_type = x.dtype
162 | ret = super().forward(x.type(torch.float32))
163 | return ret.type(orig_type)
164 |
165 |
166 | class QuickGELU(nn.Module):
167 | def forward(self, x: torch.Tensor):
168 | return x * torch.sigmoid(1.702 * x)
169 |
170 |
171 | class ResidualAttentionBlock(nn.Module):
172 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
173 | super().__init__()
174 |
175 | self.attn = nn.MultiheadAttention(d_model, n_head)
176 | self.ln_1 = LayerNorm(d_model)
177 | self.mlp = nn.Sequential(OrderedDict([
178 | ("c_fc", nn.Linear(d_model, d_model * 4)),
179 | ("gelu", QuickGELU()),
180 | ("c_proj", nn.Linear(d_model * 4, d_model))
181 | ]))
182 | self.ln_2 = LayerNorm(d_model)
183 | self.attn_mask = attn_mask
184 |
185 | def attention(self, x: torch.Tensor):
186 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
187 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
188 |
189 | def forward(self, x: torch.Tensor):
190 | x = x + self.attention(self.ln_1(x))
191 | x = x + self.mlp(self.ln_2(x))
192 | return x
193 |
194 |
195 | class Transformer(nn.Module):
196 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
197 | super().__init__()
198 | self.width = width
199 | self.layers = layers
200 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
201 |
202 | def forward(self, x: torch.Tensor):
203 | return self.resblocks(x)
204 |
205 |
206 | class VisionTransformer(nn.Module):
207 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
208 | super().__init__()
209 | self.input_resolution = input_resolution
210 | self.output_dim = output_dim
211 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
212 |
213 | scale = width ** -0.5
214 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
215 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
216 | self.ln_pre = LayerNorm(width)
217 |
218 | self.transformer = Transformer(width, layers, heads)
219 |
220 | self.ln_post = LayerNorm(width)
221 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
222 |
223 | def forward(self, x: torch.Tensor):
224 | x = self.conv1(x) # shape = [*, width, grid, grid]
225 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
226 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
227 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
228 | x = x + self.positional_embedding.to(x.dtype)
229 | x = self.ln_pre(x)
230 |
231 | x = x.permute(1, 0, 2) # NLD -> LND
232 | x = self.transformer(x)
233 | x = x.permute(1, 0, 2) # LND -> NLD
234 |
235 | x = self.ln_post(x[:, 0, :])
236 |
237 | if self.proj is not None:
238 | x = x @ self.proj
239 |
240 | return x
241 |
242 |
243 | class CLIP(nn.Module):
244 | def __init__(self,
245 | embed_dim: int,
246 | # vision
247 | image_resolution: int,
248 | vision_layers: Union[Tuple[int, int, int, int], int],
249 | vision_width: int,
250 | vision_patch_size: int,
251 | # text
252 | context_length: int,
253 | vocab_size: int,
254 | transformer_width: int,
255 | transformer_heads: int,
256 | transformer_layers: int
257 | ):
258 | super().__init__()
259 |
260 | self.context_length = context_length
261 |
262 | if isinstance(vision_layers, (tuple, list)):
263 | vision_heads = vision_width * 32 // 64
264 | self.visual = ModifiedResNet(
265 | layers=vision_layers,
266 | output_dim=embed_dim,
267 | heads=vision_heads,
268 | input_resolution=image_resolution,
269 | width=vision_width
270 | )
271 | else:
272 | vision_heads = vision_width // 64
273 | self.visual = VisionTransformer(
274 | input_resolution=image_resolution,
275 | patch_size=vision_patch_size,
276 | width=vision_width,
277 | layers=vision_layers,
278 | heads=vision_heads,
279 | output_dim=embed_dim
280 | )
281 |
282 | self.transformer = Transformer(
283 | width=transformer_width,
284 | layers=transformer_layers,
285 | heads=transformer_heads,
286 | attn_mask=self.build_attention_mask()
287 | )
288 |
289 | self.vocab_size = vocab_size
290 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
291 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
292 | self.ln_final = LayerNorm(transformer_width)
293 |
294 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
295 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
296 |
297 | self.initialize_parameters()
298 |
299 | def initialize_parameters(self):
300 | nn.init.normal_(self.token_embedding.weight, std=0.02)
301 | nn.init.normal_(self.positional_embedding, std=0.01)
302 |
303 | if isinstance(self.visual, ModifiedResNet):
304 | if self.visual.attnpool is not None:
305 | std = self.visual.attnpool.c_proj.in_features ** -0.5
306 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
307 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
308 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
309 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
310 |
311 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
312 | for name, param in resnet_block.named_parameters():
313 | if name.endswith("bn3.weight"):
314 | nn.init.zeros_(param)
315 |
316 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
317 | attn_std = self.transformer.width ** -0.5
318 | fc_std = (2 * self.transformer.width) ** -0.5
319 | for block in self.transformer.resblocks:
320 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
321 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
322 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
323 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
324 |
325 | if self.text_projection is not None:
326 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
327 |
328 | def build_attention_mask(self):
329 | # lazily create causal attention mask, with full attention between the vision tokens
330 | # pytorch uses additive attention mask; fill with -inf
331 | mask = torch.empty(self.context_length, self.context_length)
332 | mask.fill_(float("-inf"))
333 | mask.triu_(1) # zero out the lower diagonal
334 | return mask
335 |
336 | @property
337 | def dtype(self):
338 | return self.visual.conv1.weight.dtype
339 |
340 | def encode_image(self, image):
341 | return self.visual(image.type(self.dtype))
342 |
343 | def encode_text(self, text):
344 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
345 |
346 | x = x + self.positional_embedding.type(self.dtype)
347 | x = x.permute(1, 0, 2) # NLD -> LND
348 | x = self.transformer(x)
349 | x = x.permute(1, 0, 2) # LND -> NLD
350 | x = self.ln_final(x).type(self.dtype)
351 |
352 | # x.shape = [batch_size, n_ctx, transformer.width]
353 | # take features from the eot embedding (eot_token is the highest number in each sequence)
354 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
355 |
356 | return x
357 |
358 | def forward(self, image, text):
359 | image_features = self.encode_image(image)
360 | text_features = self.encode_text(text)
361 |
362 | # normalized features
363 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
364 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
365 |
366 | # cosine similarity as logits
367 | logit_scale = self.logit_scale.exp()
368 | logits_per_image = logit_scale * image_features @ text_features.t()
369 | logits_per_text = logits_per_image.t()
370 |
371 | # shape = [global_batch_size, global_batch_size]
372 | return logits_per_image, logits_per_text
373 |
374 |
375 | def convert_weights(model: nn.Module):
376 | """Convert applicable model parameters to fp16"""
377 |
378 | def _convert_weights_to_fp16(l):
379 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
380 | l.weight.data = l.weight.data.half()
381 | if l.bias is not None:
382 | l.bias.data = l.bias.data.half()
383 |
384 | if isinstance(l, nn.MultiheadAttention):
385 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
386 | tensor = getattr(l, attr)
387 | if tensor is not None:
388 | tensor.data = tensor.data.half()
389 |
390 | for name in ["text_projection", "proj"]:
391 | if hasattr(l, name):
392 | attr = getattr(l, name)
393 | if attr is not None:
394 | attr.data = attr.data.half()
395 |
396 | model.apply(_convert_weights_to_fp16)
397 |
398 |
399 | def build_model(state_dict: dict):
400 | vit = "visual.proj" in state_dict
401 |
402 | if vit:
403 | vision_width = state_dict["visual.conv1.weight"].shape[0]
404 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
405 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
406 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
407 | image_resolution = vision_patch_size * grid_size
408 | else:
409 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
410 | vision_layers = tuple(counts)
411 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
412 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
413 | vision_patch_size = None
414 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
415 | image_resolution = output_width * 32
416 |
417 | embed_dim = state_dict["text_projection"].shape[1]
418 | context_length = state_dict["positional_embedding"].shape[0]
419 | vocab_size = state_dict["token_embedding.weight"].shape[0]
420 | transformer_width = state_dict["ln_final.weight"].shape[0]
421 | transformer_heads = transformer_width // 64
422 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
423 |
424 | model = CLIP(
425 | embed_dim,
426 | image_resolution, vision_layers, vision_width, vision_patch_size,
427 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
428 | )
429 |
430 | for key in ["input_resolution", "context_length", "vocab_size"]:
431 | if key in state_dict:
432 | del state_dict[key]
433 |
434 | convert_weights(model)
435 | model.load_state_dict(state_dict)
436 | return model.eval()
437 |
--------------------------------------------------------------------------------
/models/mlicpp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import time
5 | from compressai.models import CompressionModel
6 | from compressai.ops import ste_round
7 | from compressai.ans import BufferedRansEncoder, RansDecoder
8 | from utils.func import update_registered_buffers, get_scale_table
9 | from utils.ckbd import *
10 | from modules.transform import *
11 |
12 |
13 | class MLICPlusPlus(CompressionModel):
14 | def __init__(self, config, **kwargs):
15 | super().__init__(config.N, **kwargs)
16 | N = config.N
17 | M = config.M
18 | context_window = config.context_window
19 | slice_num = config.slice_num
20 | slice_ch = M // slice_num
21 | assert slice_ch * slice_num == M
22 |
23 | self.N = N
24 | self.M = M
25 | self.context_window = context_window
26 | self.slice_num = slice_num
27 | self.slice_ch = slice_ch
28 |
29 | self.g_a = AnalysisTransform(N=N, M=M)
30 | self.g_s = SynthesisTransform(N=N, M=M)
31 |
32 | self.h_a = HyperAnalysis(M=M, N=N)
33 | self.h_s = HyperSynthesis(M=M, N=N)
34 |
35 | # Gussian Conditional
36 | self.gaussian_conditional = GaussianConditional(None)
37 |
38 | self.local_context = nn.ModuleList(
39 | LocalContext(dim=slice_ch)
40 | for _ in range(slice_num)
41 | )
42 |
43 | self.channel_context = nn.ModuleList(
44 | ChannelContext(in_dim=slice_ch * i, out_dim=slice_ch) if i else None
45 | for i in range(slice_num)
46 | )
47 |
48 | # Global Reference for non-anchors
49 | self.global_inter_context = nn.ModuleList(
50 | LinearGlobalInterContext(dim=slice_ch * i, out_dim=slice_ch * 2, num_heads=slice_ch * i // 32) if i else None
51 | for i in range(slice_num)
52 | )
53 | self.global_intra_context = nn.ModuleList(
54 | LinearGlobalIntraContext(dim=slice_ch) if i else None
55 | for i in range(slice_num)
56 | )
57 | self.entropy_parameters_anchor = nn.ModuleList(
58 | EntropyParameters(in_dim=M * 2 + slice_ch * 6, out_dim=slice_ch * 2)
59 | if i else EntropyParameters(in_dim=M * 2, out_dim=slice_ch * 2)
60 | for i in range(slice_num)
61 | )
62 | self.entropy_parameters_nonanchor = nn.ModuleList(
63 | EntropyParameters(in_dim=M * 2 + slice_ch * 10, out_dim=slice_ch * 2)
64 | if i else EntropyParameters(in_dim=M * 2 + slice_ch * 2, out_dim=slice_ch * 2)
65 | for i in range(slice_num)
66 | )
67 |
68 | # Latent Residual Prediction
69 | self.lrp_anchor = nn.ModuleList(
70 | LatentResidualPrediction(in_dim=M + (i + 1) * slice_ch, out_dim=slice_ch)
71 | for i in range(slice_num)
72 | )
73 | self.lrp_nonanchor = nn.ModuleList(
74 | LatentResidualPrediction(in_dim=M + (i + 1) * slice_ch, out_dim=slice_ch)
75 | for i in range(slice_num)
76 | )
77 |
78 |
79 | def forward(self, x):
80 | """
81 | Using checkerboard context model with mask attention
82 | which divides y into anchor and non-anchor parts
83 | non-anchor use anchor as spatial context
84 | In addition, a channel-wise entropy model is used, too.
85 | Args:
86 | x: [B, 3, H, W]
87 | return:
88 | x_hat: [B, 3, H, W]
89 | y_likelihoods: [B, M, H // 16, W // 16]
90 | z_likelihoods: [B, N, H // 64, W // 64]
91 | likelihoods: y_likelihoods, z_likelihoods
92 | """
93 | self.update_resolutions(x.size(2) // 16, x.size(3) // 16)
94 | y = self.g_a(x)
95 | z = self.h_a(y)
96 | _, z_likelihoods = self.entropy_bottleneck(z)
97 | z_offset = self.entropy_bottleneck._get_medians()
98 | z_hat = ste_round(z - z_offset) + z_offset
99 |
100 | # Hyper-parameters
101 | hyper_params = self.h_s(z_hat)
102 | hyper_scales, hyper_means = hyper_params.chunk(2, 1)
103 |
104 | y_slices = y.chunk(self.slice_num, dim=1)
105 | y_hat_slices = []
106 | y_likelihoods = []
107 | for idx, y_slice in enumerate(y_slices):
108 | slice_anchor, slice_nonanchor = ckbd_split(y_slice)
109 | if idx == 0:
110 | # Anchor
111 | params_anchor = self.entropy_parameters_anchor[idx](hyper_params)
112 | scales_anchor, means_anchor = params_anchor.chunk(2, 1)
113 | # split means and scales of anchor
114 | scales_anchor = ckbd_anchor(scales_anchor)
115 | means_anchor = ckbd_anchor(means_anchor)
116 | # round anchor
117 | slice_anchor = ste_round(slice_anchor - means_anchor) + means_anchor
118 | # predict residuals cause by round
119 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1))
120 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor)
121 | # Non-anchor
122 | # local_ctx: [B, H, W, 2 * C]
123 | local_ctx = self.local_context[idx](slice_anchor)
124 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, hyper_params], dim=1))
125 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
126 | # split means and scales of nonanchor
127 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor)
128 | means_nonanchor = ckbd_nonanchor(means_nonanchor)
129 | # merge means and scales of anchor and nonanchor
130 | scales_slice = ckbd_merge(scales_anchor, scales_nonanchor)
131 | means_slice = ckbd_merge(means_anchor, means_nonanchor)
132 | _, y_slice_likelihoods = self.gaussian_conditional(y_slice, scales_slice, means_slice)
133 | # round slice_nonanchor
134 | slice_nonanchor = ste_round(slice_nonanchor - means_nonanchor) + means_nonanchor
135 | y_hat_slice = slice_anchor + slice_nonanchor
136 | # predict residuals cause by round
137 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [y_hat_slice]), dim=1))
138 | y_hat_slice = y_hat_slice + ckbd_nonanchor(lrp_nonanchor)
139 | y_hat_slices.append(y_hat_slice)
140 | y_likelihoods.append(y_slice_likelihoods)
141 |
142 | else:
143 | global_inter_ctx = self.global_inter_context[idx](torch.cat(y_hat_slices, dim=1))
144 | channel_ctx = self.channel_context[idx](torch.cat(y_hat_slices, dim=1))
145 | # Anchor(Use channel context and hyper params)
146 | params_anchor = self.entropy_parameters_anchor[idx](torch.cat([global_inter_ctx, channel_ctx, hyper_params], dim=1))
147 | scales_anchor, means_anchor = params_anchor.chunk(2, 1)
148 | # split means and scales of anchor
149 | scales_anchor = ckbd_anchor(scales_anchor)
150 | means_anchor = ckbd_anchor(means_anchor)
151 | # round anchor
152 | slice_anchor = ste_round(slice_anchor - means_anchor) + means_anchor
153 | # predict residuals cause by round
154 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1))
155 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor)
156 | # Non-anchor(Use spatial context, channel context and hyper params)
157 | global_intra_ctx = self.global_intra_context[idx](y_hat_slices[-1], slice_anchor)
158 | # ctx_params: [B, H, W, 2 * C]
159 | local_ctx = self.local_context[idx](slice_anchor)
160 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, global_intra_ctx, global_inter_ctx, channel_ctx, hyper_params], dim=1))
161 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
162 | # split means and scales of nonanchor
163 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor)
164 | means_nonanchor = ckbd_nonanchor(means_nonanchor)
165 | # merge means and scales of anchor and nonanchor
166 | scales_slice = ckbd_merge(scales_anchor, scales_nonanchor)
167 | means_slice = ckbd_merge(means_anchor, means_nonanchor)
168 | _, y_slice_likelihoods = self.gaussian_conditional(y_slice, scales_slice, means_slice)
169 | # round slice_nonanchor
170 | slice_nonanchor = ste_round(slice_nonanchor - means_nonanchor) + means_nonanchor
171 | y_hat_slice = slice_anchor + slice_nonanchor
172 | # predict residuals cause by round
173 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [y_hat_slice]), dim=1))
174 | y_hat_slice = y_hat_slice + ckbd_nonanchor(lrp_nonanchor)
175 | y_hat_slices.append(y_hat_slice)
176 | y_likelihoods.append(y_slice_likelihoods)
177 |
178 | y_hat = torch.cat(y_hat_slices, dim=1)
179 | y_likelihoods = torch.cat(y_likelihoods, dim=1)
180 | x_hat = self.g_s(y_hat)
181 |
182 | return {
183 | "x_hat": x_hat,
184 | "likelihoods": {"y_likelihoods": y_likelihoods, "z_likelihoods": z_likelihoods}
185 | }
186 |
187 | def update_resolutions(self, H, W):
188 | for i in range(len(self.global_intra_context)):
189 | if i == 0:
190 | self.local_context[i].update_resolution(H, W, next(self.parameters()).device, mask=None)
191 | else:
192 | self.local_context[i].update_resolution(H, W, next(self.parameters()).device, mask=self.local_context[0].attn_mask)
193 |
194 | def compress(self, x):
195 | torch.cuda.synchronize()
196 | start_time = time.time()
197 | self.update_resolutions(x.size(2) // 16, x.size(3) // 16)
198 | y = self.g_a(x)
199 | z = self.h_a(y)
200 | z_strings = self.entropy_bottleneck.compress(z)
201 | z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
202 | hyper_params = self.h_s(z_hat)
203 | hyper_scales, hyper_means = hyper_params.chunk(2, 1)
204 | y_slices = y.chunk(self.slice_num, dim=1)
205 | y_hat_slices = []
206 |
207 | cdf = self.gaussian_conditional.quantized_cdf.tolist()
208 | cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist()
209 | offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist()
210 | encoder = BufferedRansEncoder()
211 | symbols_list = []
212 | indexes_list = []
213 | y_strings = []
214 |
215 | for idx, y_slice in enumerate(y_slices):
216 | slice_anchor, slice_nonanchor = ckbd_split(y_slice)
217 | if idx == 0:
218 | # Anchor
219 | params_anchor = self.entropy_parameters_anchor[idx](hyper_params)
220 | scales_anchor, means_anchor = params_anchor.chunk(2, 1)
221 | # split means and scales of anchor
222 | scales_anchor = ckbd_anchor(scales_anchor)
223 | means_anchor = ckbd_anchor(means_anchor)
224 | # round and compress anchor
225 | slice_anchor = compress_anchor(self.gaussian_conditional, slice_anchor, scales_anchor, means_anchor, symbols_list, indexes_list)
226 | # predict residuals caused by round
227 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1))
228 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor)
229 | # Non-anchor
230 | # local_ctx: [B,2 * C, H, W]
231 | local_ctx = self.local_context[idx](slice_anchor)
232 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, hyper_params], dim=1))
233 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
234 | # split means and scales of nonanchor
235 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor)
236 | means_nonanchor = ckbd_nonanchor(means_nonanchor)
237 | # round and compress nonanchor
238 | slice_nonanchor = compress_nonanchor(self.gaussian_conditional, slice_nonanchor, scales_nonanchor, means_nonanchor, symbols_list, indexes_list)
239 | # predict residuals caused by round
240 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_nonanchor + slice_anchor]), dim=1))
241 | slice_nonanchor = slice_nonanchor + ckbd_nonanchor(lrp_nonanchor)
242 | y_hat_slices.append(slice_nonanchor + slice_anchor)
243 |
244 | else:
245 | # Anchor
246 | global_inter_ctx = self.global_inter_context[idx](torch.cat(y_hat_slices, dim=1))
247 | channel_ctx = self.channel_context[idx](torch.cat(y_hat_slices, dim=1))
248 | params_anchor = self.entropy_parameters_anchor[idx](torch.cat([global_inter_ctx, channel_ctx, hyper_params], dim=1))
249 | scales_anchor, means_anchor = params_anchor.chunk(2, 1)
250 | # split means and scales of anchor
251 | scales_anchor = ckbd_anchor(scales_anchor)
252 | means_anchor = ckbd_anchor(means_anchor)
253 | # round and compress anchor
254 | slice_anchor = compress_anchor(self.gaussian_conditional, slice_anchor, scales_anchor, means_anchor, symbols_list, indexes_list)
255 | # predict residuals caused by round
256 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1))
257 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor)
258 | # Non-anchor
259 | global_intra_ctx = self.global_intra_context[idx](y_hat_slices[-1], slice_anchor)
260 | # local_ctx: [B,2 * C, H, W]
261 | local_ctx = self.local_context[idx](slice_anchor)
262 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, global_intra_ctx, global_inter_ctx, channel_ctx, hyper_params], dim=1))
263 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
264 | # split means and scales of nonanchor
265 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor)
266 | means_nonanchor = ckbd_nonanchor(means_nonanchor)
267 | # round and compress nonanchor
268 | slice_nonanchor = compress_nonanchor(self.gaussian_conditional, slice_nonanchor, scales_nonanchor, means_nonanchor, symbols_list, indexes_list)
269 | # predict residuals caused by round
270 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_nonanchor + slice_anchor]), dim=1))
271 | slice_nonanchor = slice_nonanchor + ckbd_nonanchor(lrp_nonanchor)
272 | y_hat_slices.append(slice_nonanchor + slice_anchor)
273 |
274 | encoder.encode_with_indexes(symbols_list, indexes_list, cdf, cdf_lengths, offsets)
275 | y_string = encoder.flush()
276 | y_strings.append(y_string)
277 | torch.cuda.synchronize()
278 | end_time = time.time()
279 |
280 | cost_time = end_time - start_time
281 | return {
282 | "strings": [y_strings, z_strings],
283 | "shape": z.size()[-2:],
284 | "cost_time": cost_time
285 | }
286 |
287 | def decompress(self, strings, shape):
288 | torch.cuda.synchronize()
289 | start_time = time.time()
290 | y_strings = strings[0][0]
291 | z_strings = strings[1]
292 | z_hat = self.entropy_bottleneck.decompress(z_strings, shape)
293 | self.update_resolutions(z_hat.size(2) * 4, z_hat.size(3) * 4)
294 | hyper_params = self.h_s(z_hat)
295 | hyper_scales, hyper_means = hyper_params.chunk(2, 1)
296 | y_hat_slices = []
297 |
298 | cdf = self.gaussian_conditional.quantized_cdf.tolist()
299 | cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist()
300 | offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist()
301 | decoder = RansDecoder()
302 | decoder.set_stream(y_strings)
303 |
304 | for idx in range(self.slice_num):
305 | if idx == 0:
306 | # Anchor
307 | params_anchor = self.entropy_parameters_anchor[idx](hyper_params)
308 | scales_anchor, means_anchor = params_anchor.chunk(2, 1)
309 | # split means and scales of anchor
310 | scales_anchor = ckbd_anchor(scales_anchor)
311 | means_anchor = ckbd_anchor(means_anchor)
312 | # decompress anchor
313 | slice_anchor = decompress_anchor(self.gaussian_conditional, scales_anchor, means_anchor, decoder, cdf, cdf_lengths, offsets)
314 | # predict residuals caused by round
315 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1))
316 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor)
317 | # Non-anchor
318 | # local_ctx: [B,2 * C, H, W]
319 | local_ctx = self.local_context[idx](slice_anchor)
320 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, hyper_params], dim=1))
321 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
322 | # split means and scales of nonanchor
323 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor)
324 | means_nonanchor = ckbd_nonanchor(means_nonanchor)
325 | # decompress non-anchor
326 | slice_nonanchor = decompress_nonanchor(self.gaussian_conditional, scales_nonanchor, means_nonanchor, decoder, cdf, cdf_lengths, offsets)
327 | # predict residuals caused by round
328 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_nonanchor + slice_anchor]), dim=1))
329 | slice_nonanchor = slice_nonanchor + ckbd_nonanchor(lrp_nonanchor)
330 | y_hat_slices.append(slice_nonanchor + slice_anchor)
331 |
332 | else:
333 | # Anchor
334 | global_inter_ctx = self.global_inter_context[idx](torch.cat(y_hat_slices, dim=1))
335 | channel_ctx = self.channel_context[idx](torch.cat(y_hat_slices, dim=1))
336 | params_anchor = self.entropy_parameters_anchor[idx](torch.cat([global_inter_ctx, channel_ctx, hyper_params], dim=1))
337 | scales_anchor, means_anchor = params_anchor.chunk(2, 1)
338 | # split means and scales of anchor
339 | scales_anchor = ckbd_anchor(scales_anchor)
340 | means_anchor = ckbd_anchor(means_anchor)
341 | # decompress anchor
342 | slice_anchor = decompress_anchor(self.gaussian_conditional, scales_anchor, means_anchor, decoder, cdf, cdf_lengths, offsets)
343 | # predict residuals caused by round
344 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1))
345 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor)
346 | # Non-anchor
347 | # Non-anchor
348 | global_intra_ctx = self.global_intra_context[idx](y_hat_slices[-1], slice_anchor)
349 | # local_ctx: [B,2 * C, H, W]
350 | local_ctx = self.local_context[idx](slice_anchor)
351 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, global_intra_ctx, global_inter_ctx, channel_ctx, hyper_params], dim=1))
352 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
353 | # split means and scales of nonanchor
354 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor)
355 | means_nonanchor = ckbd_nonanchor(means_nonanchor)
356 | # decompress non-anchor
357 | slice_nonanchor = decompress_nonanchor(self.gaussian_conditional, scales_nonanchor, means_nonanchor, decoder, cdf, cdf_lengths, offsets)
358 | # predict residuals caused by round
359 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_nonanchor + slice_anchor]), dim=1))
360 | slice_nonanchor = slice_nonanchor + ckbd_nonanchor(lrp_nonanchor)
361 | y_hat_slices.append(slice_nonanchor + slice_anchor)
362 |
363 | y_hat = torch.cat(y_hat_slices, dim=1)
364 | x_hat = self.g_s(y_hat)
365 | torch.cuda.synchronize()
366 | end_time = time.time()
367 |
368 | cost_time = end_time - start_time
369 |
370 | return {
371 | "x_hat": x_hat,
372 | "cost_time": cost_time
373 | }
374 |
375 | def load_state_dict(self, state_dict):
376 | update_registered_buffers(
377 | self.gaussian_conditional,
378 | "gaussian_conditional",
379 | ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"],
380 | state_dict,
381 | )
382 | super().load_state_dict(state_dict)
383 |
384 | def update(self, scale_table=None, force=False):
385 | if scale_table is None:
386 | scale_table = get_scale_table()
387 | updated = self.gaussian_conditional.update_scale_table(scale_table, force=force)
388 | updated |= super().update(force=force)
389 | return updated
390 |
--------------------------------------------------------------------------------
/models/ug_mlicpp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import time
5 | from compressai.models import CompressionModel
6 | from compressai.ops import ste_round
7 | from compressai.ans import BufferedRansEncoder, RansDecoder
8 |
9 | from models.fourier_cond import GlobalConditioning
10 | from utils.func import update_registered_buffers, get_scale_table
11 | from utils.ckbd import *
12 | from modules.transform import *
13 |
14 |
15 | class UG_MLICPlusPlus(CompressionModel):
16 | def __init__(self, config, **kwargs):
17 | super().__init__(config.N, **kwargs)
18 |
19 | self.beta = 0
20 | N = config.N
21 | M = config.M
22 | context_window = config.context_window
23 | slice_num = config.slice_num
24 | slice_ch = M // slice_num
25 | assert slice_ch * slice_num == M
26 |
27 | self.N = N
28 | self.M = M
29 | self.context_window = context_window
30 | self.slice_num = slice_num
31 | self.slice_ch = slice_ch
32 |
33 | self.g_a = AnalysisTransform(N=N, M=M)
34 |
35 | self.h_a = HyperAnalysis(M=M, N=N)
36 | self.h_s = HyperSynthesis(M=M, N=N)
37 |
38 | # Conditional g
39 | self.global_cond = GlobalConditioning()
40 | self.g_s = ConditionalSynthesisTransform(N=N, M=M)
41 |
42 | # Gussian Conditional
43 | self.gaussian_conditional = GaussianConditional(None)
44 |
45 | self.local_context = nn.ModuleList(
46 | LocalContext(dim=slice_ch)
47 | for _ in range(slice_num)
48 | )
49 |
50 | self.channel_context = nn.ModuleList(
51 | ChannelContext(in_dim=slice_ch * i, out_dim=slice_ch) if i else None
52 | for i in range(slice_num)
53 | )
54 |
55 | # Global Reference for non-anchors
56 | self.global_inter_context = nn.ModuleList(
57 | LinearGlobalInterContext(dim=slice_ch * i, out_dim=slice_ch * 2, num_heads=slice_ch * i // 32) if i else None
58 | for i in range(slice_num)
59 | )
60 | self.global_intra_context = nn.ModuleList(
61 | LinearGlobalIntraContext(dim=slice_ch) if i else None
62 | for i in range(slice_num)
63 | )
64 | self.entropy_parameters_anchor = nn.ModuleList(
65 | EntropyParameters(in_dim=M * 2 + slice_ch * 6, out_dim=slice_ch * 2)
66 | if i else EntropyParameters(in_dim=M * 2, out_dim=slice_ch * 2)
67 | for i in range(slice_num)
68 | )
69 | self.entropy_parameters_nonanchor = nn.ModuleList(
70 | EntropyParameters(in_dim=M * 2 + slice_ch * 10, out_dim=slice_ch * 2)
71 | if i else EntropyParameters(in_dim=M * 2 + slice_ch * 2, out_dim=slice_ch * 2)
72 | for i in range(slice_num)
73 | )
74 |
75 | # Latent Residual Prediction
76 | self.lrp_anchor = nn.ModuleList(
77 | LatentResidualPrediction(in_dim=M + (i + 1) * slice_ch, out_dim=slice_ch)
78 | for i in range(slice_num)
79 | )
80 | self.lrp_nonanchor = nn.ModuleList(
81 | LatentResidualPrediction(in_dim=M + (i + 1) * slice_ch, out_dim=slice_ch)
82 | for i in range(slice_num)
83 | )
84 |
85 |
86 | def forward(self, x):
87 | x, betas = x
88 | fourier_features_mlp = self.global_cond(betas)
89 | """
90 | Using checkerboard context model with mask attention
91 | which divides y into anchor and non-anchor parts
92 | non-anchor use anchor as spatial context
93 | In addition, a channel-wise entropy model is used, too.
94 | Args:
95 | x: [B, 3, H, W]
96 | return:
97 | x_hat: [B, 3, H, W]
98 | y_likelihoods: [B, M, H // 16, W // 16]
99 | z_likelihoods: [B, N, H // 64, W // 64]
100 | likelihoods: y_likelihoods, z_likelihoods
101 | """
102 | self.update_resolutions(x.size(2) // 16, x.size(3) // 16)
103 | y = self.g_a(x)
104 | z = self.h_a(y)
105 | _, z_likelihoods = self.entropy_bottleneck(z)
106 | z_offset = self.entropy_bottleneck._get_medians()
107 | z_hat = ste_round(z - z_offset) + z_offset
108 |
109 | # Hyper-parameters
110 | hyper_params = self.h_s(z_hat)
111 | hyper_scales, hyper_means = hyper_params.chunk(2, 1)
112 |
113 | y_slices = y.chunk(self.slice_num, dim=1)
114 | y_hat_slices = []
115 | y_likelihoods = []
116 | for idx, y_slice in enumerate(y_slices):
117 | slice_anchor, slice_nonanchor = ckbd_split(y_slice)
118 | if idx == 0:
119 | # Anchor
120 | params_anchor = self.entropy_parameters_anchor[idx](hyper_params)
121 | scales_anchor, means_anchor = params_anchor.chunk(2, 1)
122 | # split means and scales of anchor
123 | scales_anchor = ckbd_anchor(scales_anchor)
124 | means_anchor = ckbd_anchor(means_anchor)
125 | # round anchor
126 | slice_anchor = ste_round(slice_anchor - means_anchor) + means_anchor
127 | # predict residuals cause by round
128 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1))
129 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor)
130 | # Non-anchor
131 | # local_ctx: [B, H, W, 2 * C]
132 | local_ctx = self.local_context[idx](slice_anchor)
133 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, hyper_params], dim=1))
134 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
135 | # split means and scales of nonanchor
136 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor)
137 | means_nonanchor = ckbd_nonanchor(means_nonanchor)
138 | # merge means and scales of anchor and nonanchor
139 | scales_slice = ckbd_merge(scales_anchor, scales_nonanchor)
140 | means_slice = ckbd_merge(means_anchor, means_nonanchor)
141 | _, y_slice_likelihoods = self.gaussian_conditional(y_slice, scales_slice, means_slice)
142 | # round slice_nonanchor
143 | slice_nonanchor = ste_round(slice_nonanchor - means_nonanchor) + means_nonanchor
144 | y_hat_slice = slice_anchor + slice_nonanchor
145 | # predict residuals cause by round
146 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [y_hat_slice]), dim=1))
147 | y_hat_slice = y_hat_slice + ckbd_nonanchor(lrp_nonanchor)
148 | y_hat_slices.append(y_hat_slice)
149 | y_likelihoods.append(y_slice_likelihoods)
150 |
151 | else:
152 | global_inter_ctx = self.global_inter_context[idx](torch.cat(y_hat_slices, dim=1))
153 | channel_ctx = self.channel_context[idx](torch.cat(y_hat_slices, dim=1))
154 | # Anchor(Use channel context and hyper params)
155 | params_anchor = self.entropy_parameters_anchor[idx](torch.cat([global_inter_ctx, channel_ctx, hyper_params], dim=1))
156 | scales_anchor, means_anchor = params_anchor.chunk(2, 1)
157 | # split means and scales of anchor
158 | scales_anchor = ckbd_anchor(scales_anchor)
159 | means_anchor = ckbd_anchor(means_anchor)
160 | # round anchor
161 | slice_anchor = ste_round(slice_anchor - means_anchor) + means_anchor
162 | # predict residuals cause by round
163 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1))
164 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor)
165 | # Non-anchor(Use spatial context, channel context and hyper params)
166 | global_intra_ctx = self.global_intra_context[idx](y_hat_slices[-1], slice_anchor)
167 | # ctx_params: [B, H, W, 2 * C]
168 | local_ctx = self.local_context[idx](slice_anchor)
169 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, global_intra_ctx, global_inter_ctx, channel_ctx, hyper_params], dim=1))
170 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
171 | # split means and scales of nonanchor
172 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor)
173 | means_nonanchor = ckbd_nonanchor(means_nonanchor)
174 | # merge means and scales of anchor and nonanchor
175 | scales_slice = ckbd_merge(scales_anchor, scales_nonanchor)
176 | means_slice = ckbd_merge(means_anchor, means_nonanchor)
177 | _, y_slice_likelihoods = self.gaussian_conditional(y_slice, scales_slice, means_slice)
178 | # round slice_nonanchor
179 | slice_nonanchor = ste_round(slice_nonanchor - means_nonanchor) + means_nonanchor
180 | y_hat_slice = slice_anchor + slice_nonanchor
181 | # predict residuals cause by round
182 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [y_hat_slice]), dim=1))
183 | y_hat_slice = y_hat_slice + ckbd_nonanchor(lrp_nonanchor)
184 | y_hat_slices.append(y_hat_slice)
185 | y_likelihoods.append(y_slice_likelihoods)
186 |
187 | y_hat = torch.cat(y_hat_slices, dim=1)
188 | y_likelihoods = torch.cat(y_likelihoods, dim=1)
189 | x_hat = self.g_s((y_hat, fourier_features_mlp))
190 |
191 | return {
192 | "x_hat": x_hat,
193 | "likelihoods": {"y_likelihoods": y_likelihoods, "z_likelihoods": z_likelihoods}
194 | }
195 |
196 | def update_resolutions(self, H, W):
197 | for i in range(len(self.global_intra_context)):
198 | if i == 0:
199 | self.local_context[i].update_resolution(H, W, next(self.parameters()).device, mask=None)
200 | else:
201 | self.local_context[i].update_resolution(H, W, next(self.parameters()).device, mask=self.local_context[0].attn_mask)
202 |
203 | def compress(self, x):
204 | torch.cuda.synchronize()
205 | start_time = time.time()
206 | self.update_resolutions(x.size(2) // 16, x.size(3) // 16)
207 | y = self.g_a(x)
208 | z = self.h_a(y)
209 | z_strings = self.entropy_bottleneck.compress(z)
210 | z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
211 | hyper_params = self.h_s(z_hat)
212 | hyper_scales, hyper_means = hyper_params.chunk(2, 1)
213 | y_slices = y.chunk(self.slice_num, dim=1)
214 | y_hat_slices = []
215 |
216 | cdf = self.gaussian_conditional.quantized_cdf.tolist()
217 | cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist()
218 | offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist()
219 | encoder = BufferedRansEncoder()
220 | symbols_list = []
221 | indexes_list = []
222 | y_strings = []
223 |
224 | for idx, y_slice in enumerate(y_slices):
225 | slice_anchor, slice_nonanchor = ckbd_split(y_slice)
226 | if idx == 0:
227 | # Anchor
228 | params_anchor = self.entropy_parameters_anchor[idx](hyper_params)
229 | scales_anchor, means_anchor = params_anchor.chunk(2, 1)
230 | # split means and scales of anchor
231 | scales_anchor = ckbd_anchor(scales_anchor)
232 | means_anchor = ckbd_anchor(means_anchor)
233 | # round and compress anchor
234 | slice_anchor = compress_anchor(self.gaussian_conditional, slice_anchor, scales_anchor, means_anchor, symbols_list, indexes_list)
235 | # predict residuals caused by round
236 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1))
237 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor)
238 | # Non-anchor
239 | # local_ctx: [B,2 * C, H, W]
240 | local_ctx = self.local_context[idx](slice_anchor)
241 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, hyper_params], dim=1))
242 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
243 | # split means and scales of nonanchor
244 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor)
245 | means_nonanchor = ckbd_nonanchor(means_nonanchor)
246 | # round and compress nonanchor
247 | slice_nonanchor = compress_nonanchor(self.gaussian_conditional, slice_nonanchor, scales_nonanchor, means_nonanchor, symbols_list, indexes_list)
248 | # predict residuals caused by round
249 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_nonanchor + slice_anchor]), dim=1))
250 | slice_nonanchor = slice_nonanchor + ckbd_nonanchor(lrp_nonanchor)
251 | y_hat_slices.append(slice_nonanchor + slice_anchor)
252 |
253 | else:
254 | # Anchor
255 | global_inter_ctx = self.global_inter_context[idx](torch.cat(y_hat_slices, dim=1))
256 | channel_ctx = self.channel_context[idx](torch.cat(y_hat_slices, dim=1))
257 | params_anchor = self.entropy_parameters_anchor[idx](torch.cat([global_inter_ctx, channel_ctx, hyper_params], dim=1))
258 | scales_anchor, means_anchor = params_anchor.chunk(2, 1)
259 | # split means and scales of anchor
260 | scales_anchor = ckbd_anchor(scales_anchor)
261 | means_anchor = ckbd_anchor(means_anchor)
262 | # round and compress anchor
263 | slice_anchor = compress_anchor(self.gaussian_conditional, slice_anchor, scales_anchor, means_anchor, symbols_list, indexes_list)
264 | # predict residuals caused by round
265 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1))
266 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor)
267 | # Non-anchor
268 | global_intra_ctx = self.global_intra_context[idx](y_hat_slices[-1], slice_anchor)
269 | # local_ctx: [B,2 * C, H, W]
270 | local_ctx = self.local_context[idx](slice_anchor)
271 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, global_intra_ctx, global_inter_ctx, channel_ctx, hyper_params], dim=1))
272 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
273 | # split means and scales of nonanchor
274 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor)
275 | means_nonanchor = ckbd_nonanchor(means_nonanchor)
276 | # round and compress nonanchor
277 | slice_nonanchor = compress_nonanchor(self.gaussian_conditional, slice_nonanchor, scales_nonanchor, means_nonanchor, symbols_list, indexes_list)
278 | # predict residuals caused by round
279 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_nonanchor + slice_anchor]), dim=1))
280 | slice_nonanchor = slice_nonanchor + ckbd_nonanchor(lrp_nonanchor)
281 | y_hat_slices.append(slice_nonanchor + slice_anchor)
282 |
283 | encoder.encode_with_indexes(symbols_list, indexes_list, cdf, cdf_lengths, offsets)
284 | y_string = encoder.flush()
285 | y_strings.append(y_string)
286 | torch.cuda.synchronize()
287 | end_time = time.time()
288 |
289 | cost_time = end_time - start_time
290 | return {
291 | "strings": [y_strings, z_strings],
292 | "shape": z.size()[-2:],
293 | "cost_time": cost_time
294 | }
295 |
296 | def decompress(self, strings, shape, beta):
297 | torch.cuda.synchronize()
298 | start_time = time.time()
299 | y_strings = strings[0][0]
300 | z_strings = strings[1]
301 | z_hat = self.entropy_bottleneck.decompress(z_strings, shape)
302 | self.update_resolutions(z_hat.size(2) * 4, z_hat.size(3) * 4)
303 | hyper_params = self.h_s(z_hat)
304 | hyper_scales, hyper_means = hyper_params.chunk(2, 1)
305 | y_hat_slices = []
306 |
307 | cdf = self.gaussian_conditional.quantized_cdf.tolist()
308 | cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist()
309 | offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist()
310 | decoder = RansDecoder()
311 | decoder.set_stream(y_strings)
312 |
313 | for idx in range(self.slice_num):
314 | if idx == 0:
315 | # Anchor
316 | params_anchor = self.entropy_parameters_anchor[idx](hyper_params)
317 | scales_anchor, means_anchor = params_anchor.chunk(2, 1)
318 | # split means and scales of anchor
319 | scales_anchor = ckbd_anchor(scales_anchor)
320 | means_anchor = ckbd_anchor(means_anchor)
321 | # decompress anchor
322 | slice_anchor = decompress_anchor(self.gaussian_conditional, scales_anchor, means_anchor, decoder, cdf, cdf_lengths, offsets)
323 | # predict residuals caused by round
324 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1))
325 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor)
326 | # Non-anchor
327 | # local_ctx: [B,2 * C, H, W]
328 | local_ctx = self.local_context[idx](slice_anchor)
329 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, hyper_params], dim=1))
330 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
331 | # split means and scales of nonanchor
332 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor)
333 | means_nonanchor = ckbd_nonanchor(means_nonanchor)
334 | # decompress non-anchor
335 | slice_nonanchor = decompress_nonanchor(self.gaussian_conditional, scales_nonanchor, means_nonanchor, decoder, cdf, cdf_lengths, offsets)
336 | # predict residuals caused by round
337 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_nonanchor + slice_anchor]), dim=1))
338 | slice_nonanchor = slice_nonanchor + ckbd_nonanchor(lrp_nonanchor)
339 | y_hat_slices.append(slice_nonanchor + slice_anchor)
340 |
341 | else:
342 | # Anchor
343 | global_inter_ctx = self.global_inter_context[idx](torch.cat(y_hat_slices, dim=1))
344 | channel_ctx = self.channel_context[idx](torch.cat(y_hat_slices, dim=1))
345 | params_anchor = self.entropy_parameters_anchor[idx](torch.cat([global_inter_ctx, channel_ctx, hyper_params], dim=1))
346 | scales_anchor, means_anchor = params_anchor.chunk(2, 1)
347 | # split means and scales of anchor
348 | scales_anchor = ckbd_anchor(scales_anchor)
349 | means_anchor = ckbd_anchor(means_anchor)
350 | # decompress anchor
351 | slice_anchor = decompress_anchor(self.gaussian_conditional, scales_anchor, means_anchor, decoder, cdf, cdf_lengths, offsets)
352 | # predict residuals caused by round
353 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1))
354 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor)
355 | # Non-anchor
356 | # Non-anchor
357 | global_intra_ctx = self.global_intra_context[idx](y_hat_slices[-1], slice_anchor)
358 | # local_ctx: [B,2 * C, H, W]
359 | local_ctx = self.local_context[idx](slice_anchor)
360 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, global_intra_ctx, global_inter_ctx, channel_ctx, hyper_params], dim=1))
361 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1)
362 | # split means and scales of nonanchor
363 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor)
364 | means_nonanchor = ckbd_nonanchor(means_nonanchor)
365 | # decompress non-anchor
366 | slice_nonanchor = decompress_nonanchor(self.gaussian_conditional, scales_nonanchor, means_nonanchor, decoder, cdf, cdf_lengths, offsets)
367 | # predict residuals caused by round
368 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_nonanchor + slice_anchor]), dim=1))
369 | slice_nonanchor = slice_nonanchor + ckbd_nonanchor(lrp_nonanchor)
370 | y_hat_slices.append(slice_nonanchor + slice_anchor)
371 |
372 | y_hat = torch.cat(y_hat_slices, dim=1)
373 |
374 | bs = y_hat.size(0)
375 | betas = torch.full((bs,), beta).to(y_hat.device)
376 | fourier_features_mlp = self.global_cond(betas)
377 | x_hat = self.g_s((y_hat, fourier_features_mlp))
378 | torch.cuda.synchronize()
379 | end_time = time.time()
380 |
381 | cost_time = end_time - start_time
382 |
383 | return {
384 | "x_hat": x_hat,
385 | "cost_time": cost_time
386 | }
387 |
388 | def load_state_dict(self, state_dict):
389 | update_registered_buffers(
390 | self.gaussian_conditional,
391 | "gaussian_conditional",
392 | ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"],
393 | state_dict,
394 | )
395 | super().load_state_dict(state_dict)
396 |
397 | def update(self, scale_table=None, force=False):
398 | if scale_table is None:
399 | scale_table = get_scale_table()
400 | updated = self.gaussian_conditional.update_scale_table(scale_table, force=force)
401 | updated |= super().update(force=force)
402 | return updated
403 |
--------------------------------------------------------------------------------