├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── __pycache__ │ ├── clip.cpython-37.pyc │ ├── clip.cpython-38.pyc │ ├── Closs.cpython-37.pyc │ ├── Closs.cpython-38.pyc │ ├── model.cpython-37.pyc │ ├── model.cpython-38.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── simple_tokenizer.cpython-37.pyc │ └── simple_tokenizer.cpython-38.pyc ├── simple_tokenizer.py ├── clip.py ├── Closs.py └── model.py ├── loss ├── __init__.py └── rd_loss.py ├── models ├── __init__.py ├── fourier_cond.py ├── mlicpp.py └── ug_mlicpp.py ├── img ├── image1.png ├── image2.png └── image3.png ├── modules ├── layers │ ├── __init__.py │ ├── conv.py │ ├── attention.py │ └── res_blk.py └── transform │ ├── __init__.py │ ├── quantization.py │ ├── entropy.py │ ├── analysis.py │ ├── synthesis.py │ └── context.py ├── playground ├── warmup.sh ├── test.py ├── test_condi.py ├── train.py ├── warmup.py ├── train_mask.py └── train_condi.py ├── tests ├── test.sh └── test.py ├── config ├── config.py └── args.py ├── utils ├── logger.py ├── metrics.py ├── optimizers.py ├── utils.py ├── func.py ├── newtrain_data.py ├── ckbd.py ├── train_data.py └── testing.py ├── LICENSE ├── classification └── test.py ├── instance segmantation └── test.py └── README.md /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .rd_loss import * 2 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ug_mlicpp import UG_MLICPlusPlus 2 | -------------------------------------------------------------------------------- /img/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/img/image1.png -------------------------------------------------------------------------------- /img/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/img/image2.png -------------------------------------------------------------------------------- /img/image3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/img/image3.png -------------------------------------------------------------------------------- /modules/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import * 2 | from .conv import * 3 | from .res_blk import * 4 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/clip.cpython-37.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/clip.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/Closs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/Closs.cpython-37.pyc -------------------------------------------------------------------------------- /clip/__pycache__/Closs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/Closs.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /clip/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/simple_tokenizer.cpython-37.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YinKangSheng/UG-ICM/HEAD/clip/__pycache__/simple_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /modules/transform/__init__.py: -------------------------------------------------------------------------------- 1 | from .analysis import * 2 | from .synthesis import * 3 | from .context import * 4 | from .quantization import * 5 | from .entropy import * 6 | -------------------------------------------------------------------------------- /playground/warmup.sh: -------------------------------------------------------------------------------- 1 | work_path=$(dirname $0) 2 | export PYTHONPATH=..:$PYTHONPATH 3 | nohup python warmup.py --metrics mse --exp mlicpp_mse_q1 --gpu_id 0 --lambda 0.0018 -lr 1e-4 --clip_max_norm 1.0 --seed 1984 --batch-size 8 & > 0018v2.txt 4 | -------------------------------------------------------------------------------- /tests/test.sh: -------------------------------------------------------------------------------- 1 | work_path=$(dirname $0) 2 | export PYTHONPATH=..:$PYTHONPATH 3 | CUDA_VISIBLE_DEVICES='0' python test.py -exp test_human --gpu_id 0 --beta 0 -c /path/to/checkpoint -d /path/to/dataset 4 | CUDA_VISIBLE_DEVICES='0' python test.py -exp test_machine --gpu_id 0 --beta 1 -c /path/to/checkpoint -d /path/to/dataset 5 | 6 | -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from utils.utils import Config 3 | 4 | def model_config(): 5 | config = Config({ 6 | # MLIC and MLIC+ 7 | "N": 192, 8 | "M": 320, 9 | "slice_num": 10, 10 | "context_window": 5, 11 | "act": nn.GELU, 12 | }) 13 | 14 | return config 15 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from datetime import datetime 4 | 5 | 6 | def get_timestamp(): 7 | return datetime.now().strftime('%y%m%d-%H%M%S') 8 | 9 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): 10 | '''set up logger''' 11 | lg = logging.getLogger(logger_name) 12 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', 13 | datefmt='%y-%m-%d %H:%M:%S') 14 | lg.setLevel(level) 15 | if tofile: 16 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp())) 17 | fh = logging.FileHandler(log_file, mode='w') 18 | fh.setFormatter(formatter) 19 | lg.addHandler(fh) 20 | if screen: 21 | sh = logging.StreamHandler() 22 | sh.setFormatter(formatter) 23 | lg.addHandler(sh) 24 | -------------------------------------------------------------------------------- /modules/layers/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from compressai.layers import conv3x3 6 | 7 | 8 | def conv1x1(in_ch: int, out_ch: int, stride: int = 1) -> nn.Module: 9 | """1x1 convolution.""" 10 | return nn.Conv2d(in_ch, out_ch, kernel_size=1, stride=stride) 11 | 12 | 13 | def conv(in_channels, out_channels, kernel_size=5, stride=2): 14 | return nn.Conv2d( 15 | in_channels, 16 | out_channels, 17 | kernel_size=kernel_size, 18 | stride=stride, 19 | padding=kernel_size // 2, 20 | ) 21 | 22 | 23 | def deconv(in_channels, out_channels, kernel_size=5, stride=2): 24 | return nn.ConvTranspose2d( 25 | in_channels, 26 | out_channels, 27 | kernel_size=kernel_size, 28 | stride=stride, 29 | output_padding=stride - 1, 30 | padding=kernel_size // 2, 31 | ) 32 | -------------------------------------------------------------------------------- /modules/transform/quantization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from compressai.layers import subpel_conv3x3 5 | from modules.layers.conv import conv1x1, conv3x3, conv, deconv 6 | from modules.layers.res_blk import * 7 | 8 | 9 | class LatentResidualPrediction(nn.Module): 10 | def __init__(self, in_dim, out_dim, act=nn.GELU): 11 | super().__init__() 12 | diff = abs(out_dim - in_dim) 13 | # such setting leads to much more parameters, you'd better use the setting in Minnen'20 ICIP paper. 14 | self.lrp_transform = nn.Sequential( 15 | conv3x3(in_dim, in_dim - diff // 4), 16 | act(), 17 | conv3x3(in_dim - diff // 4, in_dim - diff // 2), 18 | act(), 19 | conv3x3(in_dim - diff // 2, in_dim - diff * 3 // 4), 20 | act(), 21 | conv3x3(in_dim - diff * 3 // 4, out_dim), 22 | ) 23 | 24 | def forward(self, x): 25 | x = self.lrp_transform(x) 26 | x = 0.5 * torch.tanh(x) 27 | return x 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 YinKangSheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /loss/rd_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from pytorch_msssim import ms_ssim 5 | 6 | 7 | class RateDistortionLoss(nn.Module): 8 | """Custom rate distortion loss with a Lagrangian parameter.""" 9 | 10 | def __init__(self, lmbda=1e-2, metrics='mse'): 11 | super().__init__() 12 | self.mse = nn.MSELoss() 13 | self.lmbda = lmbda 14 | self.metrics = metrics 15 | 16 | def set_lmbda(self, lmbda): 17 | self.lmbda = lmbda 18 | 19 | def forward(self, output, target): 20 | N, _, H, W = target.size() 21 | out = {} 22 | num_pixels = N * H * W 23 | 24 | out["bpp_loss"] = sum( 25 | (torch.log(likelihoods).sum() / (-math.log(2) * num_pixels)) 26 | for likelihoods in output["likelihoods"].values() 27 | ) 28 | if self.metrics == 'mse': 29 | out["mse_loss"] = self.mse(output["x_hat"], target) 30 | out["ms_ssim_loss"] = None 31 | out["loss"] = self.lmbda * 255 ** 2 * out["mse_loss"] + out["bpp_loss"] 32 | elif self.metrics == 'ms-ssim': 33 | out["mse_loss"] = None 34 | out["ms_ssim_loss"] = 1 - ms_ssim(output["x_hat"], target, data_range=1.0) 35 | out["loss"] = self.lmbda * out["ms_ssim_loss"] + out["bpp_loss"] 36 | 37 | return out 38 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import PIL.Image as Image 4 | from typing import Dict, List, Optional, Tuple, Union 5 | from pytorch_msssim import ms_ssim 6 | import math 7 | from skimage.metrics import structural_similarity as ssim 8 | def PSNR(img1, img2): 9 | mse = np.mean((img1 / 255. - img2 / 255.) ** 2) 10 | if mse < 1.0e-10: 11 | return 100 12 | PIXEL_MAX = 1 13 | return 20 * math.log10(PIXEL_MAX / math.sqrt(mse)) 14 | 15 | 16 | def SSIM(img1, img2): 17 | 18 | return ssim(img1,img2,multichannel=True) 19 | 20 | def compute_metrics( 21 | a: Union[np.array, Image.Image], 22 | b: Union[np.array, Image.Image], 23 | max_val: float = 255.0, 24 | ) -> Tuple[float, float]: 25 | """Returns PSNR and MS-SSIM between images `a` and `b`. """ 26 | if isinstance(a, Image.Image): 27 | a = np.asarray(a) 28 | if isinstance(b, Image.Image): 29 | b = np.asarray(b) 30 | 31 | a = torch.from_numpy(a.copy()).float().unsqueeze(0) 32 | if a.size(3) == 3: 33 | a = a.permute(0, 3, 1, 2) 34 | b = torch.from_numpy(b.copy()).float().unsqueeze(0) 35 | if b.size(3) == 3: 36 | b = b.permute(0, 3, 1, 2) 37 | 38 | mse = torch.mean((a - b) ** 2).item() 39 | p = 20 * np.log10(max_val) - 10 * np.log10(mse) 40 | m = ms_ssim(a, b, data_range=max_val).item() 41 | return p, m 42 | 43 | def compute_metrics2( 44 | a: Union[np.array, Image.Image], 45 | b: Union[np.array, Image.Image], 46 | max_val: float = 255.0, 47 | ) -> Tuple[float, float]: 48 | 49 | 50 | p = PSNR(np.array(b), np.array(a)) 51 | m = SSIM(np.array(b), np.array(a)) 52 | return p, m -------------------------------------------------------------------------------- /modules/layers/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.init import trunc_normal_ 4 | from torch.nn import functional as F 5 | from timm.models.layers import to_2tuple 6 | from modules.layers.res_blk import ResidualBottleneck 7 | 8 | class MLP(nn.Module): 9 | 10 | def __init__(self, in_dim, hidden_dim=None, out_dim=None, act_layer=nn.GELU, drop=0.): 11 | super().__init__() 12 | out_dim = out_dim or in_dim 13 | hidden_dim = hidden_dim or in_dim 14 | self.fc1 = nn.Linear(in_dim, hidden_dim) 15 | self.act = act_layer() 16 | self.fc2 = nn.Linear(hidden_dim, out_dim) 17 | self.drop = nn.Dropout(drop) 18 | 19 | def forward(self, x): 20 | x = self.fc1(x) 21 | x = self.act(x) 22 | x = self.drop(x) 23 | x = self.fc2(x) 24 | x = self.drop(x) 25 | return x 26 | 27 | 28 | def build_position_index(window_size): 29 | coords_h = torch.arange(window_size[0]) 30 | coords_w = torch.arange(window_size[1]) 31 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 32 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 33 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 34 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 35 | relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 36 | relative_coords[:, :, 1] += window_size[1] - 1 37 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 38 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 39 | return relative_position_index -------------------------------------------------------------------------------- /config/args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def test_options(): 5 | parser = argparse.ArgumentParser(description="Testing script.") 6 | parser.add_argument( 7 | "-exp", 8 | "--experiment", 9 | default="test", 10 | type=str, 11 | required=False, 12 | help="Experiment name" 13 | ) 14 | parser.add_argument( 15 | "-d", 16 | "--dataset", 17 | default="/home/npr/dataset/", 18 | type=str, 19 | required=False, 20 | help="Training dataset" 21 | ) 22 | parser.add_argument( 23 | "-n", 24 | "--num-workers", 25 | type=int, 26 | default=1, 27 | help="Dataloaders threads (default: %(default)s)", 28 | ) 29 | parser.add_argument( 30 | "--metrics", 31 | type=str, 32 | default="mse", 33 | help="Optimized for (default: %(default)s)", 34 | ) 35 | parser.add_argument( 36 | "--test-batch-size", 37 | type=int, 38 | default=1, 39 | help="Test batch size (default: %(default)s)", 40 | ) 41 | parser.add_argument( 42 | "--gpu_id", 43 | type=int, 44 | default=3, 45 | help="GPU ID" 46 | ) 47 | parser.add_argument( 48 | "--cuda", 49 | default=True, 50 | help="Use cuda" 51 | ) 52 | parser.add_argument( 53 | "--save", 54 | default=True, 55 | help="Save model to disk" 56 | ) 57 | parser.add_argument( 58 | "-c", 59 | "--checkpoint", 60 | default=None, 61 | type=str, 62 | help="pretrained model path" 63 | ) 64 | args = parser.parse_args() 65 | return args 66 | -------------------------------------------------------------------------------- /classification/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | import torchvision.datasets as datasets 4 | import torchvision.models as models 5 | from torchvision.models import ResNet101_Weights 6 | from torch.utils.data import DataLoader 7 | import torch.nn.functional as F 8 | import sys 9 | import os 10 | from tqdm import tqdm 11 | # Pretrain ResNet101 12 | net = models.resnet101(weights=None, progress=True) 13 | checkpoint_path = 'path/to/resnet101-cd907fc2.pth' 14 | checkpoint = torch.load(checkpoint_path) 15 | 16 | net.load_state_dict(checkpoint) 17 | 18 | 19 | net.eval() 20 | 21 | # Data 22 | preprocess = transforms.Compose([ 23 | transforms.Resize(256), 24 | transforms.CenterCrop(224), 25 | transforms.ToTensor(), 26 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 27 | ]) 28 | 29 | # load ImageNet val 30 | imagenet_val_dataset = datasets.ImageFolder(root='path/to/val', transform=preprocess) 31 | val_loader = DataLoader(imagenet_val_dataset, batch_size=32, shuffle=False, num_workers=4) 32 | 33 | 34 | def evaluate_model(model, data_loader): 35 | model.eval() 36 | correct = 0 37 | total = 0 38 | with torch.no_grad(): 39 | for images, labels in tqdm(data_loader): 40 | images = images.to('cuda') 41 | labels = labels.to('cuda') 42 | outputs = model(images) 43 | _, predicted = torch.max(outputs.data, 1) 44 | total += labels.size(0) 45 | correct += (predicted == labels).sum().item() 46 | # if predicted!=labels: 47 | # print(total%25) 48 | # print(labels) 49 | accuracy = 100 * correct / total 50 | return accuracy 51 | 52 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 53 | net.to(device) 54 | 55 | accuracy = evaluate_model(net, val_loader) 56 | print(f'Accuracy of the model on the ImageNet validation set: {accuracy:.2f}%') 57 | -------------------------------------------------------------------------------- /utils/optimizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | 5 | 6 | def configure_optimizers(net, args): 7 | """Separate parameters for the main optimizer and the auxiliary optimizer. 8 | Return two optimizers""" 9 | 10 | parameters = [ 11 | p for n, p in net.named_parameters() if not n.endswith(".quantiles") 12 | ] 13 | aux_parameters = [ 14 | p for n, p in net.named_parameters() if n.endswith(".quantiles") 15 | ] 16 | # Make sure we don't have an intersection of parameters 17 | params_dict = dict(net.named_parameters()) 18 | inter_params = set(parameters) & set(aux_parameters) 19 | union_params = set(parameters) | set(aux_parameters) 20 | 21 | assert len(inter_params) == 0 22 | assert len(union_params) - len(params_dict.keys()) == 0 23 | 24 | optimizer = optim.Adam( 25 | (p for p in parameters if p.requires_grad), 26 | lr=args.learning_rate, 27 | ) 28 | aux_optimizer = optim.Adam( 29 | (p for p in aux_parameters if p.requires_grad), 30 | lr=args.aux_learning_rate, 31 | ) 32 | return optimizer, aux_optimizer 33 | def configure_optimizer(net, args): 34 | """Separate parameters for the main optimizer and the auxiliary optimizer. 35 | Return two optimizers""" 36 | 37 | parameters = [ 38 | p for n, p in net.named_parameters() if not n.endswith(".quantiles") 39 | ] 40 | aux_parameters = [ 41 | p for n, p in net.named_parameters() if n.endswith(".quantiles") 42 | ] 43 | # Make sure we don't have an intersection of parameters 44 | params_dict = dict(net.named_parameters()) 45 | inter_params = set(parameters) & set(aux_parameters) 46 | union_params = set(parameters) | set(aux_parameters) 47 | 48 | assert len(inter_params) == 0 49 | assert len(union_params) - len(params_dict.keys()) == 0 50 | 51 | optimizer = optim.Adam( 52 | (p for p in parameters if p.requires_grad), 53 | lr=args.learning_rate, 54 | ) 55 | 56 | return optimizer 57 | -------------------------------------------------------------------------------- /instance segmantation/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import random 5 | import detectron2 6 | from detectron2.engine import DefaultPredictor, DefaultTrainer 7 | from detectron2.config import get_cfg 8 | from detectron2.data import DatasetCatalog, MetadataCatalog 9 | from detectron2.data.datasets import register_pascal_voc 10 | from detectron2.utils.visualizer import Visualizer 11 | from detectron2.evaluation import COCOEvaluator, DatasetEvaluators, inference_on_dataset 12 | from detectron2.data import build_detection_test_loader 13 | from detectron2.utils.logger import setup_logger 14 | from detectron2.data.datasets import register_coco_instances 15 | # Initialize the logger 16 | # from detectron2.model_zoo import get_config_file, 17 | from detectron2 import model_zoo 18 | from PIL import ImageFile 19 | import numpy as np 20 | # ImageFile.LOAD_TRUNCATED_IMAGES = True 21 | 22 | setup_logger() 23 | 24 | dataset_name = "coco_8k_hy" 25 | 26 | json_file = "Coco/coco_2017.json" 27 | 28 | image_dir = "path/to/images" 29 | 30 | register_coco_instances(dataset_name, {}, json_file, image_dir) 31 | 32 | # Get dataset metadata and dataset dictionaries 33 | metadata = MetadataCatalog.get(dataset_name) 34 | dataset_dicts = DatasetCatalog.get(dataset_name) 35 | 36 | # Configure and load the pre-trained mask_rcnn model 37 | cfg = get_cfg() 38 | 39 | print("===============load model==============") 40 | cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml")) 41 | cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml") 42 | print("==========complete==================") 43 | 44 | cfg.MODEL.DEVICE = "cuda" # 45 | cfg.DATASETS.TEST = (dataset_name, ) 46 | cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(metadata.thing_classes) 47 | 48 | predictor = DefaultPredictor(cfg) 49 | 50 | evaluator = COCOEvaluator(dataset_name, cfg, False, output_dir="./mask/") 51 | val_loader = build_detection_test_loader(cfg, dataset_name) 52 | 53 | output_dir = "./mask/" 54 | os.makedirs(output_dir, exist_ok=True) 55 | 56 | # Perform inference and evaluation 57 | print("Running inference...") 58 | results = inference_on_dataset(predictor.model, val_loader, evaluator) 59 | -------------------------------------------------------------------------------- /modules/transform/entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from modules.layers.attention import MLP 5 | 6 | 7 | class EntropyParameters(nn.Module): 8 | def __init__(self, in_dim, out_dim, act=nn.GELU) -> None: 9 | super().__init__() 10 | self.fusion = nn.Sequential( 11 | nn.Conv2d(in_dim, 320, kernel_size=1, stride=1, padding=0), 12 | act(), 13 | nn.Conv2d(320, 256, kernel_size=1, stride=1, padding=0), 14 | act(), 15 | nn.Conv2d(256, 128, kernel_size=1, stride=1, padding=0), 16 | act(), 17 | nn.Conv2d(128, out_dim, kernel_size=1, stride=1, padding=0), 18 | ) 19 | 20 | def forward(self, params): 21 | """ 22 | Args: 23 | params(Tensor): [B, C * K, H, W] 24 | return: 25 | gaussian_params(Tensor): [B, C * 2, H, W] 26 | """ 27 | gaussian_params = self.fusion(params) 28 | 29 | return gaussian_params 30 | 31 | 32 | class EntropyParametersEX(nn.Module): 33 | def __init__(self, in_dim, out_dim, act=nn.GELU) -> None: 34 | super().__init__() 35 | self.fusion = nn.Sequential( 36 | nn.Conv2d(in_dim, out_dim * 5 // 3, 1), 37 | act(), 38 | nn.Conv2d(out_dim * 5 // 3, out_dim * 4 // 3, 1), 39 | act(), 40 | nn.Conv2d(out_dim * 4 // 3, out_dim, 1), 41 | ) 42 | 43 | def forward(self, params): 44 | """ 45 | Args: 46 | params(Tensor): [B, C * K, H, W] 47 | return: 48 | gaussian_params(Tensor): [B, C * 2, H, W] 49 | """ 50 | gaussian_params = self.fusion(params) 51 | 52 | return gaussian_params 53 | 54 | 55 | class ChannelWiseEntropyParameters(nn.Module): 56 | def __init__(self, in_channels=192, out_channels=192): 57 | super().__init__() 58 | diff = (in_channels - out_channels) // 3 59 | self.layers = nn.Sequential( 60 | nn.Conv2d(in_channels, in_channels - diff, 1), 61 | nn.LeakyReLU(inplace=True), 62 | nn.Conv2d(in_channels - diff, in_channels - 2 * diff, 1), 63 | nn.LeakyReLU(inplace=True), 64 | nn.Conv2d(in_channels - 2 * diff, out_channels, 1), 65 | ) 66 | 67 | def forward(self, x): 68 | x = self.layers(x) 69 | return x 70 | -------------------------------------------------------------------------------- /playground/test.py: -------------------------------------------------------------------------------- 1 | from re import T 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import os 6 | import logging 7 | import sys 8 | 9 | 10 | 11 | sys.path.append("..") 12 | from utils.train_data import TestData, testDataset_collate, NewTestData, TestImagenet 13 | from config.args import test_options 14 | from config.config import model_config 15 | from compressai.datasets import ImageFolder 16 | from torchvision import transforms 17 | from torch.utils.data import DataLoader 18 | from PIL import ImageFile, Image 19 | from models import * 20 | from utils.testing import test_model 21 | from utils.logger import setup_logger 22 | from utils.newtrain_data import TrainData 23 | 24 | def main(): 25 | ImageFile.LOAD_TRUNCATED_IMAGES = True 26 | Image.MAX_IMAGE_PIXELS = None 27 | 28 | args = test_options() 29 | config = model_config() 30 | 31 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) 32 | 33 | torch.backends.cudnn.deterministic = True 34 | 35 | if not os.path.exists(os.path.join('./experiments', args.experiment)): 36 | os.makedirs(os.path.join('./experiments', args.experiment)) 37 | setup_logger('test', os.path.join('./experiments', args.experiment), 'test_' + args.experiment, level=logging.INFO, 38 | screen=True, tofile=True) 39 | logger_test = logging.getLogger('test') 40 | 41 | # 42 | # test_dataset = TestData("../voc_norm_test.txt") 43 | # test_dataset = TestImagenet("../../data/liuquan/Imagenet_val25k/val_set.txt") 44 | # test_dataset = NewTestData("../../data") 45 | test_dataset = NewTestData("../../data/Kodak24") 46 | # test_dataset = NewTestData("../../yolov3/coco/val2017") 47 | test_dataloader = DataLoader( 48 | test_dataset, 49 | batch_size=args.test_batch_size, 50 | num_workers=args.num_workers, 51 | shuffle=False, 52 | pin_memory=True, 53 | collate_fn=testDataset_collate, 54 | ) 55 | 56 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" 57 | 58 | net = MLICPlusPlus(config=config) 59 | net = net.to(device) 60 | checkpoint = torch.load(args.checkpoint) 61 | # new_ckpt = modify_checkpoint(checkpoint['state_dict']) 62 | net.load_state_dict(checkpoint['state_dict']) 63 | epoch = checkpoint["epoch"] 64 | logger_test.info(f"Start testing!" ) 65 | save_dir = os.path.join('./experiments', args.experiment) 66 | if not os.path.exists(save_dir): 67 | os.makedirs(save_dir) 68 | test_model(net=net, test_dataloader=test_dataloader, logger_test=logger_test, save_dir=save_dir, epoch=epoch) 69 | 70 | 71 | if __name__ == '__main__': 72 | main() 73 | 74 | -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | from re import T 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import os 6 | import logging 7 | import sys 8 | 9 | 10 | 11 | sys.path.append("..") 12 | from utils.utils import save_checkpoint 13 | from utils.train_data import TestData, testDataset_collate 14 | 15 | from config.args import test_options 16 | from config.config import model_config 17 | 18 | from torch.utils.data import DataLoader 19 | from PIL import ImageFile, Image 20 | from models import * 21 | from utils.testing import test_model 22 | from utils.logger import setup_logger 23 | 24 | 25 | def main(): 26 | ImageFile.LOAD_TRUNCATED_IMAGES = True 27 | Image.MAX_IMAGE_PIXELS = None 28 | 29 | args = test_options() 30 | config = model_config() 31 | 32 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) 33 | 34 | torch.backends.cudnn.deterministic = True 35 | 36 | if not os.path.exists(os.path.join('./experiments', args.experiment)): 37 | os.makedirs(os.path.join('./experiments', args.experiment)) 38 | setup_logger('test', os.path.join('./experiments', args.experiment), 'test_' + args.experiment, level=logging.INFO, 39 | screen=True, tofile=True) 40 | logger_test = logging.getLogger('test') 41 | 42 | 43 | test_dataset = TestData(args.dataset) 44 | test_dataloader = DataLoader( 45 | test_dataset, 46 | batch_size=args.test_batch_size, 47 | num_workers=args.num_workers, 48 | shuffle=False, 49 | pin_memory=True, 50 | collate_fn=testDataset_collate, 51 | ) 52 | 53 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" 54 | 55 | net = UG_MLICPlusPlus(config=config) 56 | net = net.to(device) 57 | checkpoint = torch.load(args.checkpoint) 58 | net.load_state_dict(checkpoint['state_dict']) 59 | logger_test.info(f"Start testing!" ) 60 | save_dir = os.path.join('./experiments', args.experiment, "bitstream") 61 | save_dir_human = os.path.join('./experiments', args.experiment, "human") 62 | save_dir_machine = os.path.join('./experiments', args.experiment, "machine") 63 | if not os.path.exists(save_dir): 64 | os.makedirs(save_dir) 65 | if not os.path.exists(save_dir_human): 66 | os.makedirs(save_dir_human) 67 | if not os.path.exists(save_dir_machine): 68 | os.makedirs(save_dir_machine) 69 | test_model(net=net, test_dataloader=test_dataloader, logger_test=logger_test, save_dir=save_dir ,save_dir_human=save_dir_human ,save_dir_machine=save_dir_machine) 70 | 71 | 72 | if __name__ == '__main__': 73 | main() 74 | 75 | -------------------------------------------------------------------------------- /playground/test_condi.py: -------------------------------------------------------------------------------- 1 | from re import T 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import os 6 | import logging 7 | import sys 8 | 9 | 10 | 11 | sys.path.append("..") 12 | from utils.train_data import TestData, testDataset_collate, NewTestData, TestImagenet 13 | from models.mlicpp_con import ConditionalMLICPlusPlus 14 | from config.args import test_options 15 | from config.config import model_config 16 | from compressai.datasets import ImageFolder 17 | from torchvision import transforms 18 | from torch.utils.data import DataLoader 19 | from PIL import ImageFile, Image 20 | from models import * 21 | from utils.testing import test_model 22 | from utils.logger import setup_logger 23 | from utils.newtrain_data import TrainData 24 | 25 | def main(): 26 | ImageFile.LOAD_TRUNCATED_IMAGES = True 27 | Image.MAX_IMAGE_PIXELS = None 28 | 29 | args = test_options() 30 | config = model_config() 31 | 32 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) 33 | 34 | torch.backends.cudnn.deterministic = True 35 | 36 | if not os.path.exists(os.path.join('./experiments', args.experiment)): 37 | os.makedirs(os.path.join('./experiments', args.experiment)) 38 | setup_logger('test', os.path.join('./experiments', args.experiment), 'test_' + args.experiment, level=logging.INFO, 39 | screen=True, tofile=True) 40 | logger_test = logging.getLogger('test') 41 | 42 | 43 | # test_dataset = TestData("../voc_norm_test.txt") 44 | test_dataset = NewTestData("../../data/segment") 45 | # test_dataset = TestImagenet("../../data/liuquan/Imagenet_val25k/val_set.txt") 46 | # test_dataset = NewTestData("../../yolov3/coco/val2017") 47 | # test_dataset = NewTestData("../../data/Kodak24") 48 | test_dataloader = DataLoader( 49 | test_dataset, 50 | batch_size=args.test_batch_size, 51 | num_workers=args.num_workers, 52 | shuffle=False, 53 | pin_memory=True, 54 | collate_fn=testDataset_collate, 55 | ) 56 | 57 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" 58 | 59 | net = ConditionalMLICPlusPlus(config=config) 60 | net.beta= args.beta 61 | net = net.to(device) 62 | checkpoint = torch.load(args.checkpoint) 63 | # new_ckpt = modify_checkpoint(checkpoint['state_dict']) 64 | net.load_state_dict(checkpoint['state_dict']) 65 | epoch = checkpoint["epoch"] 66 | logger_test.info(f"Start testing!" ) 67 | save_dir = os.path.join('./experiments', args.experiment) 68 | if not os.path.exists(save_dir): 69 | os.makedirs(save_dir) 70 | test_model(net=net, test_dataloader=test_dataloader, logger_test=logger_test, save_dir=save_dir, epoch=epoch) 71 | 72 | 73 | if __name__ == '__main__': 74 | main() 75 | 76 | -------------------------------------------------------------------------------- /modules/transform/analysis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from compressai.layers import subpel_conv3x3 5 | from modules.layers.conv import conv1x1, conv3x3, conv, deconv 6 | from modules.layers.res_blk import * 7 | 8 | 9 | class AnalysisTransform(nn.Module): 10 | def __init__(self, N, M): 11 | super().__init__() 12 | self.analysis_transform = nn.Sequential( 13 | ResidualBlockWithStride(3, N, stride=2), 14 | ResidualBlock(N, N), 15 | ResidualBlockWithStride(N, N, stride=2), 16 | ResidualBlock(N, N), 17 | ResidualBlockWithStride(N, N, stride=2), 18 | ResidualBlock(N, N), 19 | conv3x3(N, M, stride=2) 20 | ) 21 | 22 | def forward(self, x): 23 | x = self.analysis_transform(x) 24 | 25 | return x 26 | 27 | 28 | class HyperAnalysis(nn.Module): 29 | """ 30 | Local reference 31 | """ 32 | def __init__(self, M=192, N=192): 33 | super().__init__() 34 | self.M = M 35 | self.N = N 36 | self.reduction = nn.Sequential( 37 | conv3x3(M, N), 38 | nn.GELU(), 39 | conv3x3(N, N), 40 | nn.GELU(), 41 | conv3x3(N, N, stride=2), 42 | nn.GELU(), 43 | conv3x3(N, N), 44 | nn.GELU(), 45 | conv3x3(N, N, stride=2), 46 | ) 47 | 48 | def forward(self, x): 49 | x = self.reduction(x) 50 | 51 | return x 52 | 53 | class AnalysisTransformEX(nn.Module): 54 | def __init__(self, N, M, act=nn.GELU): 55 | super().__init__() 56 | self.analysis_transform = nn.Sequential( 57 | conv(3, N), 58 | ResidualBottleneck(N, act=act, groups=N * 2), 59 | ResidualBottleneck(N, act=act, groups=N * 2), 60 | ResidualBottleneck(N, act=act, groups=N * 2), 61 | conv(N, N), 62 | ResidualBottleneck(N, act=act, groups=N * 2), 63 | ResidualBottleneck(N, act=act, groups=N * 2), 64 | ResidualBottleneck(N, act=act, groups=N * 2), 65 | AttentionBlock(N), 66 | conv(N, N), 67 | ResidualBottleneck(N, act=act, groups=N * 2), 68 | ResidualBottleneck(N, act=act, groups=N * 2), 69 | ResidualBottleneck(N, act=act, groups=N * 2), 70 | conv(N, M), 71 | AttentionBlock(M) 72 | ) 73 | 74 | def forward(self, x): 75 | x = self.analysis_transform(x) 76 | return x 77 | 78 | 79 | class HyperAnalysisEX(nn.Module): 80 | def __init__(self, N, M, act=nn.GELU) -> None: 81 | super().__init__() 82 | self.M = M 83 | self.N = N 84 | self.reduction = nn.Sequential( 85 | conv3x3(M, N), 86 | act(), 87 | conv(N, N), 88 | act(), 89 | conv(N, N) 90 | ) 91 | 92 | def forward(self, x): 93 | x = self.reduction(x) 94 | return x -------------------------------------------------------------------------------- /models/fourier_cond.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BetaMlp(nn.Module): 6 | def __init__(self, channels=512, act_layer=nn.ReLU): 7 | super(BetaMlp, self).__init__() 8 | self.output_features = channels 9 | self.fc1 = nn.Linear(21, self.output_features) 10 | self.act = act_layer() 11 | self.fc2 = nn.Linear(self.output_features, self.output_features) 12 | 13 | def forward(self, x): 14 | 15 | x = self.fc1(x) 16 | x = self.act(x) 17 | x = self.fc2(x) 18 | x = self.act(x) 19 | return x 20 | 21 | class Embedder: 22 | def __init__(self, **kwargs): 23 | self.kwargs = kwargs 24 | self.create_embedding_fn() 25 | 26 | def create_embedding_fn(self): 27 | embed_fns = [] 28 | d = self.kwargs['input_dims'] 29 | out_dim = 0 30 | if self.kwargs['include_input']: 31 | embed_fns.append(lambda x: x) 32 | out_dim += d 33 | 34 | max_freq = self.kwargs['max_freq_log2'] 35 | N_freqs = self.kwargs['num_freqs'] 36 | 37 | if self.kwargs['log_sampling']: 38 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) 39 | else: 40 | freq_bands = torch.linspace(2. ** 0., 2. ** max_freq, N_freqs) 41 | 42 | for freq in freq_bands: 43 | for p_fn in self.kwargs['periodic_fns']: 44 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 45 | out_dim += d 46 | 47 | self.embed_fns = embed_fns 48 | self.out_dim = out_dim 49 | 50 | def embed(self, inputs): 51 | # 将 self.embed_fns 中的每个函数应用于 inputs,并将结果堆叠起来 52 | stacked_outputs = torch.stack([fn(inputs) for fn in self.embed_fns]) 53 | 54 | # 对堆叠后的张量进行转置操作 55 | transposed_outputs = stacked_outputs.permute(1, 0) 56 | return transposed_outputs 57 | 58 | def get_embedder(multires, i=0): 59 | if i == -1: 60 | return lambda x: x, 3 61 | 62 | embed_kwargs = { 63 | 'include_input': True, 64 | 'input_dims': 1, 65 | 'max_freq_log2': multires - 1, 66 | 'num_freqs': multires, 67 | 'log_sampling': True, 68 | 'periodic_fns': [torch.sin, torch.cos], 69 | } 70 | 71 | embedder_obj = Embedder(**embed_kwargs) 72 | 73 | def embed(x, eo=embedder_obj): return eo.embed(x) 74 | 75 | return embed, embedder_obj.out_dim 76 | 77 | # def build_global_conditioning(): 78 | # beta = torch.randn(1) # Dummy input, replace with actual input tensor 79 | # 80 | # embed_fn, input_ch = get_embedder(multires=10, i=0) 81 | # beta_mlp = BetaMlp() 82 | # fourier_features = embed_fn(beta) 83 | # fourier_features_mlp = beta_mlp(fourier_features) 84 | # 85 | # print(fourier_features_mlp) 86 | # 87 | # return beta_mlp # Return the model instance 88 | 89 | class GlobalConditioning(nn.Module): 90 | def __init__(self): 91 | super(GlobalConditioning, self).__init__() 92 | self.beta_mlp = BetaMlp() 93 | self.embed_fn, self.input_ch = get_embedder(multires=10, i=0) # Assuming get_embedder is defined elsewhere 94 | 95 | def forward(self, beta): 96 | fourier_features = self.embed_fn(beta) 97 | fourier_features_mlp = self.beta_mlp(fourier_features) 98 | return fourier_features_mlp -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import PIL.Image as Image 2 | import shutil 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import struct 7 | from pathlib import Path 8 | from torchvision.transforms import ToPILImage 9 | import json 10 | 11 | 12 | """ configuration json """ 13 | class Config(dict): 14 | __getattr__ = dict.__getitem__ 15 | __setattr__ = dict.__setitem__ 16 | 17 | @classmethod 18 | def load(cls, file): 19 | with open(file, 'r') as f: 20 | config = json.loads(f.read()) 21 | return Config(config) 22 | 23 | 24 | def write_uchars(fd, values, fmt=">{:d}B"): 25 | fd.write(struct.pack(fmt.format(len(values)), *values)) 26 | return len(values) * 1 27 | 28 | 29 | def write_uints(fd, values, fmt=">{:d}I"): 30 | fd.write(struct.pack(fmt.format(len(values)), *values)) 31 | return len(values) * 4 32 | 33 | 34 | def read_uints(fd, n, fmt=">{:d}I"): 35 | sz = struct.calcsize("I") 36 | return struct.unpack(fmt.format(n), fd.read(n * sz)) 37 | 38 | 39 | def read_uchars(fd, n, fmt=">{:d}B"): 40 | sz = struct.calcsize("B") 41 | return struct.unpack(fmt.format(n), fd.read(n * sz)) 42 | 43 | 44 | def write_bytes(fd, values, fmt=">{:d}s"): 45 | if len(values) == 0: 46 | return 47 | fd.write(struct.pack(fmt.format(len(values)), values)) 48 | return len(values) * 1 49 | 50 | 51 | def read_bytes(fd, n, fmt=">{:d}s"): 52 | sz = struct.calcsize("s") 53 | return struct.unpack(fmt.format(n), fd.read(n * sz))[0] 54 | 55 | 56 | def read_body(fd): 57 | lstrings = [] 58 | shape = read_uints(fd, 2) 59 | n_strings = read_uints(fd, 1)[0] 60 | for _ in range(n_strings): 61 | s = read_bytes(fd, read_uints(fd, 1)[0]) 62 | lstrings.append([s]) 63 | 64 | return lstrings, shape 65 | 66 | 67 | def write_body(fd, shape, out_strings): 68 | bytes_cnt = 0 69 | bytes_cnt = write_uints(fd, (shape[0], shape[1], len(out_strings))) 70 | for s in out_strings: 71 | bytes_cnt += write_uints(fd, (len(s[0]),)) 72 | bytes_cnt += write_bytes(fd, s[0]) 73 | return bytes_cnt 74 | 75 | 76 | def filesize(filepath: str) -> int: 77 | if not Path(filepath).is_file(): 78 | raise ValueError(f'Invalid file "{filepath}".') 79 | return Path(filepath).stat().st_size 80 | 81 | 82 | def torch2img(x: torch.Tensor) -> Image.Image: 83 | return ToPILImage()(x.clamp_(0, 1).squeeze()) 84 | 85 | 86 | class AverageMeter: 87 | """Compute running average.""" 88 | 89 | def __init__(self): 90 | self.val = 0 91 | self.avg = 0 92 | self.sum = 0 93 | self.count = 0 94 | 95 | def update(self, val, n=1): 96 | self.val = val 97 | self.sum += val * n 98 | self.count += n 99 | self.avg = self.sum / self.count 100 | 101 | 102 | class CustomDataParallel(nn.DataParallel): 103 | """Custom DataParallel to access the module methods.""" 104 | 105 | def __getattr__(self, key): 106 | try: 107 | return super().__getattr__(key) 108 | except AttributeError: 109 | return getattr(self.module, key) 110 | 111 | 112 | def save_checkpoint(state, is_best, filename="checkpoint.pth.tar"): 113 | torch.save(state, filename) 114 | if is_best: 115 | best_filename = filename.replace(filename.split('/')[-1], "checkpoint_best_loss.pth.tar") 116 | shutil.copyfile(filename, best_filename) 117 | 118 | 119 | def split_data(source_path, destination_path, train_file): 120 | f = open(train_file, encoding='utf-8') 121 | while True: 122 | line = f.readline() 123 | if line: 124 | line = line.strip('\n') 125 | print(line) 126 | img_path = source_path + line 127 | shutil.move(img_path, destination_path + line) 128 | -------------------------------------------------------------------------------- /modules/transform/synthesis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from compressai.layers import subpel_conv3x3, AttentionBlock 5 | from modules.layers.conv import conv1x1, conv3x3, conv, deconv 6 | from modules.layers.res_blk import * 7 | 8 | 9 | class HyperSynthesis(nn.Module): 10 | """ 11 | Local Reference 12 | """ 13 | def __init__(self, M=192, N=192) -> None: 14 | super().__init__() 15 | self.M = M 16 | self.N = N 17 | 18 | self.increase = nn.Sequential( 19 | conv3x3(N, M), 20 | nn.GELU(), 21 | subpel_conv3x3(M, M, 2), 22 | nn.GELU(), 23 | conv3x3(M, M * 3 // 2), 24 | nn.GELU(), 25 | subpel_conv3x3(M * 3 // 2, M * 3 // 2, 2), 26 | nn.GELU(), 27 | conv3x3(M * 3 // 2, M * 2), 28 | ) 29 | 30 | def forward(self, x): 31 | x = self.increase(x) 32 | 33 | return x 34 | 35 | 36 | class SynthesisTransform(nn.Module): 37 | def __init__(self, N, M): 38 | super().__init__() 39 | self.synthesis_transform = nn.Sequential( 40 | ResidualBlock(M, N), 41 | ResidualBlockUpsample(N, N, 2), 42 | ResidualBlock(N, N), 43 | ResidualBlockUpsample(N, N, 2), 44 | ResidualBlock(N, N), 45 | ResidualBlockUpsample(N, N, 2), 46 | ResidualBlock(N, N), 47 | subpel_conv3x3(N, 3, 2), 48 | ) 49 | 50 | def forward(self, x): 51 | x = self.synthesis_transform(x) 52 | 53 | return x 54 | 55 | class ConditionalSynthesisTransform(nn.Module): 56 | def __init__(self, N, M): 57 | super().__init__() 58 | self.synthesis_transform = nn.Sequential( 59 | PCDM(M, N), 60 | ResidualBlockUpsample(N, N, 2), 61 | PCDM(N, N), 62 | ResidualBlockUpsample(N, N, 2), 63 | PCDM(N, N), 64 | ResidualBlockUpsample(N, N, 2), 65 | PCDM(N, N), 66 | subpel_conv3x3(N, 3, 2), 67 | ) 68 | 69 | def forward(self, x): 70 | x, fourier_features_mlp = x 71 | for block in self.synthesis_transform: 72 | if isinstance(block, PCDM): 73 | x = block((x, fourier_features_mlp)) 74 | else: 75 | x = block(x) 76 | 77 | return x 78 | 79 | 80 | class SynthesisTransformEX(nn.Module): 81 | def __init__(self, N, M, act=nn.GELU) -> None: 82 | super().__init__() 83 | self.synthesis_transform = nn.Sequential( 84 | AttentionBlock(M), 85 | deconv(M, N), 86 | ResidualBottleneck(N, act=act, groups=N * 2), 87 | ResidualBottleneck(N, act=act, groups=N * 2), 88 | ResidualBottleneck(N, act=act, groups=N * 2), 89 | deconv(N, N), 90 | AttentionBlock(N), 91 | ResidualBottleneck(N, act=act, groups=N * 2), 92 | ResidualBottleneck(N, act=act, groups=N * 2), 93 | ResidualBottleneck(N, act=act), 94 | deconv(N, N), 95 | ResidualBottleneck(N, act=act, groups=N * 2), 96 | ResidualBottleneck(N, act=act, groups=N * 2), 97 | ResidualBottleneck(N, act=act, groups=N * 2), 98 | deconv(N, 3) 99 | ) 100 | 101 | def forward(self, x): 102 | x = self.synthesis_transform(x) 103 | return x 104 | 105 | 106 | class HyperSynthesisEX(nn.Module): 107 | def __init__(self, N, M, act=nn.GELU) -> None: 108 | super().__init__() 109 | self.increase = nn.Sequential( 110 | deconv(N, M), 111 | act(), 112 | deconv(M, M * 3 // 2), 113 | act(), 114 | deconv(M * 3 // 2, M * 2, kernel_size=3, stride=1), 115 | ) 116 | 117 | def forward(self, x): 118 | x = self.increase(x) 119 | return x -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unified Coding for Both Human Perception and Generalized Machine Analytics with CLIP Supervision(UG-ICM) 2 | This repo contains the official PyTorch implementation for the paper “Unified Coding for Both Human Perception and Generalized Machine Analytics with CLIP Supervision”. 3 | 4 | ## Updates 5 | 6 | #### 2025/11/17 7 | The training script has been open-sourced. 8 | 9 | #### 2024/12/13 10 | The training script is about to be open-sourced. 11 | 12 | #### 2024/12/10 13 | *Unified Coding for Both Human Perception and Generalized Machine Analytics with CLIP Supervision(UG-ICM)* is accepted at **AAAI 2025**! 14 | 15 | 16 | ## Abstract 17 | The image compression model has long struggled with adaptability and generalization, as the decoded bitstream typically serves only human or machine needs and fails to preserve information for unseen visual tasks. Therefore, this paper innovatively introduces supervision obtained from multimodal pre-training models and incorporates adaptive multi-objective optimization tailored to support both human visual perception and machine vision simultaneously with a single bitstream, denoted as Unified and Generalized Image Coding for Machine (UG-ICM). Specifically, to get rid of the reliance between compression models with downstream task supervision, we introduce Contrastive Language-Image Pre-training(CLIP) models into the training constraint for improved generalization. Global-to-instance-wise CLIP supervision is applied to help obtain hierarchical semantics that make models more generalizable for the tasks relying on the information of different granularity. Furthermore, for supporting both human and machine visions with only a unifying bitstream, we incorporate a conditional decoding strategy that takes as conditions human or machine preferences, enabling the bitstream to be decoded into different versions for corresponding preferences. As such, our proposed UG-ICM is fully trained in a self-supervised manner, i.e., without awareness of any specific downstream models and tasks. The extensive experiments have shown that the proposed UG-ICM is capable of achieving remarkable improvements in various unseen machine analytics tasks, while simultaneously providing perceptually satisfying images. 18 | 19 | ![image-20240309205241968](./img/image2.png) 20 | 21 | ![image-20240309205241969](./img/image3.png) 22 | 23 | ## Environment 24 | 25 | Pytorch 3.7.16 26 | 27 | CompressAI 1.2.0b3 28 | 29 | Pytorch 1.13.0 30 | 31 | ## Weights 32 | 33 | 34 |
35 | 36 | | | Link | 37 | |:--------:|:--------:| 38 | | UG-ICM| [BaiDu Drive](https://pan.baidu.com/s/1bEcDDbiSIPj67yVrUt6tGQ?pwd=5doz) | 39 | 40 |
41 | 42 | ## Training: 43 | Stage One 44 | ```python 45 | cd ./playground && python train_condi.py --metrics mse --exp mlicpp_condi_q1 --gpu_id 0 --lambda 0.0022 --lambda_beta1 0 --lambda_clip 22 -lr 1e-5 --seed 2000 --batch-size 128 --test-batch-size 128 --tune -e 20 46 | ``` 47 | Stage Two 48 | ```python 49 | cd ./playground && python train_mask.py --metrics mse --exp mlicpp_condi_q4_tune --gpu_id 0 --lambda 0.001 --lambda_beta1 1 --lambda_clip 1 -lr 1e-5 --seed 2000 --batch-size 128 --test-batch-size 128 --tune -e 20 -c experiments/mlicpp_condi_q4_2/checkpoints/checkpoint_best_loss.pth.tar 50 | ``` 51 | ## Testing: 52 | 53 | ### Compress: 54 | Encode and compress the test images, and obtain the decoded images for human and the decoded images for machine from the unified bitstreams. 55 | ```python 56 | cd tests 57 | 58 | python test.py -exp test --gpu_id 0 -c /path/to/checkpoint -d /path/to/dataset 59 | ``` 60 | 61 | ### Classification: 62 | #### Dataset: ImageNet-1k 63 | 64 | The performance of decoded images for classification task (modify the decoded image path in `test.py`). 65 | ```python 66 | cd classification 67 | 68 | python test.py 69 | ``` 70 | 71 | ### Instance segmantation: 72 | #### Dataset: COCO2017_val 73 | 74 | The performance of decoded images for instance segmantation task (modify the decoded image path in `test.py`). 75 | ```python 76 | cd instance segmantation 77 | 78 | python test.py 79 | ``` 80 | 81 | ## License 82 | 83 | [MIT License](https://opensource.org/licenses/MIT) 84 | 85 | ## Acknowledgments 86 | Thanks [Compressai](https://github.com/InterDigitalInc/CompressAI), [MLIC](https://github.com/JiangWeibeta/MLIC), [CLIP](https://github.com/openai/CLIP), [TransTIC](https://github.com/NYCU-MAPL/TransTIC) for their public code and released models. 87 | 88 | 89 | -------------------------------------------------------------------------------- /utils/func.py: -------------------------------------------------------------------------------- 1 | # From CompresssAI 2 | # Modified by Wei Jiang 3 | 4 | import torch 5 | import torch.nn as nn 6 | import math 7 | import einops 8 | 9 | def calc_params(model): 10 | """ 11 | Calculate the number of parameters in a model 12 | """ 13 | return sum(p.numel() for p in model.parameters()) 14 | 15 | 16 | def get_scale_table( 17 | min=0.11, max=256, levels=64 18 | ): # pylint: disable=W0622 19 | return torch.exp(torch.linspace(math.log(min), math.log(max), levels)) # 为什么要先ln再求e次方,是为了更高的精度吗? 20 | 21 | 22 | def find_named_module(module, query): 23 | """Helper function to find a named module. Returns a `nn.Module` or `None` 24 | 25 | Args: 26 | module (nn.Module): the root module 27 | query (str): the module name to find 28 | 29 | Returns: 30 | nn.Module or None 31 | """ 32 | 33 | return next((m for n, m in module.named_modules() if n == query), None) 34 | 35 | 36 | def find_named_buffer(module, query): 37 | """Helper function to find a named buffer. Returns a `torch.Tensor` or `None` 38 | 39 | Args: 40 | module (nn.Module): the root module 41 | query (str): the buffer name to find 42 | 43 | Returns: 44 | torch.Tensor or None 45 | """ 46 | return next((b for n, b in module.named_buffers() if n == query), None) 47 | 48 | 49 | def _update_registered_buffer( 50 | module, 51 | buffer_name, 52 | state_dict_key, 53 | state_dict, 54 | policy="resize_if_empty", 55 | dtype=torch.int, 56 | ): 57 | new_size = state_dict[state_dict_key].size() 58 | registered_buf = find_named_buffer(module, buffer_name) 59 | 60 | if policy in ("resize_if_empty", "resize"): 61 | if registered_buf is None: 62 | raise RuntimeError(f'buffer "{buffer_name}" was not registered') 63 | 64 | if policy == "resize" or registered_buf.numel() == 0: 65 | registered_buf.resize_(new_size) 66 | 67 | elif policy == "register": 68 | if registered_buf is not None: 69 | raise RuntimeError(f'buffer "{buffer_name}" was already registered') 70 | 71 | module.register_buffer(buffer_name, torch.empty(new_size, dtype=dtype).fill_(0)) 72 | 73 | else: 74 | raise ValueError(f'Invalid policy "{policy}"') 75 | 76 | 77 | def update_registered_buffers( 78 | module, 79 | module_name, 80 | buffer_names, 81 | state_dict, 82 | policy="resize_if_empty", 83 | dtype=torch.int, 84 | ): 85 | """Update the registered buffers in a module according to the tensors sized 86 | in a state_dict. 87 | 88 | (There's no way in torch to directly load a buffer with a dynamic size) 89 | 90 | Args: 91 | module (nn.Module): the module 92 | module_name (str): module name in the state dict 93 | buffer_names (list(str)): list of the buffer names to resize in the module 94 | state_dict (dict): the state dict 95 | policy (str): Update policy, choose from 96 | ('resize_if_empty', 'resize', 'register') 97 | dtype (dtype): Type of buffer to be registered (when policy is 'register') 98 | """ 99 | valid_buffer_names = [n for n, _ in module.named_buffers()] 100 | for buffer_name in buffer_names: 101 | if buffer_name not in valid_buffer_names: 102 | raise ValueError(f'Invalid buffer name "{buffer_name}"') 103 | 104 | for buffer_name in buffer_names: 105 | _update_registered_buffer( 106 | module, 107 | buffer_name, 108 | f"{module_name}.{buffer_name}", 109 | state_dict, 110 | policy, 111 | dtype, 112 | ) 113 | 114 | def cal_params(model): 115 | """ 116 | Calculate the number of parameters in a model 117 | """ 118 | return sum(p.numel() for p in model.parameters()) 119 | 120 | 121 | 122 | def image2patch(x, patch_size): 123 | """Image to patches.""" 124 | batch, channels, height, width = x.shape 125 | grid_height = height // patch_size[0] 126 | grid_width = width // patch_size[1] 127 | x = einops.rearrange( 128 | x, "n c (gh fh) (gw fw) -> n c (gh gw) (fh fw)", 129 | gh=grid_height, gw=grid_width, fh=patch_size[0], fw=patch_size[1]) 130 | return x 131 | 132 | 133 | def patch2image(x, grid_size, patch_size): 134 | """patches to images.""" 135 | x = einops.rearrange( 136 | x, "n c (gh gw) (fh fw) -> n c (gh fh) (gw fw)", 137 | gh=grid_size[0], gw=grid_size[1], fh=patch_size[0], fw=patch_size[1]) 138 | return x 139 | -------------------------------------------------------------------------------- /utils/newtrain_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | paper: GridDehazeNet: Attention-Based Multi-Scale Network for Image Dehazing 3 | file: train_data.py 4 | about: build the training dataset 5 | author: Xiaohong Liu 6 | date: 01/08/19 7 | """ 8 | 9 | # --- Imports --- # 10 | import gzip 11 | import random 12 | 13 | import numpy as np 14 | import torch 15 | import torch.utils.data as data 16 | from PIL import Image 17 | import torch.nn.functional as F 18 | 19 | from torchvision.transforms import Compose, ToTensor, Resize, RandomCrop 20 | 21 | 22 | def compute_padding(in_h: int, in_w: int, *, out_h=None, out_w=None, min_div=1): 23 | """Returns tuples for padding and unpadding. 24 | 25 | Args: 26 | in_h: Input height. 27 | in_w: Input width. 28 | out_h: Output height. 29 | out_w: Output width. 30 | min_div: Length that output dimensions should be divisible by. 31 | """ 32 | if out_h is None: 33 | out_h = (in_h + min_div - 1) // min_div * min_div 34 | if out_w is None: 35 | out_w = (in_w + min_div - 1) // min_div * min_div 36 | 37 | if out_h % min_div != 0 or out_w % min_div != 0: 38 | raise ValueError( 39 | f"Padded output height and width are not divisible by min_div={min_div}." 40 | ) 41 | 42 | left = (out_w - in_w) // 2 43 | right = out_w - in_w - left 44 | top = (out_h - in_h) // 2 45 | bottom = out_h - in_h - top 46 | 47 | pad = (left, right, top, bottom) 48 | unpad = (-left, -right, -top, -bottom) 49 | 50 | return pad, unpad 51 | 52 | def pad(x, p=2**6): 53 | h, w = x.size(1), x.size(2) 54 | pad, _ = compute_padding(h, w, min_div=p) 55 | return F.pad(x, pad, mode="constant", value=0) 56 | 57 | 58 | # --- Training dataset --- # 59 | class TrainData(data.Dataset): 60 | def __init__(self, train_data_dir): 61 | super().__init__() 62 | train_list = train_data_dir 63 | with open(train_list) as f: 64 | txt = f.readlines() 65 | annotations = [line.strip() for line in txt] 66 | 67 | self.annotations = annotations 68 | 69 | def get_images(self, index): 70 | seed = torch.random.seed() 71 | 72 | 73 | img_name = self.annotations[index] 74 | 75 | 76 | image = Image.open("../../../coco2voc/coco2voc/" + 'JPEGImages/' + str(img_name) + '.jpg') 77 | # --- Transform to tensor --- # 78 | transform1 = Compose([RandomCrop((256, 256)), ToTensor()]) 79 | transform2 = Compose([Resize((256, 256)), ToTensor()]) 80 | image_width, image_height = image.size 81 | if image_width >= 256 and image_height >= 256: 82 | torch.random.manual_seed(seed) 83 | image = transform1(image) 84 | else: 85 | image = transform2(image) 86 | 87 | return image 88 | 89 | def __getitem__(self, index): 90 | res = self.get_images(index) 91 | return res 92 | 93 | def __len__(self): 94 | return len(self.annotations) 95 | 96 | class TrainDataWithMask(data.Dataset): 97 | def __init__(self, train_data_dir): 98 | super().__init__() 99 | train_list = train_data_dir 100 | with open(train_list) as f: 101 | txt = f.readlines() 102 | annotations = [line.strip() for line in txt] 103 | 104 | self.annotations = annotations 105 | 106 | def get_images(self, index): 107 | seed = torch.random.seed() 108 | 109 | img_name = self.annotations[index] 110 | 111 | image1 = Image.open("../../../coco2voc/coco2voc/" + 'JPEGImages/' + str(img_name) + '.jpg') 112 | image2 = image1.copy() 113 | # --- Transform to tensor --- # 114 | transform1 = Compose([RandomCrop((256, 256)), ToTensor()]) 115 | transform2 = Compose([Resize((256, 256)), ToTensor()]) 116 | image_width, image_height = image1.size 117 | if image_width >= 256 and image_height >= 256: 118 | torch.random.manual_seed(seed) 119 | image1 = transform1(image1) 120 | else: 121 | image1 = transform2(image1) 122 | image2 = transform2(image2) 123 | with gzip.open("../../cocomask/cocomask/" + str(img_name) + '.pth.gz', 'rb') as f: 124 | MaskList = torch.load(f) 125 | 126 | return image1, image2, MaskList 127 | 128 | 129 | 130 | def __getitem__(self, index): 131 | res = self.get_images(index) 132 | return res 133 | 134 | def __len__(self): 135 | return len(self.annotations) 136 | 137 | 138 | # DataLoader中collate_fn使用 139 | def dataset_collate(batch): 140 | 141 | imagesets = [] 142 | for imageset in batch: 143 | imagesets.append(imageset[0]) 144 | imagesets = torch.from_numpy(np.array([item.numpy() for item in imageset])).type(torch.FloatTensor) 145 | print(imagesets.shape) 146 | return imagesets 147 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /utils/ckbd.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from compressai.entropy_models import GaussianConditional, EntropyModel 4 | 5 | 6 | def ckbd_split(y): 7 | """ 8 | Split y to anchor and non-anchor 9 | anchor : 10 | 0 1 0 1 0 11 | 1 0 1 0 1 12 | 0 1 0 1 0 13 | 1 0 1 0 1 14 | 0 1 0 1 0 15 | non-anchor: 16 | 1 0 1 0 1 17 | 0 1 0 1 0 18 | 1 0 1 0 1 19 | 0 1 0 1 0 20 | 1 0 1 0 1 21 | """ 22 | anchor = ckbd_anchor(y) 23 | nonanchor = ckbd_nonanchor(y) 24 | return anchor, nonanchor 25 | 26 | def ckbd_merge(anchor, nonanchor): 27 | # out = torch.zeros_like(anchor).to(anchor.device) 28 | # out[:, :, 0::2, 0::2] = non_anchor[:, :, 0::2, 0::2] 29 | # out[:, :, 1::2, 1::2] = non_anchor[:, :, 1::2, 1::2] 30 | # out[:, :, 0::2, 1::2] = anchor[:, :, 0::2, 1::2] 31 | # out[:, :, 1::2, 0::2] = anchor[:, :, 1::2, 0::2] 32 | 33 | return anchor + nonanchor 34 | 35 | def ckbd_anchor(y): 36 | anchor = torch.zeros_like(y).to(y.device) 37 | anchor[:, :, 0::2, 1::2] = y[:, :, 0::2, 1::2] 38 | anchor[:, :, 1::2, 0::2] = y[:, :, 1::2, 0::2] 39 | return anchor 40 | 41 | def ckbd_nonanchor(y): 42 | nonanchor = torch.zeros_like(y).to(y.device) 43 | nonanchor[:, :, 0::2, 0::2] = y[:, :, 0::2, 0::2] 44 | nonanchor[:, :, 1::2, 1::2] = y[:, :, 1::2, 1::2] 45 | return nonanchor 46 | 47 | def ckbd_anchor_sequeeze(y): 48 | B, C, H, W = y.shape 49 | anchor = torch.zeros([B, C, H, W // 2]).to(y.device) 50 | anchor[:, :, 0::2, :] = y[:, :, 0::2, 1::2] 51 | anchor[:, :, 1::2, :] = y[:, :, 1::2, 0::2] 52 | return anchor 53 | 54 | def ckbd_nonanchor_sequeeze(y): 55 | B, C, H, W = y.shape 56 | nonanchor = torch.zeros([B, C, H, W // 2]).to(y.device) 57 | nonanchor[:, :, 0::2, :] = y[:, :, 0::2, 0::2] 58 | nonanchor[:, :, 1::2, :] = y[:, :, 1::2, 1::2] 59 | return nonanchor 60 | 61 | def ckbd_anchor_unsequeeze(anchor): 62 | B, C, H, W = anchor.shape 63 | y_anchor = torch.zeros([B, C, H, W * 2]).to(anchor.device) 64 | y_anchor[:, :, 0::2, 1::2] = anchor[:, :, 0::2, :] 65 | y_anchor[:, :, 1::2, 0::2] = anchor[:, :, 1::2, :] 66 | return y_anchor 67 | 68 | def ckbd_nonanchor_unsequeeze(nonanchor): 69 | B, C, H, W = nonanchor.shape 70 | y_nonanchor = torch.zeros([B, C, H, W * 2]).to(nonanchor.device) 71 | y_nonanchor[:, :, 0::2, 0::2] = nonanchor[:, :, 0::2, :] 72 | y_nonanchor[:, :, 1::2, 1::2] = nonanchor[:, :, 1::2, :] 73 | return y_nonanchor 74 | 75 | 76 | def compress_anchor(gaussian_conditional:EntropyModel, anchor, scales_anchor, means_anchor, symbols_list, indexes_list): 77 | # squeeze anchor to avoid non-anchor symbols 78 | anchor_squeeze = ckbd_anchor_sequeeze(anchor) 79 | scales_anchor_squeeze = ckbd_anchor_sequeeze(scales_anchor) 80 | means_anchor_squeeze = ckbd_anchor_sequeeze(means_anchor) 81 | indexes = gaussian_conditional.build_indexes(scales_anchor_squeeze) 82 | anchor_hat = gaussian_conditional.quantize(anchor_squeeze, "symbols", means_anchor_squeeze) 83 | symbols_list.extend(anchor_hat.reshape(-1).tolist()) 84 | indexes_list.extend(indexes.reshape(-1).tolist()) 85 | anchor_hat = ckbd_anchor_unsequeeze(anchor_hat + means_anchor_squeeze) 86 | return anchor_hat 87 | 88 | def compress_nonanchor(gaussian_conditional:EntropyModel, nonanchor, scales_nonanchor, means_nonanchor, symbols_list, indexes_list): 89 | nonanchor_squeeze = ckbd_nonanchor_sequeeze(nonanchor) 90 | scales_nonanchor_squeeze = ckbd_nonanchor_sequeeze(scales_nonanchor) 91 | means_nonanchor_squeeze = ckbd_nonanchor_sequeeze(means_nonanchor) 92 | indexes = gaussian_conditional.build_indexes(scales_nonanchor_squeeze) 93 | nonanchor_hat = gaussian_conditional.quantize(nonanchor_squeeze, "symbols", means_nonanchor_squeeze) 94 | symbols_list.extend(nonanchor_hat.reshape(-1).tolist()) 95 | indexes_list.extend(indexes.reshape(-1).tolist()) 96 | nonanchor_hat = ckbd_nonanchor_unsequeeze(nonanchor_hat + means_nonanchor_squeeze) 97 | return nonanchor_hat 98 | 99 | def decompress_anchor(gaussian_conditional:EntropyModel, scales_anchor, means_anchor, decoder, cdf, cdf_lengths, offsets): 100 | scales_anchor_squeeze = ckbd_anchor_sequeeze(scales_anchor) 101 | means_anchor_squeeze = ckbd_anchor_sequeeze(means_anchor) 102 | indexes = gaussian_conditional.build_indexes(scales_anchor_squeeze) 103 | anchor_hat = decoder.decode_stream(indexes.reshape(-1).tolist(), cdf, cdf_lengths, offsets) 104 | anchor_hat = torch.Tensor(anchor_hat).reshape(scales_anchor_squeeze.shape).to(scales_anchor.device) + means_anchor_squeeze 105 | anchor_hat = ckbd_anchor_unsequeeze(anchor_hat) 106 | return anchor_hat 107 | 108 | def decompress_nonanchor(gaussian_conditional:EntropyModel, scales_nonanchor, means_nonanchor, decoder, cdf, cdf_lengths, offsets): 109 | scales_nonanchor_squeeze = ckbd_nonanchor_sequeeze(scales_nonanchor) 110 | means_nonanchor_squeeze = ckbd_nonanchor_sequeeze(means_nonanchor) 111 | indexes = gaussian_conditional.build_indexes(scales_nonanchor_squeeze) 112 | nonanchor_hat = decoder.decode_stream(indexes.reshape(-1).tolist(), cdf, cdf_lengths, offsets) 113 | nonanchor_hat = torch.Tensor(nonanchor_hat).reshape(scales_nonanchor_squeeze.shape).to(scales_nonanchor.device) + means_nonanchor_squeeze 114 | nonanchor_hat = ckbd_nonanchor_unsequeeze(nonanchor_hat) 115 | return nonanchor_hat 116 | -------------------------------------------------------------------------------- /playground/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | import sys 5 | from PIL import ImageFile, Image 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.utils.data import DataLoader 13 | from torchvision import transforms 14 | from compressai.datasets import ImageFolder 15 | sys.path.append("..") 16 | from utils.logger import setup_logger 17 | from utils.utils import CustomDataParallel, save_checkpoint 18 | from utils.optimizers import configure_optimizers 19 | from utils.training import train_one_epoch 20 | from utils.testing import test_one_epoch 21 | from loss.rd_loss import RateDistortionLoss 22 | from config.args import train_options 23 | from config.config import model_config 24 | from models import * 25 | from utils.newtrain_data import TrainData,dataset_collate 26 | import random 27 | 28 | 29 | def main(): 30 | torch.backends.cudnn.benchmark = True 31 | ImageFile.LOAD_TRUNCATED_IMAGES = True 32 | Image.MAX_IMAGE_PIXELS = None 33 | 34 | args = train_options() 35 | config = model_config() 36 | 37 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) 38 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" 39 | 40 | if args.seed is not None: 41 | seed = args.seed 42 | else: 43 | seed = 100 * random.random() 44 | torch.manual_seed(seed) 45 | random.seed(seed) 46 | 47 | if not os.path.exists(os.path.join('./experiments', args.experiment)): 48 | os.makedirs(os.path.join('./experiments', args.experiment)) 49 | 50 | setup_logger('train', os.path.join('./experiments', args.experiment), 'train_' + args.experiment, level=logging.INFO, 51 | screen=True, tofile=True) 52 | setup_logger('val', os.path.join('./experiments', args.experiment), 'val_' + args.experiment, level=logging.INFO, 53 | screen=True, tofile=True) 54 | 55 | logger_train = logging.getLogger('train') 56 | logger_val = logging.getLogger('val') 57 | tb_logger = SummaryWriter(log_dir='./tb_logger/' + args.experiment) 58 | 59 | if not os.path.exists(os.path.join('./experiments', args.experiment, 'checkpoints')): 60 | os.makedirs(os.path.join('./experiments', args.experiment, 'checkpoints')) 61 | 62 | 63 | train_dataset = TrainData("../new_trainvallist.txt") 64 | val_dataset = TrainData("../new_testlist.txt") 65 | 66 | train_dataloader = DataLoader( 67 | train_dataset, 68 | batch_size=args.batch_size, 69 | num_workers=args.num_workers, 70 | shuffle=True, 71 | pin_memory=(device == "cuda"), 72 | collate_fn=dataset_collate 73 | ) 74 | 75 | test_dataloader = DataLoader( 76 | val_dataset, 77 | batch_size=args.test_batch_size, 78 | num_workers=args.num_workers, 79 | shuffle=False, 80 | pin_memory=(device == "cuda"), 81 | collate_fn=dataset_collate 82 | ) 83 | 84 | net = MLICPlusPlus(config=config) 85 | if args.cuda and torch.cuda.device_count() > 1: 86 | net = CustomDataParallel(net) 87 | net = net.to(device) 88 | optimizer, aux_optimizer = configure_optimizers(net, args) 89 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 100], gamma=0.1) 90 | criterion = RateDistortionLoss(lmbda=args.lmbda, metrics=args.metrics) 91 | 92 | if args.checkpoint != None: 93 | checkpoint = torch.load(args.checkpoint) 94 | # new_ckpt = modify_checkpoint(checkpoint['state_dict']) 95 | net.load_state_dict(checkpoint['state_dict']) 96 | start_epoch = 0 97 | best_loss = 1e10 98 | current_step = 0 99 | # optimizer.load_state_dict(checkpoint['optimizer']) 100 | # aux_optimizer.load_state_dict(checkpoint['aux_optimizer']) 101 | # lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 102 | # lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[450,550], gamma=0.1) 103 | # lr_scheduler._step_count = checkpoint['lr_scheduler']['_step_count'] 104 | # lr_scheduler.last_epoch = checkpoint['lr_scheduler']['last_epoch'] 105 | # print(lr_scheduler.state_dict()) 106 | # start_epoch = checkpoint['epoch'] 107 | # best_loss = checkpoint['loss'] 108 | # current_step = start_epoch * math.ceil(len(train_dataloader.dataset) / args.batch_size) 109 | checkpoint = None 110 | else: 111 | start_epoch = 0 112 | best_loss = 1e10 113 | current_step = 0 114 | 115 | # start_epoch = 0 116 | # best_loss = 1e10 117 | # current_step = 0 118 | 119 | logger_train.info(args) 120 | logger_train.info(config) 121 | logger_train.info(net) 122 | logger_train.info(optimizer) 123 | optimizer.param_groups[0]['lr'] = args.learning_rate 124 | for epoch in range(start_epoch, args.epochs): 125 | logger_train.info(f"Learning rate: {optimizer.param_groups[0]['lr']}") 126 | current_step = train_one_epoch( 127 | net, 128 | criterion, 129 | train_dataloader, 130 | optimizer, 131 | aux_optimizer, 132 | epoch, 133 | args.clip_max_norm, 134 | logger_train, 135 | tb_logger, 136 | current_step 137 | ) 138 | 139 | save_dir = os.path.join('./experiments', args.experiment, 'val_images', '%03d' % (epoch + 1)) 140 | loss = test_one_epoch(epoch, test_dataloader, net, criterion, save_dir, logger_val, tb_logger) 141 | 142 | lr_scheduler.step() 143 | is_best = loss < best_loss 144 | best_loss = min(loss, best_loss) 145 | 146 | net.update(force=True) 147 | if args.save: 148 | save_checkpoint( 149 | { 150 | "epoch": epoch + 1, 151 | "state_dict": net.state_dict(), 152 | "loss": loss, 153 | "optimizer": optimizer.state_dict(), 154 | "aux_optimizer": aux_optimizer.state_dict(), 155 | "lr_scheduler": lr_scheduler.state_dict(), 156 | }, 157 | is_best, 158 | os.path.join('./experiments', args.experiment, 'checkpoints', "checkpoint_%03d.pth.tar" % (epoch + 1)) 159 | ) 160 | if is_best: 161 | logger_val.info('best checkpoint saved.') 162 | 163 | if __name__ == '__main__': 164 | main() 165 | -------------------------------------------------------------------------------- /playground/warmup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | from PIL import ImageFile, Image 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torch.utils.data import DataLoader 12 | from torchvision import transforms 13 | from transformers import get_linear_schedule_with_warmup 14 | from compressai.datasets import ImageFolder 15 | from utils.logger import setup_logger 16 | from utils.utils import CustomDataParallel, save_checkpoint 17 | from utils.optimizers import configure_optimizers 18 | from utils.training import warmup_one_epoch 19 | from utils.testing import test_one_epoch 20 | from loss.rd_loss import RateDistortionLoss 21 | from config.args import train_options 22 | from config.config import model_config 23 | from models import * 24 | import random 25 | 26 | 27 | def main(): 28 | torch.backends.cudnn.benchmark = True 29 | ImageFile.LOAD_TRUNCATED_IMAGES = True 30 | Image.MAX_IMAGE_PIXELS = None 31 | 32 | args = train_options() 33 | config = model_config() 34 | 35 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) 36 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" 37 | 38 | if args.seed is not None: 39 | # seed = 100 * random.random() 40 | seed = args.seed 41 | torch.manual_seed(seed) 42 | random.seed(seed) 43 | 44 | if not os.path.exists(os.path.join('./experiments', args.experiment)): 45 | os.makedirs(os.path.join('./experiments', args.experiment)) 46 | 47 | setup_logger('train', os.path.join('./experiments', args.experiment), 'train_' + args.experiment, level=logging.INFO, 48 | screen=True, tofile=True) 49 | setup_logger('val', os.path.join('./experiments', args.experiment), 'val_' + args.experiment, level=logging.INFO, 50 | screen=True, tofile=True) 51 | 52 | logger_train = logging.getLogger('train') 53 | logger_val = logging.getLogger('val') 54 | tb_logger = SummaryWriter(log_dir='./tb_logger/' + args.experiment) 55 | 56 | if not os.path.exists(os.path.join('./experiments', args.experiment, 'checkpoints')): 57 | os.makedirs(os.path.join('./experiments', args.experiment, 'checkpoints')) 58 | 59 | train_transforms = transforms.Compose( 60 | [transforms.RandomCrop(args.patch_size), transforms.ToTensor()] 61 | ) 62 | test_transforms = transforms.Compose( 63 | [transforms.ToTensor()] 64 | ) 65 | 66 | train_dataset = ImageFolder(args.dataset, split="train", transform=train_transforms) 67 | test_dataset = ImageFolder(args.dataset, split="test", transform=test_transforms) 68 | 69 | train_dataloader = DataLoader( 70 | train_dataset, 71 | batch_size=args.batch_size, 72 | num_workers=args.num_workers, 73 | shuffle=True, 74 | pin_memory=(device == "cuda"), 75 | ) 76 | 77 | test_dataloader = DataLoader( 78 | test_dataset, 79 | batch_size=args.test_batch_size, 80 | num_workers=args.num_workers, 81 | shuffle=False, 82 | pin_memory=(device == "cuda"), 83 | ) 84 | 85 | net = MLICPlusPlus(config=config) 86 | net = torch.compile(net) 87 | if args.cuda and torch.cuda.device_count() > 1: 88 | net = CustomDataParallel(net) 89 | net = net.to(device) 90 | optimizer, aux_optimizer = configure_optimizers(net, args) 91 | # lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, "min") 92 | warmup_steps = len(train_dataloader) * 1 93 | total_steps = len(train_dataloader) * 150 94 | lr_scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps) 95 | criterion = RateDistortionLoss(lmbda=args.lmbda, metrics=args.metrics) 96 | 97 | if args.checkpoint != None: 98 | checkpoint = torch.load(args.checkpoint) 99 | net.load_state_dict(checkpoint["state_dict"]) 100 | optimizer.load_state_dict(checkpoint['optimizer']) 101 | aux_optimizer.load_state_dict(checkpoint['aux_optimizer']) 102 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 103 | # lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[450,550], gamma=0.1) 104 | # lr_scheduler._step_count = checkpoint['lr_scheduler']['_step_count'] 105 | # lr_scheduler.last_epoch = checkpoint['lr_scheduler']['last_epoch'] 106 | # print(lr_scheduler.state_dict()) 107 | start_epoch = checkpoint['epoch'] 108 | best_loss = checkpoint['loss'] 109 | current_step = start_epoch * math.ceil(len(train_dataloader.dataset) / args.batch_size) 110 | checkpoint = None 111 | else: 112 | start_epoch = 0 113 | best_loss = 1e10 114 | current_step = 0 115 | 116 | # start_epoch = 0 117 | # best_loss = 1e10 118 | # current_step = 0 119 | 120 | logger_train.info(args) 121 | logger_train.info(config) 122 | logger_train.info(net) 123 | logger_train.info(optimizer) 124 | for epoch in range(start_epoch, args.epochs): 125 | logger_train.info(f"Learning rate: {optimizer.param_groups[0]['lr']}") 126 | current_step = warmup_one_epoch( 127 | net, 128 | criterion, 129 | train_dataloader, 130 | optimizer, 131 | aux_optimizer, 132 | epoch, 133 | args.clip_max_norm, 134 | logger_train, 135 | tb_logger, 136 | current_step, 137 | lr_scheduler 138 | ) 139 | 140 | save_dir = os.path.join('./experiments', args.experiment, 'val_images', '%03d' % (epoch + 1)) 141 | loss = test_one_epoch(epoch, test_dataloader, net, criterion, save_dir, logger_val, tb_logger) 142 | 143 | is_best = loss < best_loss 144 | best_loss = min(loss, best_loss) 145 | 146 | net.update(force=True) 147 | if args.save: 148 | save_checkpoint( 149 | { 150 | "epoch": epoch + 1, 151 | "state_dict": net.state_dict(), 152 | "loss": loss, 153 | "optimizer": optimizer.state_dict(), 154 | "aux_optimizer": aux_optimizer.state_dict(), 155 | "lr_scheduler": lr_scheduler.state_dict(), 156 | }, 157 | is_best, 158 | os.path.join('./experiments', args.experiment, 'checkpoints', "checkpoint_%03d.pth.tar" % (epoch + 1)) 159 | ) 160 | if is_best: 161 | logger_val.info('best checkpoint saved.') 162 | 163 | if __name__ == '__main__': 164 | main() 165 | -------------------------------------------------------------------------------- /modules/layers/res_blk.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | from compressai.layers import GDN, subpel_conv3x3 5 | from modules.layers.conv import conv1x1, conv3x3 6 | 7 | 8 | class AttentionBlock(nn.Module): 9 | """Self attention block. 10 | 11 | Simplified variant from `"Learned Image Compression with 12 | Discretized Gaussian Mixture Likelihoods and Attention Modules" 13 | `_, by Zhengxue Cheng, Heming Sun, Masaru 14 | Takeuchi, Jiro Katto. 15 | 16 | Args: 17 | N (int): Number of channels) 18 | """ 19 | 20 | def __init__(self, N: int): 21 | super().__init__() 22 | 23 | class ResidualUnit(nn.Module): 24 | """Simple residual unit.""" 25 | 26 | def __init__(self): 27 | super().__init__() 28 | self.conv = nn.Sequential( 29 | conv1x1(N, N // 2), 30 | nn.GELU(), 31 | conv3x3(N // 2, N // 2), 32 | nn.GELU(), 33 | conv1x1(N // 2, N), 34 | ) 35 | self.relu = nn.GELU() 36 | 37 | def forward(self, x: Tensor) -> Tensor: 38 | identity = x 39 | out = self.conv(x) 40 | out += identity 41 | out = self.relu(out) 42 | return out 43 | 44 | self.conv_a = nn.Sequential(ResidualUnit(), ResidualUnit(), ResidualUnit()) 45 | 46 | self.conv_b = nn.Sequential( 47 | ResidualUnit(), 48 | ResidualUnit(), 49 | ResidualUnit(), 50 | conv1x1(N, N), 51 | ) 52 | 53 | def forward(self, x: Tensor) -> Tensor: 54 | identity = x 55 | a = self.conv_a(x) 56 | b = self.conv_b(x) 57 | out = a * torch.sigmoid(b) 58 | out += identity 59 | return out 60 | 61 | 62 | class ResidualBlockWithStride(nn.Module): 63 | """Residual block with a stride on the first convolution. 64 | 65 | Args: 66 | in_ch (int): number of input channels 67 | out_ch (int): number of output channels 68 | stride (int): stride value (default: 2) 69 | """ 70 | 71 | def __init__(self, in_ch: int, out_ch: int, stride: int = 2): 72 | super().__init__() 73 | self.conv1 = conv3x3(in_ch, out_ch, stride=stride) 74 | self.act = nn.GELU() 75 | self.conv2 = conv3x3(out_ch, out_ch) 76 | self.gdn = GDN(out_ch) 77 | if stride != 1 or in_ch != out_ch: 78 | self.skip = conv1x1(in_ch, out_ch, stride=stride) 79 | else: 80 | self.skip = None 81 | 82 | def forward(self, x): 83 | identity = x 84 | out = self.conv1(x) 85 | out = self.act(out) 86 | out = self.conv2(out) 87 | out = self.gdn(out) 88 | 89 | if self.skip is not None: 90 | identity = self.skip(x) 91 | 92 | out += identity 93 | return out 94 | 95 | 96 | class ResidualBlockUpsample(nn.Module): 97 | """Residual block with sub-pixel upsampling on the last convolution. 98 | 99 | Args: 100 | in_ch (int): number of input channels 101 | out_ch (int): number of output channels 102 | upsample (int): upsampling factor (default: 2) 103 | """ 104 | 105 | def __init__(self, in_ch: int, out_ch: int, upsample: int = 2): 106 | super().__init__() 107 | self.subpel_conv = subpel_conv3x3(in_ch, out_ch, upsample) 108 | self.act = nn.GELU() 109 | self.conv = conv3x3(out_ch, out_ch) 110 | self.igdn = GDN(out_ch, inverse=True) 111 | self.upsample = subpel_conv3x3(in_ch, out_ch, upsample) 112 | 113 | def forward(self, x): 114 | identity = x 115 | out = self.subpel_conv(x) 116 | out = self.act(out) 117 | out = self.conv(out) 118 | out = self.igdn(out) 119 | identity = self.upsample(x) 120 | out += identity 121 | return out 122 | 123 | 124 | class ResidualBlock(nn.Module): 125 | """Simple residual block with two 3x3 convolutions. 126 | 127 | Args: 128 | in_ch (int): number of input channels 129 | out_ch (int): number of output channels 130 | """ 131 | 132 | def __init__(self, in_ch: int, out_ch: int): 133 | super().__init__() 134 | self.conv1 = conv3x3(in_ch, out_ch) 135 | self.act = nn.GELU() 136 | self.conv2 = conv3x3(out_ch, out_ch) 137 | if in_ch != out_ch: 138 | self.skip = conv1x1(in_ch, out_ch) 139 | else: 140 | self.skip = None 141 | 142 | def forward(self, x): 143 | identity = x 144 | 145 | out = self.conv1(x) 146 | out = self.act(out) 147 | out = self.conv2(out) 148 | out = self.act(out) 149 | 150 | if self.skip is not None: 151 | identity = self.skip(x) 152 | 153 | out = out + identity 154 | return out 155 | 156 | 157 | class ResidualBottleneck(nn.Module): 158 | def __init__(self, N=192, act=nn.GELU, groups=1) -> None: 159 | super().__init__() 160 | self.branch = nn.Sequential( 161 | conv1x1(N, N // 2), 162 | act(), 163 | nn.Conv2d(N // 2, N // 2, kernel_size=3, stride=1, padding=1), 164 | act(), 165 | conv1x1(N // 2, N) 166 | ) 167 | 168 | def forward(self, x): 169 | out = x + self.branch(x) 170 | 171 | return out 172 | 173 | class PCDM(nn.Module): 174 | """Simple residual block with two 3x3 convolutions. 175 | 176 | Args: 177 | in_ch (int): number of input channels 178 | out_ch (int): number of output channels 179 | """ 180 | 181 | def __init__(self, in_ch: int, out_ch: int): 182 | super().__init__() 183 | self.conv1 = conv3x3(in_ch, out_ch) 184 | self.act = nn.GELU() 185 | self.conv2 = conv3x3(out_ch, out_ch) 186 | if in_ch != out_ch: 187 | self.skip = conv1x1(in_ch, out_ch) 188 | else: 189 | self.skip = None 190 | self._proj1 = nn.Sequential( 191 | nn.Linear(512, out_ch, bias=False) 192 | ) 193 | self._proj2 = nn.Sequential( 194 | nn.Linear(512, out_ch, bias=False) 195 | ) 196 | 197 | def forward(self, x): 198 | x, fourier_features_mlp = x 199 | identity = x 200 | 201 | out = self.conv1(x) 202 | out = self.act(out) 203 | 204 | proj1 = self._proj1(fourier_features_mlp) 205 | out += proj1[:, :, None, None] 206 | 207 | out = self.conv2(out) 208 | out = self.act(out) 209 | 210 | proj2 = self._proj2(fourier_features_mlp) 211 | out += proj2[:, :, None, None] 212 | 213 | 214 | if self.skip is not None: 215 | identity = self.skip(x) 216 | 217 | out = out + identity 218 | return out -------------------------------------------------------------------------------- /playground/train_mask.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import random 4 | import logging 5 | from PIL import ImageFile, Image 6 | import math 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.nn.functional as F 11 | from torch.utils.tensorboard import SummaryWriter 12 | from torch.utils.data import DataLoader 13 | from torchvision import transforms 14 | from compressai.datasets import ImageFolder 15 | import sys 16 | 17 | 18 | 19 | sys.path.append("..") 20 | from models.mlicpp_con import ConditionalMLICPlusPlus 21 | from clip.Closs import CLIPConvLoss 22 | from utils.logger import setup_logger 23 | from utils.newtrain_data import TrainData, dataset_collate, TrainDataWithMask 24 | from utils.utils import CustomDataParallel, save_checkpoint 25 | from utils.optimizers import configure_optimizers, configure_optimizer 26 | from utils.training import train_one_epoch, train_clip_one_epoch, train_conditional_one_epoch, train_mask_conditional_one_epoch 27 | from utils.testing import test_one_epoch, test_clip_one_epoch, test_conditional_one_epoch, newtest_conditional_one_epoch 28 | from loss.rd_loss import RateDistortionLoss 29 | from config.args import train_options 30 | from config.config import model_config 31 | from models import * 32 | import random 33 | 34 | 35 | def main(): 36 | torch.backends.cudnn.benchmark = True 37 | ImageFile.LOAD_TRUNCATED_IMAGES = True 38 | Image.MAX_IMAGE_PIXELS = None 39 | 40 | args = train_options() 41 | config = model_config() 42 | 43 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) 44 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" 45 | 46 | if args.seed is not None: 47 | # seed = args.seed 48 | # else: 49 | seed = 100 * random.random() 50 | torch.manual_seed(seed) 51 | random.seed(seed) 52 | 53 | if not os.path.exists(os.path.join('./experiments', args.experiment)): 54 | os.makedirs(os.path.join('./experiments', args.experiment)) 55 | 56 | setup_logger('train', os.path.join('./experiments', args.experiment), 'train_' + args.experiment, level=logging.INFO, 57 | screen=True, tofile=True) 58 | setup_logger('val', os.path.join('./experiments', args.experiment), 'val_' + args.experiment, level=logging.INFO, 59 | screen=True, tofile=True) 60 | 61 | logger_train = logging.getLogger('train') 62 | logger_val = logging.getLogger('val') 63 | tb_logger = SummaryWriter(log_dir='./tb_logger/' + args.experiment) 64 | 65 | if not os.path.exists(os.path.join('./experiments', args.experiment, 'checkpoints')): 66 | os.makedirs(os.path.join('./experiments', args.experiment, 'checkpoints')) 67 | 68 | train_dataset = TrainDataWithMask("../../../coco2voc/coco2voc/ImageSets/Main/trainval.txt") 69 | test_dataset = TrainData("../../../coco2voc/coco2voc/ImageSets/Main/test.txt") 70 | train_dataloader = DataLoader( 71 | train_dataset, 72 | batch_size=args.batch_size, 73 | num_workers=args.num_workers, 74 | shuffle=True, 75 | pin_memory=(device == "cuda"), 76 | ) 77 | 78 | test_dataloader = DataLoader( 79 | test_dataset, 80 | batch_size=args.test_batch_size, 81 | num_workers=args.num_workers, 82 | shuffle=False, 83 | pin_memory=(device == "cuda"), 84 | ) 85 | 86 | net = ConditionalMLICPlusPlus(config=config) 87 | if args.cuda and torch.cuda.device_count() > 1: 88 | net = CustomDataParallel(net) 89 | net = net.to(device) 90 | optimizer, aux_optimizer = configure_optimizers(net, args) 91 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 100], gamma=0.1) 92 | criterion = RateDistortionLoss(lmbda=args.lmbda, metrics=args.metrics) 93 | clip_loss = CLIPConvLoss(device) 94 | 95 | if args.checkpoint != None: 96 | if args.tune: 97 | checkpoint = torch.load(args.checkpoint) 98 | # new_ckpt = modify_checkpoint(checkpoint['state_dict']) 99 | net.load_state_dict(checkpoint['state_dict']) 100 | for name, param in net.named_parameters(): 101 | if not ('g_s' in name or 'global_cond' in name): 102 | param.requires_grad = False 103 | # 检查哪些参数被冻结,哪些参数是可训练的 104 | # for name, param in net.named_parameters(): 105 | # print(f"{name}: requires_grad={param.requires_grad}") 106 | optimizer = configure_optimizer(net, args) 107 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1) 108 | start_epoch = 0 109 | best_loss = 1e10 110 | current_step = 0 111 | else: 112 | start_epoch = 0 113 | best_loss = 1e10 114 | current_step = 0 115 | 116 | # start_epoch = 0 117 | # best_loss = 1e10 118 | # current_step = 0 119 | 120 | logger_train.info(args) 121 | logger_train.info(config) 122 | logger_train.info(net) 123 | logger_train.info(optimizer) 124 | optimizer.param_groups[0]['lr'] = args.learning_rate 125 | save_dir = os.path.join('./experiments', args.experiment, 'val_images', '%03d' % (1 + 1)) 126 | for epoch in range(start_epoch, args.epochs): 127 | logger_train.info(f"Learning rate: {optimizer.param_groups[0]['lr']}") 128 | current_step = train_mask_conditional_one_epoch( 129 | net, 130 | criterion, 131 | train_dataloader, 132 | optimizer, 133 | aux_optimizer, 134 | args.tune, 135 | epoch, 136 | args.clip_max_norm, 137 | logger_train, 138 | tb_logger, 139 | current_step, 140 | clip_loss, 141 | args.lambda_clip, 142 | args.lmbda, 143 | args.lambda_beta1, 144 | ) 145 | 146 | 147 | loss = test_conditional_one_epoch(epoch, test_dataloader, net, criterion, save_dir, logger_val, tb_logger, clip_loss, args.lambda_clip) 148 | 149 | lr_scheduler.step() 150 | is_best = loss < best_loss 151 | best_loss = min(loss, best_loss) 152 | 153 | net.update(force=True) 154 | if args.save: 155 | save_checkpoint( 156 | { 157 | "epoch": epoch + 1, 158 | "state_dict": net.state_dict(), 159 | "loss": loss, 160 | "optimizer": optimizer.state_dict(), 161 | "aux_optimizer": aux_optimizer.state_dict(), 162 | "lr_scheduler": lr_scheduler.state_dict(), 163 | }, 164 | is_best, 165 | os.path.join('./experiments', args.experiment, 'checkpoints', "checkpoint_%03d.pth.tar" % (epoch + 1)) 166 | ) 167 | if is_best: 168 | logger_val.info('best checkpoint saved.') 169 | 170 | if __name__ == '__main__': 171 | main() -------------------------------------------------------------------------------- /utils/train_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | paper: GridDehazeNet: Attention-Based Multi-Scale Network for Image Dehazing 3 | file: train_data.py 4 | about: build the training dataset 5 | author: Xiaohong Liu 6 | date: 01/08/19 7 | """ 8 | 9 | # --- Imports --- # 10 | import os 11 | import random 12 | 13 | import numpy as np 14 | import torch 15 | import torch.utils.data as data 16 | from PIL import Image 17 | import torch.nn.functional as F 18 | 19 | from torchvision.transforms import Compose, ToTensor, Resize, RandomCrop 20 | 21 | 22 | def compute_padding(in_h: int, in_w: int, *, out_h=None, out_w=None, min_div=1): 23 | """Returns tuples for padding and unpadding. 24 | 25 | Args: 26 | in_h: Input height. 27 | in_w: Input width. 28 | out_h: Output height. 29 | out_w: Output width. 30 | min_div: Length that output dimensions should be divisible by. 31 | """ 32 | if out_h is None: 33 | out_h = (in_h + min_div - 1) // min_div * min_div 34 | if out_w is None: 35 | out_w = (in_w + min_div - 1) // min_div * min_div 36 | 37 | if out_h % min_div != 0 or out_w % min_div != 0: 38 | raise ValueError( 39 | f"Padded output height and width are not divisible by min_div={min_div}." 40 | ) 41 | 42 | left = (out_w - in_w) // 2 43 | right = out_w - in_w - left 44 | top = (out_h - in_h) // 2 45 | bottom = out_h - in_h - top 46 | 47 | pad = (left, right, top, bottom) 48 | unpad = (-left, -right, -top, -bottom) 49 | 50 | return pad, unpad 51 | 52 | def pad(x, p=2**6): 53 | h, w = x.size(1), x.size(2) 54 | pad, _ = compute_padding(h, w, min_div=p) 55 | return F.pad(x, pad, mode="constant", value=0) 56 | 57 | 58 | 59 | class TestData(data.Dataset): 60 | def __init__(self, train_data_dir): 61 | super().__init__() 62 | train_list = train_data_dir 63 | with open(train_list) as f: 64 | txt = f.readlines() 65 | # annotations = [line.strip() for line in txt if len(line.strip().split()[1:]) != 0] 66 | annotations = [line.strip() for line in txt] 67 | 68 | self.annotations = annotations 69 | self.train_data_dir = train_data_dir 70 | 71 | def get_images(self, index): 72 | line = self.annotations[index].split() 73 | image_path = line[0] 74 | # print(image_path) 75 | img_name = image_path.split('/')[-1] 76 | # print(img_name) 77 | image_name = img_name.split('.')[0] 78 | # print(image_name) 79 | gt_name = image_name 80 | pgt_name = image_name 81 | 82 | # --- Transform to tensor --- # 83 | transform = Compose([ToTensor()]) 84 | 85 | gts = [] 86 | haze_img = Image.open("../../yolov3/" + 'VOCdevkit/VOC2007/JPEGImages/' + gt_name +'.jpg' ) 87 | gt = transform(haze_img) 88 | # --- Check the channel is 3 or not --- # 89 | if list(gt.shape)[0] is not 3 : 90 | raise Exception('Bad image channel: {}'.format(gt_name)) 91 | gts.append(gt) 92 | 93 | names = [] 94 | names.append(pgt_name) 95 | 96 | return gts, names 97 | 98 | def __getitem__(self, index): 99 | res = self.get_images(index) 100 | return res 101 | 102 | def __len__(self): 103 | return len(self.annotations) 104 | 105 | 106 | class TestImagenet(data.Dataset): 107 | def __init__(self, train_data_dir): 108 | super().__init__() 109 | train_list = train_data_dir 110 | with open(train_list) as f: 111 | txt = f.readlines() 112 | # annotations = [line.strip() for line in txt if len(line.strip().split()[1:]) != 0] 113 | annotations = [line.strip() for line in txt] 114 | 115 | self.annotations = annotations 116 | self.train_data_dir = train_data_dir 117 | 118 | def get_images(self, index): 119 | image_path = self.annotations[index] 120 | # print(image_path) 121 | img_name = image_path.split('/')[-1] 122 | # print(img_name) 123 | image_name = img_name.split('.')[0] 124 | # print(image_name) 125 | gt_name = image_name 126 | pgt_name = image_name 127 | 128 | # --- Transform to tensor --- # 129 | transform = Compose([ToTensor()]) 130 | 131 | gts = [] 132 | haze_img = Image.open("../../data/liuquan/Imagenet_val25k/" + image_path) 133 | haze_img = haze_img.convert('RGB') 134 | gt = transform(haze_img) 135 | # --- Check the channel is 3 or not --- # 136 | if list(gt.shape)[0] is not 3 : 137 | raise Exception('Bad image channel: {}'.format(gt_name)) 138 | gts.append(gt) 139 | 140 | names = [] 141 | names.append(pgt_name) 142 | 143 | return gts, names 144 | 145 | def __getitem__(self, index): 146 | res = self.get_images(index) 147 | return res 148 | 149 | def __len__(self): 150 | return len(self.annotations) 151 | 152 | class NewTestData(data.Dataset): 153 | def __init__(self, img_dir): 154 | super().__init__() 155 | self.img_dir = img_dir 156 | self.img_labels = [f for f in os.listdir(img_dir) if f.endswith(".JPEG") or f.endswith(".png") or f.endswith(".jpg")] 157 | 158 | def get_images(self, index): 159 | img_label=self.img_labels[index] 160 | image_name = img_label.split('.')[0] 161 | img_path = os.path.join(self.img_dir, img_label) 162 | 163 | # --- Transform to tensor --- # 164 | transform = Compose([ToTensor()]) 165 | 166 | gts = [] 167 | haze_img = Image.open(img_path) 168 | haze_img = haze_img.convert('RGB') 169 | gt = transform(haze_img) 170 | # --- Check the channel is 3 or not --- # 171 | if list(gt.shape)[0] is not 3 : 172 | raise Exception('Bad image channel: {}'.format(image_name)) 173 | gts.append(gt) 174 | 175 | names = [] 176 | names.append(image_name) 177 | 178 | return gts, names 179 | 180 | def __getitem__(self, index): 181 | res = self.get_images(index) 182 | return res 183 | 184 | def __len__(self): 185 | return len(self.img_labels) 186 | 187 | 188 | # DataLoader中collate_fn使用 189 | def dataset_collate(batch): 190 | gts = [] 191 | pgts = [] 192 | for gt, pgt in batch: 193 | gts.append(gt[0]) 194 | pgts.append(pgt[0]) 195 | gts = torch.from_numpy(np.array([item.numpy() for item in gts])).type(torch.FloatTensor) 196 | pgts = torch.from_numpy(np.array([item.numpy() for item in pgts])).type(torch.FloatTensor) 197 | 198 | return gts, pgts 199 | 200 | def testDataset_collate(batch): 201 | gts = [] 202 | names = [] 203 | for gt, name in batch: 204 | gts.append(gt[0]) 205 | names.append(name[0]) 206 | gts = torch.from_numpy(np.array([item.numpy() for item in gts])).type(torch.FloatTensor) 207 | return gts, names -------------------------------------------------------------------------------- /playground/train_condi.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import logging 4 | from PIL import ImageFile, Image 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | from torch.utils.tensorboard import SummaryWriter 11 | from torch.utils.data import DataLoader 12 | from torchvision import transforms 13 | from compressai.datasets import ImageFolder 14 | import sys 15 | 16 | 17 | 18 | sys.path.append("..") 19 | from models.mlicpp_con import ConditionalMLICPlusPlus 20 | from clip.Closs import CLIPConvLoss 21 | from utils.logger import setup_logger 22 | from utils.newtrain_data import TrainData, dataset_collate 23 | from utils.utils import CustomDataParallel, save_checkpoint 24 | from utils.optimizers import configure_optimizers, configure_optimizer 25 | from utils.training import train_one_epoch, train_clip_one_epoch, train_conditional_one_epoch 26 | from utils.testing import test_one_epoch, test_clip_one_epoch, test_conditional_one_epoch, newtest_conditional_one_epoch 27 | from loss.rd_loss import RateDistortionLoss 28 | from config.args import train_options 29 | from config.config import model_config 30 | from models import * 31 | import random 32 | 33 | 34 | def main(): 35 | torch.backends.cudnn.benchmark = True 36 | ImageFile.LOAD_TRUNCATED_IMAGES = True 37 | Image.MAX_IMAGE_PIXELS = None 38 | 39 | args = train_options() 40 | config = model_config() 41 | 42 | os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id) 43 | device = "cuda" if args.cuda and torch.cuda.is_available() else "cpu" 44 | 45 | if args.seed is not None: 46 | # seed = args.seed 47 | # else: 48 | seed = 100 * random.random() 49 | torch.manual_seed(seed) 50 | random.seed(seed) 51 | 52 | if not os.path.exists(os.path.join('./experiments', args.experiment)): 53 | os.makedirs(os.path.join('./experiments', args.experiment)) 54 | 55 | setup_logger('train', os.path.join('./experiments', args.experiment), 'train_' + args.experiment, level=logging.INFO, 56 | screen=True, tofile=True) 57 | setup_logger('val', os.path.join('./experiments', args.experiment), 'val_' + args.experiment, level=logging.INFO, 58 | screen=True, tofile=True) 59 | 60 | logger_train = logging.getLogger('train') 61 | logger_val = logging.getLogger('val') 62 | tb_logger = SummaryWriter(log_dir='./tb_logger/' + args.experiment) 63 | 64 | if not os.path.exists(os.path.join('./experiments', args.experiment, 'checkpoints')): 65 | os.makedirs(os.path.join('./experiments', args.experiment, 'checkpoints')) 66 | 67 | train_transforms = transforms.Compose( 68 | [transforms.RandomCrop(args.patch_size), transforms.ToTensor()] 69 | ) 70 | test_transforms = transforms.Compose( 71 | [transforms.ToTensor()] 72 | ) 73 | 74 | train_dataset = TrainData("../../../coco2voc/coco2voc/ImageSets/Main/trainval.txt") 75 | test_dataset = TrainData("../../../coco2voc/coco2voc/ImageSets/Main/test.txt") 76 | train_dataloader = DataLoader( 77 | train_dataset, 78 | batch_size=args.batch_size, 79 | num_workers=args.num_workers, 80 | shuffle=True, 81 | pin_memory=(device == "cuda"), 82 | ) 83 | 84 | test_dataloader = DataLoader( 85 | test_dataset, 86 | batch_size=args.test_batch_size, 87 | num_workers=args.num_workers, 88 | shuffle=False, 89 | pin_memory=(device == "cuda"), 90 | ) 91 | 92 | net = ConditionalMLICPlusPlus(config=config) 93 | if args.cuda and torch.cuda.device_count() > 1: 94 | net = CustomDataParallel(net) 95 | net = net.to(device) 96 | optimizer, aux_optimizer = configure_optimizers(net, args) 97 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 100], gamma=0.1) 98 | criterion = RateDistortionLoss(lmbda=args.lmbda, metrics=args.metrics) 99 | clip_loss = CLIPConvLoss(device) 100 | 101 | if args.checkpoint != None: 102 | if args.tune: 103 | checkpoint = torch.load(args.checkpoint) 104 | # new_ckpt = modify_checkpoint(checkpoint['state_dict']) 105 | net.load_state_dict(checkpoint['state_dict']) 106 | for name, param in net.named_parameters(): 107 | if not ('g_s' in name or 'global_cond' in name): 108 | param.requires_grad = False 109 | # 检查哪些参数被冻结,哪些参数是可训练的 110 | # for name, param in net.named_parameters(): 111 | # print(f"{name}: requires_grad={param.requires_grad}") 112 | optimizer = configure_optimizer(net, args) 113 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10], gamma=0.1) 114 | start_epoch = 0 115 | best_loss = 1e10 116 | current_step = 0 117 | else: 118 | checkpoint = torch.load(args.checkpoint) 119 | # new_ckpt = modify_checkpoint(checkpoint['state_dict']) 120 | net.load_state_dict(checkpoint['state_dict']) 121 | optimizer.load_state_dict(checkpoint['optimizer']) 122 | aux_optimizer.load_state_dict(checkpoint['aux_optimizer']) 123 | # lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 124 | lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[450, 550], gamma=0.1) 125 | # lr_scheduler._step_count = checkpoint['lr_scheduler']['_step_count'] 126 | # lr_scheduler.last_epoch = checkpoint['lr_scheduler']['last_epoch'] 127 | # print(lr_scheduler.state_dict()) 128 | start_epoch = checkpoint['epoch'] 129 | best_loss = checkpoint['loss'] 130 | current_step = start_epoch * math.ceil(len(train_dataloader.dataset) / args.batch_size) 131 | checkpoint = None 132 | 133 | else: 134 | start_epoch = 0 135 | best_loss = 1e10 136 | current_step = 0 137 | 138 | # start_epoch = 0 139 | # best_loss = 1e10 140 | # current_step = 0 141 | 142 | logger_train.info(args) 143 | logger_train.info(config) 144 | logger_train.info(net) 145 | logger_train.info(optimizer) 146 | optimizer.param_groups[0]['lr'] = args.learning_rate 147 | save_dir = os.path.join('./experiments', args.experiment, 'val_images', '%03d' % (1 + 1)) 148 | for epoch in range(start_epoch, args.epochs): 149 | logger_train.info(f"Learning rate: {optimizer.param_groups[0]['lr']}") 150 | current_step = train_conditional_one_epoch( 151 | net, 152 | criterion, 153 | train_dataloader, 154 | optimizer, 155 | aux_optimizer, 156 | args.tune, 157 | epoch, 158 | args.clip_max_norm, 159 | logger_train, 160 | tb_logger, 161 | current_step, 162 | clip_loss, 163 | args.lambda_clip, 164 | args.lmbda, 165 | args.lambda_beta1, 166 | args.lambda_beta2 167 | ) 168 | 169 | 170 | loss = test_conditional_one_epoch(epoch, test_dataloader, net, criterion, save_dir, logger_val, tb_logger, clip_loss, args.lambda_clip) 171 | 172 | lr_scheduler.step() 173 | is_best = loss < best_loss 174 | best_loss = min(loss, best_loss) 175 | 176 | net.update(force=True) 177 | if args.save: 178 | save_checkpoint( 179 | { 180 | "epoch": epoch + 1, 181 | "state_dict": net.state_dict(), 182 | "loss": loss, 183 | "optimizer": optimizer.state_dict(), 184 | "aux_optimizer": aux_optimizer.state_dict(), 185 | "lr_scheduler": lr_scheduler.state_dict(), 186 | }, 187 | is_best, 188 | os.path.join('./experiments', args.experiment, 'checkpoints', "checkpoint_%03d.pth.tar" % (epoch + 1)) 189 | ) 190 | if is_best: 191 | logger_val.info('best checkpoint saved.') 192 | 193 | if __name__ == '__main__': 194 | main() -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | import ssl 16 | ssl._create_default_https_context = ssl._create_unverified_context 17 | 18 | try: 19 | from torchvision.transforms import InterpolationMode 20 | BICUBIC = InterpolationMode.BICUBIC 21 | except ImportError: 22 | BICUBIC = Image.BICUBIC 23 | 24 | 25 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 26 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 27 | 28 | 29 | __all__ = ["available_models", "load", "tokenize"] 30 | _tokenizer = _Tokenizer() 31 | 32 | _MODELS = { 33 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 34 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 35 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 36 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 37 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 38 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 39 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 40 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 41 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 42 | } 43 | 44 | 45 | def _download(url: str, root: str): 46 | os.makedirs(root, exist_ok=True) 47 | filename = os.path.basename(url) 48 | 49 | expected_sha256 = url.split("/")[-2] 50 | download_target = os.path.join(root, filename) 51 | 52 | if os.path.exists(download_target) and not os.path.isfile(download_target): 53 | raise RuntimeError(f"{download_target} exists and is not a regular file") 54 | 55 | if os.path.isfile(download_target): 56 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 57 | return download_target 58 | else: 59 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 60 | 61 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 62 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 63 | while True: 64 | buffer = source.read(8192) 65 | if not buffer: 66 | break 67 | 68 | output.write(buffer) 69 | loop.update(len(buffer)) 70 | 71 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 72 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 73 | 74 | return download_target 75 | 76 | 77 | def _convert_image_to_rgb(image): 78 | return image.convert("RGB") 79 | 80 | 81 | def _transform(n_px): 82 | return Compose([ 83 | Resize(n_px, interpolation=BICUBIC), 84 | CenterCrop(n_px), 85 | _convert_image_to_rgb, 86 | ToTensor(), 87 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 88 | ]) 89 | 90 | 91 | def available_models() -> List[str]: 92 | """Returns the names of available CLIP models""" 93 | return list(_MODELS.keys()) 94 | 95 | 96 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 97 | """Load a CLIP model 98 | 99 | Parameters 100 | ---------- 101 | name : str 102 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 103 | 104 | device : Union[str, torch.device] 105 | The device to put the loaded model 106 | 107 | jit : bool 108 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 109 | 110 | download_root: str 111 | path to download the model files; by default, it uses "~/.cache/clip" 112 | 113 | Returns 114 | ------- 115 | model : torch.nn.Module 116 | The CLIP model 117 | 118 | preprocess : Callable[[PIL.Image], torch.Tensor] 119 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 120 | """ 121 | if name in _MODELS: 122 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 123 | elif os.path.isfile(name): 124 | model_path = name 125 | else: 126 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 127 | 128 | with open(model_path, 'rb') as opened_file: 129 | try: 130 | # loading JIT archive 131 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 132 | state_dict = None 133 | except RuntimeError: 134 | # loading saved state dict 135 | if jit: 136 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 137 | jit = False 138 | state_dict = torch.load(opened_file, map_location="cpu") 139 | 140 | if not jit: 141 | model = build_model(state_dict or model.state_dict()).to(device) 142 | if str(device) == "cpu": 143 | model.float() 144 | return model, _transform(model.visual.input_resolution) 145 | 146 | # patch the device names 147 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 148 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 149 | 150 | def _node_get(node: torch._C.Node, key: str): 151 | """Gets attributes of a node which is polymorphic over return type. 152 | 153 | From https://github.com/pytorch/pytorch/pull/82628 154 | """ 155 | sel = node.kindOf(key) 156 | return getattr(node, sel)(key) 157 | 158 | def patch_device(module): 159 | try: 160 | graphs = [module.graph] if hasattr(module, "graph") else [] 161 | except RuntimeError: 162 | graphs = [] 163 | 164 | if hasattr(module, "forward1"): 165 | graphs.append(module.forward1.graph) 166 | 167 | for graph in graphs: 168 | for node in graph.findAllNodes("prim::Constant"): 169 | if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): 170 | node.copyAttributes(device_node) 171 | 172 | model.apply(patch_device) 173 | patch_device(model.encode_image) 174 | patch_device(model.encode_text) 175 | 176 | # patch dtype to float32 on CPU 177 | if str(device) == "cpu": 178 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 179 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 180 | float_node = float_input.node() 181 | 182 | def patch_float(module): 183 | try: 184 | graphs = [module.graph] if hasattr(module, "graph") else [] 185 | except RuntimeError: 186 | graphs = [] 187 | 188 | if hasattr(module, "forward1"): 189 | graphs.append(module.forward1.graph) 190 | 191 | for graph in graphs: 192 | for node in graph.findAllNodes("aten::to"): 193 | inputs = list(node.inputs()) 194 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 195 | if _node_get(inputs[i].node(), "value") == 5: 196 | inputs[i].node().copyAttributes(float_node) 197 | 198 | model.apply(patch_float) 199 | patch_float(model.encode_image) 200 | patch_float(model.encode_text) 201 | 202 | model.float() 203 | 204 | return model, _transform(model.input_resolution.item()) 205 | 206 | 207 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 208 | """ 209 | Returns the tokenized representation of given input string(s) 210 | 211 | Parameters 212 | ---------- 213 | texts : Union[str, List[str]] 214 | An input string or a list of input strings to tokenize 215 | 216 | context_length : int 217 | The context length to use; all CLIP models use 77 as the context length 218 | 219 | truncate: bool 220 | Whether to truncate the text in case its encoding is longer than the context length 221 | 222 | Returns 223 | ------- 224 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 225 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 226 | """ 227 | if isinstance(texts, str): 228 | texts = [texts] 229 | 230 | sot_token = _tokenizer.encoder["<|startoftext|>"] 231 | eot_token = _tokenizer.encoder["<|endoftext|>"] 232 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 233 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 234 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 235 | else: 236 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 237 | 238 | for i, tokens in enumerate(all_tokens): 239 | if len(tokens) > context_length: 240 | if truncate: 241 | tokens = tokens[:context_length] 242 | tokens[-1] = eot_token 243 | else: 244 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 245 | result[i, :len(tokens)] = torch.tensor(tokens) 246 | 247 | return result 248 | -------------------------------------------------------------------------------- /clip/Closs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import collections 3 | import torch.nn as nn 4 | 5 | sys.path.append("..") 6 | import clip.clip as clip 7 | import torch 8 | from torchvision import models, transforms 9 | 10 | class CLIPLoss(torch.nn.Module): 11 | def __init__(self, device, num, affine): 12 | super(CLIPLoss, self).__init__() 13 | 14 | self.model, clip_preprocess = clip.load( 15 | 'ViT-B/32', device, jit=False) 16 | self.model.eval() 17 | self.preprocess = transforms.Compose( 18 | [clip_preprocess.transforms[-1], clip_preprocess.transforms[0]]) # clip normalisation 19 | self.device = device 20 | self.NUM_AUGS = num 21 | self.mse = nn.MSELoss() 22 | augemntations = [] 23 | if affine: 24 | # augemntations.append(transforms.RandomPerspective( 25 | # fill=0, p=1.0, distortion_scale=0.5)) 26 | # augemntations.append(transforms.Resize([224, 224])) 27 | augemntations.append(transforms.Resize([224, 224])) 28 | augemntations.append( 29 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))) 30 | self.augment_trans = transforms.Compose(augemntations) 31 | 32 | self.calc_target = True 33 | self.counter = 0 34 | 35 | def forward(self, sketches, targets, mode="train"): 36 | if self.calc_target: 37 | targets_ = self.preprocess(targets).to(self.device) 38 | self.targets_features = self.model.encode_image(targets_).detach() 39 | self.calc_target = False 40 | 41 | if mode == "eval": 42 | # for regular clip distance, no augmentations 43 | with torch.no_grad(): 44 | sketches = self.preprocess(sketches).to(self.device) 45 | sketches_features = self.model.encode_image(sketches) 46 | return 1. - torch.cosine_similarity(sketches_features, self.targets_features) 47 | 48 | loss_clip = 0 49 | sketch_augs = [] 50 | img_augs = [] 51 | for n in range(self.NUM_AUGS): 52 | # augmented_pair = self.augment_trans(torch.cat([sketches, targets])) 53 | augmented_pair = self.augment_trans(sketches) 54 | sketch_augs.append(augmented_pair[0].unsqueeze(0)) 55 | sketch_batch = torch.cat(sketch_augs) 56 | # sketch_utils.plot_batch(img_batch, sketch_batch, self.args, self.counter, use_wandb=False, title="fc_aug{}_iter{}_{}.jpg".format(1, self.counter, mode)) 57 | # if self.counter % 100 == 0: 58 | # sketch_utils.plot_batch(img_batch, sketch_batch, self.args, self.counter, use_wandb=False, title="aug{}_iter{}_{}.jpg".format(1, self.counter, mode)) 59 | 60 | sketch_features = self.model.encode_image(sketch_batch) 61 | 62 | for n in range(self.NUM_AUGS): 63 | loss_clip += (1. - torch.cosine_similarity( 64 | sketch_features[n:n+1], self.targets_features, dim=1)) 65 | # loss_clip += self.mse(sketch_features[n:n+1], self.targets_features) 66 | self.counter += 1 67 | return loss_clip 68 | # return 1. - torch.cosine_similarity(sketches_features, self.targets_features) 69 | 70 | 71 | class CLIPVisualEncoder(nn.Module): 72 | def __init__(self, clip_model): 73 | super().__init__() 74 | self.clip_model = clip_model 75 | self.featuremaps = None 76 | 77 | for i in range(12): # 12 resblocks in VIT visual transformer 78 | self.clip_model.visual.transformer.resblocks[i].register_forward_hook( 79 | self.make_hook(i)) 80 | 81 | def make_hook(self, name): 82 | def hook(module, input, output): 83 | if len(output.shape) == 3: 84 | self.featuremaps[name] = output.permute( 85 | 1, 0, 2) # LND -> NLD bs, smth, 768 86 | else: 87 | self.featuremaps[name] = output 88 | 89 | return hook 90 | 91 | def forward(self, x): 92 | self.featuremaps = collections.OrderedDict() 93 | fc_features = self.clip_model.encode_image(x).float() 94 | featuremaps = [self.featuremaps[k] for k in range(12)] 95 | 96 | return fc_features, featuremaps 97 | 98 | 99 | def l2_layers(xs_conv_features, ys_conv_features, clip_model_name): 100 | return [torch.square(x_conv - y_conv).mean() for x_conv, y_conv in 101 | zip(xs_conv_features, ys_conv_features)] 102 | 103 | 104 | def l1_layers(xs_conv_features, ys_conv_features, clip_model_name): 105 | return [torch.abs(x_conv - y_conv).mean() for x_conv, y_conv in 106 | zip(xs_conv_features, ys_conv_features)] 107 | 108 | 109 | def cos_layers(xs_conv_features, ys_conv_features, clip_model_name): 110 | if "RN" in clip_model_name: 111 | return [torch.square(x_conv, y_conv, dim=1).mean() for x_conv, y_conv in 112 | zip(xs_conv_features, ys_conv_features)] 113 | return [(1 - torch.cosine_similarity(x_conv, y_conv, dim=1)).mean() for x_conv, y_conv in 114 | zip(xs_conv_features, ys_conv_features)] 115 | 116 | 117 | class CLIPConvLoss(torch.nn.Module): 118 | def __init__(self, device): 119 | super(CLIPConvLoss, self).__init__() 120 | self.clip_model_name = "ViT-B/32" 121 | assert self.clip_model_name in [ 122 | "RN50", 123 | "RN101", 124 | "RN50x4", 125 | "RN50x16", 126 | "ViT-B/32", 127 | "ViT-B/16", 128 | ] 129 | 130 | self.clip_conv_loss_type = "L2" 131 | self.clip_fc_loss_type = "Cos" # args.clip_fc_loss_type 132 | assert self.clip_conv_loss_type in [ 133 | "L2", "Cos", "L1", 134 | ] 135 | assert self.clip_fc_loss_type in [ 136 | "L2", "Cos", "L1", 137 | ] 138 | 139 | self.distance_metrics = \ 140 | { 141 | "L2": l2_layers, 142 | "L1": l1_layers, 143 | "Cos": cos_layers 144 | } 145 | 146 | self.model, clip_preprocess = clip.load( 147 | self.clip_model_name, device, jit=False, download_root=".") 148 | 149 | if self.clip_model_name.startswith("ViT"): 150 | self.visual_encoder = CLIPVisualEncoder(self.model) 151 | 152 | else: 153 | self.visual_model = self.model.visual 154 | layers = list(self.model.visual.children()) 155 | init_layers = torch.nn.Sequential(*layers)[:8] 156 | self.layer1 = layers[8] 157 | self.layer2 = layers[9] 158 | self.layer3 = layers[10] 159 | self.layer4 = layers[11] 160 | self.att_pool2d = layers[12] 161 | 162 | 163 | self.img_size = clip_preprocess.transforms[1].size 164 | self.model.eval() 165 | self.target_transform = transforms.Compose([ 166 | transforms.ToTensor(), 167 | ]) # clip normalisation 168 | self.normalize_transform = transforms.Compose([ 169 | clip_preprocess.transforms[0], # Resize 170 | clip_preprocess.transforms[1], # CenterCrop 171 | clip_preprocess.transforms[-1], # Normalize 172 | ]) 173 | 174 | self.model.eval() 175 | self.device = device 176 | # self.num_augs = num 177 | 178 | augemntations = [] 179 | # if affine: 180 | # augemntations.append(transforms.RandomPerspective( 181 | # fill=0, p=1.0, distortion_scale=0.5)) 182 | # augemntations.append(transforms.RandomResizedCrop( 183 | # 224, scale=(0.8, 0.8), ratio=(1.0, 1.0))) 184 | # else: 185 | augemntations.append(transforms.Resize([224, 224])) 186 | augemntations.append( 187 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))) 188 | self.augment_trans = transforms.Compose(augemntations) 189 | 190 | self.clip_fc_layer_dims = None # self.args.clip_fc_layer_dims 191 | self.clip_conv_layer_dims = None # self.args.clip_conv_layer_dims 192 | self.clip_fc_loss_weight = 0.1 193 | self.counter = 0 194 | self.clip_conv_layer_weights = "0,0,0,0,0" 195 | self.clip_conv_layer_weights = [ 196 | float(item) for item in self.clip_conv_layer_weights.split(',')] 197 | 198 | def forward(self, sketch, target, mode="train"): 199 | """ 200 | Parameters 201 | ---------- 202 | sketch: Torch Tensor [1, C, H, W] 203 | target: Torch Tensor [1, C, H, W] 204 | """ 205 | # y = self.target_transform(target).to(self.args.device) 206 | conv_loss_dict = {} 207 | x = sketch.to(self.device) 208 | y = target.to(self.device) 209 | sketch_augs, img_augs = [self.normalize_transform(x)], [ 210 | self.normalize_transform(y)] 211 | # if mode == "train": 212 | # for n in range(self.num_augs): 213 | # augmented_pair = self.augment_trans(torch.cat([x, y])) 214 | # sketch_augs.append(augmented_pair[0].unsqueeze(0)) 215 | # img_augs.append(augmented_pair[1].unsqueeze(0)) 216 | 217 | xs = torch.cat(sketch_augs, dim=0).to(self.device) 218 | ys = torch.cat(img_augs, dim=0).to(self.device) 219 | 220 | if self.clip_model_name.startswith("RN"): 221 | xs_fc_features, xs_conv_features = self.forward_inspection_clip_resnet( 222 | xs.contiguous()) 223 | ys_fc_features, ys_conv_features = self.forward_inspection_clip_resnet( 224 | ys.detach()) 225 | 226 | else: 227 | xs_fc_features, xs_conv_features = self.visual_encoder(xs) 228 | ys_fc_features, ys_conv_features = self.visual_encoder(ys) 229 | 230 | conv_loss = self.distance_metrics[self.clip_conv_loss_type]( 231 | xs_conv_features, ys_conv_features, self.clip_model_name) 232 | conv_loss_total=0 233 | for layer, w in enumerate(self.clip_conv_layer_weights): 234 | if w: 235 | conv_loss_dict[f"clip_conv_loss_layer{layer}"] = conv_loss[layer] * w 236 | conv_loss_total = conv_loss_total + conv_loss_dict[f"clip_conv_loss_layer{layer}"] 237 | if self.clip_fc_loss_weight: 238 | # fc distance is always cos 239 | fc_loss = (1 - torch.cosine_similarity(xs_fc_features, 240 | ys_fc_features, dim=1)).mean() 241 | conv_loss_dict["fc"] = fc_loss * self.clip_fc_loss_weight 242 | conv_loss_total += conv_loss_dict["fc"] 243 | self.counter += 1 244 | return conv_loss_total 245 | 246 | def forward_inspection_clip_resnet(self, x): 247 | def stem(m, x): 248 | x = m.relu1(m.bn1(m.conv1(x))) 249 | x = m.relu2(m.bn2(m.conv2(x))) 250 | x = m.relu3(m.bn3(m.conv3(x))) 251 | x = m.avgpool(x) 252 | return x 253 | 254 | x = x.type(self.visual_model.conv1.weight.dtype) 255 | x = stem(self.visual_model, x) 256 | x1 = self.layer1(x) 257 | x2 = self.layer2(x1) 258 | x3 = self.layer3(x2) 259 | x4 = self.layer4(x3) 260 | y = self.att_pool2d(x4) 261 | return y, [x, x1, x2, x3, x4] 262 | -------------------------------------------------------------------------------- /modules/transform/context.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.init import trunc_normal_ 5 | from einops import rearrange 6 | from modules.layers import MLP, build_position_index 7 | from modules.layers import conv, deconv 8 | from utils.ckbd import * 9 | 10 | 11 | class LocalContext(nn.Module): 12 | def __init__(self, 13 | dim=32, 14 | window_size=5, 15 | mlp_ratio=2., 16 | num_heads=2, 17 | qkv_bias=True, 18 | qk_scale=None 19 | ) -> None: 20 | super().__init__() 21 | self.H = -1 22 | self.W = -1 23 | self.dim = dim 24 | self.window_size = window_size 25 | self.num_heads = num_heads 26 | self.head_dim = dim // num_heads 27 | self.scale = qk_scale or self.head_dim ** -0.5 28 | self.qkv_proj = nn.Linear(dim, dim * 3, bias=qkv_bias) 29 | self.unfold = nn.Unfold(kernel_size=window_size, stride=1, padding=(window_size - 1) // 2) 30 | self.relative_position_table = nn.Parameter( 31 | torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads) 32 | ) 33 | trunc_normal_(self.relative_position_table, std=0.02) 34 | self.softmax = nn.Softmax(dim=-1) 35 | self.proj = nn.Linear(dim * 2, dim * 2) 36 | self.mlp = MLP(in_dim=dim * 2, hidden_dim=int(dim * 2 * mlp_ratio), out_dim=dim * 2) 37 | self.norm1 = nn.LayerNorm(dim) 38 | self.norm2 = nn.LayerNorm(dim * 2) 39 | self.register_buffer("relative_position_index", build_position_index((window_size, window_size))) 40 | self.attn_mask = None 41 | self.fusion = nn.Conv2d(dim, dim * 2, kernel_size=window_size) 42 | 43 | def update_resolution(self, H, W, device, mask=None): 44 | updated=False 45 | if self.H != H or self.W != W: 46 | self.H = H 47 | self.W = W 48 | if mask is not None: 49 | self.attn_mask = mask.to(device) 50 | updated=True 51 | return updated 52 | ckbd = torch.zeros((1, 2, H, W), requires_grad=False) 53 | # anchor 54 | ckbd[:, :, 0::2, 1::2] = 1 55 | ckbd[:, :, 1::2, 0::2] = 1 56 | qk_windows = self.unfold(ckbd).permute(0, 2, 1) 57 | qk_windows = qk_windows.view(1, H * W, 2, 1, self.window_size, self.window_size).permute(2, 0, 1, 3, 4, 5) 58 | q_windows, k_windows = qk_windows[0], qk_windows[1] 59 | q = q_windows.reshape(1, H * W, 1, self.window_size * self.window_size).permute(0, 1, 3, 2) 60 | k = k_windows.reshape(1, H * W, 1, self.window_size * self.window_size).permute(0, 1, 3, 2) 61 | attn_mask = (q @ k.transpose(-2, -1)) 62 | attn_mask = attn_mask.masked_fill(attn_mask == 0., float(-100.0)).masked_fill(attn_mask == 1, float(0.0)) 63 | self.attn_mask = attn_mask[0].to(device).detach() 64 | updated=True 65 | return updated 66 | 67 | def forward(self, x): 68 | B, C, H, W = x.shape 69 | L = H * W 70 | self.update_resolution(H, W, x.device) 71 | # [B, L, C] 72 | x = x.reshape(B, C, L).permute(0, 2, 1) 73 | x = self.norm1(x) 74 | 75 | # [3, B, C, H, W] 76 | qkv = self.qkv_proj(x).reshape(B, H, W, 3, C).permute(3, 0, 4, 1, 2) 77 | 78 | # window partition 79 | q, k, v = qkv[0], qkv[1], qkv[2] 80 | qkv = torch.cat([q, k, v], dim=1) 81 | qkv_windows = self.unfold(qkv).permute(0, 2, 1) 82 | qkv_windows = qkv_windows.view(B, L, 3, C, self.window_size, self.window_size).permute(2, 0, 1, 3, 4, 5) 83 | # [B, L, C, window_size, window_size] 84 | q_windows, k_windows, v_windows = qkv_windows[0], qkv_windows[1], qkv_windows[2] 85 | 86 | # [B, L, num_heads, window_size * window_size, head_dim] 87 | q = q_windows.reshape(B, L, self.head_dim, self.num_heads, self.window_size * self.window_size).permute(0, 1, 3, 4, 2) 88 | k = k_windows.reshape(B, L, self.head_dim, self.num_heads, self.window_size * self.window_size).permute(0, 1, 3, 4, 2) 89 | v = v_windows.reshape(B, L, self.head_dim, self.num_heads, self.window_size * self.window_size).permute(0, 1, 3, 4, 2) 90 | 91 | q = q * self.scale 92 | # [B, L, num_heads, window_size * window_size, window_size * window_size] 93 | attn = (q @ k.transpose(-2, -1)) 94 | 95 | relative_position_bias = self.relative_position_table[self.relative_position_index.view(-1)].view( 96 | self.window_size * self.window_size, self.window_size * self.window_size, -1 97 | ) 98 | # [num_heads, window_size * window_size, window_size * window_size] 99 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() 100 | attn = attn + relative_position_bias.unsqueeze(0).unsqueeze(1) 101 | 102 | attn = attn + self.attn_mask.unsqueeze(0).unsqueeze(2) 103 | 104 | attn = self.softmax(attn) 105 | 106 | x = (attn @ v).reshape(B, L, self.num_heads, self.window_size, self.window_size, self.head_dim).permute(0, 1, 3, 4, 2, 5) 107 | x = x.reshape(B * L, self.window_size, self.window_size, C).permute(0, 3, 1, 2) 108 | x = self.fusion(x).reshape(B, L, C * 2) 109 | x = self.proj(x) 110 | x = x + self.mlp(self.norm2(x)) 111 | x = x.permute(0, 2, 1).reshape(B, C * 2, H, W) 112 | return x 113 | 114 | 115 | class ChannelContext(nn.Module): 116 | def __init__(self, in_dim, out_dim) -> None: 117 | super().__init__() 118 | self.fushion = nn.Sequential( 119 | nn.Conv2d(in_dim, 192, kernel_size=3, stride=1, padding=1), 120 | nn.GELU(), 121 | nn.Conv2d(192, 128, kernel_size=3, stride=1, padding=1), 122 | nn.GELU(), 123 | nn.Conv2d(128, out_dim * 4, kernel_size=3, stride=1, padding=1) 124 | ) 125 | 126 | def forward(self, channel_params): 127 | """ 128 | Args: 129 | channel_params(Tensor): [B, C * K, H, W] 130 | return: 131 | channel_params(Tensor): [B, C * 2, H, W] 132 | """ 133 | channel_params = self.fushion(channel_params) 134 | 135 | return channel_params 136 | 137 | 138 | class ChannelContextEX(nn.Module): 139 | def __init__(self, in_dim, out_dim, act=nn.GELU) -> None: 140 | super().__init__() 141 | self.fushion = nn.Sequential( 142 | nn.Conv2d(in_dim, 224, kernel_size=3, stride=1, padding=1), 143 | act(), 144 | nn.Conv2d(224, 128, kernel_size=3, stride=1, padding=1), 145 | act(), 146 | nn.Conv2d(128, out_dim, kernel_size=3, stride=1, padding=1) 147 | ) 148 | 149 | def forward(self, channel_params): 150 | """ 151 | Args: 152 | channel_params(Tensor): [B, C * K, H, W] 153 | return: 154 | channel_params(Tensor): [B, C * 2, H, W] 155 | """ 156 | channel_params = self.fushion(channel_params) 157 | 158 | return channel_params 159 | 160 | class LinearGlobalIntraContext(nn.Module): 161 | def __init__( 162 | self, 163 | dim=32, 164 | num_heads=2) -> None: 165 | super().__init__() 166 | self.dim = dim 167 | self.num_heads = num_heads 168 | self.keys = nn.Sequential( 169 | nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0), 170 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) 171 | ) 172 | self.queries = nn.Sequential( 173 | nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0), 174 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) 175 | ) 176 | self.values = nn.Sequential( 177 | nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0), 178 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) 179 | ) 180 | self.reprojection = nn.Conv2d(dim, dim * 2, kernel_size=5, stride=1, padding=2) 181 | self.mlp = nn.Sequential( 182 | nn.Conv2d(dim * 2, dim * 4, kernel_size=1, stride=1), 183 | nn.GELU(), 184 | nn.Conv2d(dim * 4, dim * 4, kernel_size=3, stride=1, padding=1, groups=dim * 4), 185 | nn.GELU(), 186 | nn.Conv2d(dim * 4, dim * 2, kernel_size=1, stride=1) 187 | ) 188 | 189 | def forward(self, x1, x2): 190 | B, C, H, W = x1.shape 191 | x1_ac = ckbd_anchor(x1) 192 | x1_na = ckbd_nonanchor(x1) 193 | queries = ckbd_nonanchor_sequeeze(self.queries(x1_na)).reshape(B, self.dim, H * W//2) 194 | keys = ckbd_anchor_sequeeze(self.keys(x1_ac)).reshape(B, self.dim, H * W//2) 195 | values = ckbd_anchor_sequeeze(self.values(x2)).reshape(B, self.dim, H * W//2) 196 | head_dim = self.dim // self.num_heads 197 | 198 | attended_values = [] 199 | for i in range(self.num_heads): 200 | key = F.softmax(keys[:, i * head_dim: (i + 1) * head_dim, :], dim=2) 201 | query = F.softmax(queries[:, i * head_dim: (i + 1) * head_dim, :], dim=1) 202 | value = values[:, i * head_dim: (i + 1) * head_dim, :] 203 | key = ckbd_anchor_unsequeeze(key.reshape(B, head_dim, H, W //2)).reshape(B, head_dim, H * W) 204 | value = ckbd_anchor_unsequeeze(value.reshape(B, head_dim, H, W //2)).reshape(B, head_dim, H * W) 205 | query = ckbd_nonanchor_unsequeeze(query.reshape(B, head_dim, H, W //2)).reshape(B, head_dim, H * W) 206 | context = key @ value.transpose(1, 2) 207 | attended_value = (context.transpose(1, 2) @ query).reshape(B, head_dim, H, W) 208 | attended_values.append(attended_value) 209 | 210 | aggregated_values = torch.cat(attended_values, dim=1) 211 | attention = self.reprojection(aggregated_values) 212 | 213 | return attention + self.mlp(attention) 214 | 215 | class LinearGlobalInterContext(nn.Module): 216 | def __init__( 217 | self, 218 | dim=32, 219 | out_dim=64, 220 | num_heads=2) -> None: 221 | super().__init__() 222 | self.dim = dim 223 | self.num_heads = num_heads 224 | self.keys = nn.Sequential( 225 | nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0), 226 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) 227 | ) 228 | self.queries = nn.Sequential( 229 | nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0), 230 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) 231 | ) 232 | self.values = nn.Sequential( 233 | nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0), 234 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) 235 | ) 236 | self.reprojection = nn.Conv2d(dim, out_dim * 3 // 2, kernel_size=5, stride=1, padding=2) 237 | self.mlp = nn.Sequential( 238 | nn.Conv2d(out_dim * 3 // 2, out_dim * 2, kernel_size=1, stride=1), 239 | nn.GELU(), 240 | nn.Conv2d(out_dim * 2, out_dim * 2, kernel_size=3, stride=1, padding=1, groups=out_dim * 2), 241 | nn.GELU(), 242 | nn.Conv2d(out_dim * 2, out_dim, kernel_size=1, stride=1) 243 | ) 244 | self.skip = nn.Conv2d(out_dim * 3 // 2, out_dim, kernel_size=1, stride=1, padding=0) 245 | 246 | def forward(self, x1): 247 | B, C, H, W = x1.shape 248 | queries = self.queries(x1).reshape(B, self.dim, H * W) 249 | keys = self.keys(x1).reshape(B, self.dim, H * W) 250 | values = self.values(x1).reshape(B, self.dim, H * W) 251 | head_dim = self.dim // self.num_heads 252 | 253 | attended_values = [] 254 | for i in range(self.num_heads): 255 | key = F.softmax(keys[:, i * head_dim: (i + 1) * head_dim, :], dim=2) 256 | query = F.softmax(queries[:, i * head_dim: (i + 1) * head_dim, :], dim=1) 257 | value = values[:, i * head_dim: (i + 1) * head_dim, :] 258 | context = key @ value.transpose(1, 2) 259 | attended_value = (context.transpose(1, 2) @ query).reshape(B, head_dim, H, W) 260 | attended_values.append(attended_value) 261 | 262 | aggregated_values = torch.cat(attended_values, dim=1) 263 | attention = self.reprojection(aggregated_values) 264 | 265 | return self.skip(attention) + self.mlp(attention) 266 | -------------------------------------------------------------------------------- /utils/testing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from utils.metrics import compute_metrics, compute_metrics2 6 | from utils.utils import * 7 | 8 | 9 | def test_one_epoch(epoch, test_dataloader, model, criterion, save_dir, logger_val, tb_logger): 10 | model.eval() 11 | device = next(model.parameters()).device 12 | 13 | loss = AverageMeter() 14 | bpp_loss = AverageMeter() 15 | mse_loss = AverageMeter() 16 | ms_ssim_loss = AverageMeter() 17 | aux_loss = AverageMeter() 18 | psnr = AverageMeter() 19 | ms_ssim = AverageMeter() 20 | 21 | with torch.no_grad(): 22 | for i, d in enumerate(test_dataloader): 23 | d = d.to(device) 24 | out_net = model(d) 25 | out_criterion = criterion(out_net, d) 26 | 27 | aux_loss.update(model.aux_loss()) 28 | bpp_loss.update(out_criterion["bpp_loss"]) 29 | loss.update(out_criterion["loss"]) 30 | if out_criterion["mse_loss"] is not None: 31 | mse_loss.update(out_criterion["mse_loss"]) 32 | if out_criterion["ms_ssim_loss"] is not None: 33 | ms_ssim_loss.update(out_criterion["ms_ssim_loss"]) 34 | 35 | rec = torch2img(out_net['x_hat'][0]) 36 | img = torch2img(d[0]) 37 | p, m = compute_metrics(rec, img) 38 | psnr.update(p) 39 | ms_ssim.update(m) 40 | 41 | # if not os.path.exists(save_dir): 42 | # os.makedirs(save_dir) 43 | # rec.save(os.path.join(save_dir, '%03d_rec.png' % i)) 44 | # img.save(os.path.join(save_dir, '%03d_gt.png' % i)) 45 | 46 | tb_logger.add_scalar('{}'.format('[val]: loss'), loss.avg, epoch + 1) 47 | tb_logger.add_scalar('{}'.format('[val]: bpp_loss'), bpp_loss.avg, epoch + 1) 48 | tb_logger.add_scalar('{}'.format('[val]: psnr'), psnr.avg, epoch + 1) 49 | tb_logger.add_scalar('{}'.format('[val]: ms-ssim'), ms_ssim.avg, epoch + 1) 50 | 51 | if out_criterion["mse_loss"] is not None: 52 | logger_val.info( 53 | f"Test epoch {epoch}: Average losses: " 54 | f"Loss: {loss.avg:.4f} | " 55 | f"MSE loss: {mse_loss.avg:.6f} | " 56 | f"Bpp loss: {bpp_loss.avg:.4f} | " 57 | f"Aux loss: {aux_loss.avg:.2f} | " 58 | f"PSNR: {psnr.avg:.6f} | " 59 | f"MS-SSIM: {ms_ssim.avg:.6f}" 60 | ) 61 | tb_logger.add_scalar('{}'.format('[val]: mse_loss'), mse_loss.avg, epoch + 1) 62 | if out_criterion["ms_ssim_loss"] is not None: 63 | logger_val.info( 64 | f"Test epoch {epoch}: Average losses: " 65 | f"Loss: {loss.avg:.4f} | " 66 | f"MS-SSIM loss: {ms_ssim_loss.avg:.6f} | " 67 | f"Bpp loss: {bpp_loss.avg:.4f} | " 68 | f"Aux loss: {aux_loss.avg:.2f} | " 69 | f"PSNR: {psnr.avg:.6f} | " 70 | f"MS-SSIM: {ms_ssim.avg:.6f}" 71 | ) 72 | tb_logger.add_scalar('{}'.format('[val]: ms_ssim_loss'), ms_ssim_loss.avg, epoch + 1) 73 | 74 | return loss.avg 75 | 76 | def test_clip_one_epoch(epoch, test_dataloader, model, criterion, save_dir, logger_val, tb_logger, clip_loss, lambda_clip): 77 | model.eval() 78 | device = next(model.parameters()).device 79 | 80 | loss = AverageMeter() 81 | bpp_loss = AverageMeter() 82 | mse_loss = AverageMeter() 83 | ms_ssim_loss = AverageMeter() 84 | aux_loss = AverageMeter() 85 | psnr = AverageMeter() 86 | ms_ssim = AverageMeter() 87 | cliploss = AverageMeter() 88 | 89 | with torch.no_grad(): 90 | for i, d in enumerate(test_dataloader): 91 | d = d.to(device) 92 | out_net = model(d) 93 | out_criterion = criterion(out_net, d) 94 | closs = clip_loss(out_net["x_hat"], d, "eval") 95 | aux_loss.update(model.aux_loss()) 96 | bpp_loss.update(out_criterion["bpp_loss"]) 97 | cliploss.update(closs) 98 | loss.update(out_criterion["loss"]+(lambda_clip * closs)) 99 | if out_criterion["mse_loss"] is not None: 100 | mse_loss.update(out_criterion["mse_loss"]) 101 | if out_criterion["ms_ssim_loss"] is not None: 102 | ms_ssim_loss.update(out_criterion["ms_ssim_loss"]) 103 | 104 | 105 | 106 | rec = torch2img(out_net['x_hat']) 107 | img = torch2img(d) 108 | p, m = compute_metrics(rec, img) 109 | psnr.update(p) 110 | ms_ssim.update(m) 111 | 112 | 113 | if not os.path.exists(save_dir): 114 | os.makedirs(save_dir) 115 | rec.save(os.path.join(save_dir, '%03d_rec.png' % i)) 116 | img.save(os.path.join(save_dir, '%03d_gt.png' % i)) 117 | 118 | tb_logger.add_scalar('{}'.format('[val]: loss'), loss.avg, epoch + 1) 119 | tb_logger.add_scalar('{}'.format('[val]: bpp_loss'), bpp_loss.avg, epoch + 1) 120 | tb_logger.add_scalar('{}'.format('[val]: cliploss'), cliploss.avg, epoch + 1) 121 | tb_logger.add_scalar('{}'.format('[val]: psnr'), psnr.avg, epoch + 1) 122 | tb_logger.add_scalar('{}'.format('[val]: ms-ssim'), ms_ssim.avg, epoch + 1) 123 | 124 | 125 | if out_criterion["mse_loss"] is not None: 126 | logger_val.info( 127 | f"Test epoch {epoch}: Average losses: " 128 | f"Loss: {loss.avg:.4f} | " 129 | f"MSE loss: {mse_loss.avg:.6f} | " 130 | f"CLIP loss: {cliploss.avg:.6f} | " 131 | f"Bpp loss: {bpp_loss.avg:.4f} | " 132 | f"Aux loss: {aux_loss.avg:.2f} | " 133 | f"PSNR: {psnr.avg:.6f} | " 134 | f"MS-SSIM: {ms_ssim.avg:.6f}" 135 | ) 136 | tb_logger.add_scalar('{}'.format('[val]: mse_loss'), mse_loss.avg, epoch + 1) 137 | if out_criterion["ms_ssim_loss"] is not None: 138 | logger_val.info( 139 | f"Test epoch {epoch}: Average losses: " 140 | f"Loss: {loss.avg:.4f} | " 141 | f"MS-SSIM loss: {ms_ssim_loss.avg:.6f} | " 142 | f"CLIP loss: {cliploss.avg:.6f} | " 143 | f"Bpp loss: {bpp_loss.avg:.4f} | " 144 | f"Aux loss: {aux_loss.avg:.2f} | " 145 | f"PSNR: {psnr.avg:.6f} | " 146 | f"MS-SSIM: {ms_ssim.avg:.6f}" 147 | ) 148 | tb_logger.add_scalar('{}'.format('[val]: ms_ssim_loss'), ms_ssim_loss.avg, epoch + 1) 149 | 150 | return loss.avg 151 | 152 | def test_conditional_one_epoch(epoch, test_dataloader, model, criterion, save_dir, logger_val, tb_logger, clip_loss, lambda_clip): 153 | model.eval() 154 | device = next(model.parameters()).device 155 | 156 | loss = AverageMeter() 157 | bpp_loss = AverageMeter() 158 | mse_loss = AverageMeter() 159 | ms_ssim_loss = AverageMeter() 160 | aux_loss = AverageMeter() 161 | psnr = AverageMeter() 162 | ms_ssim = AverageMeter() 163 | cliploss = AverageMeter() 164 | 165 | with torch.no_grad(): 166 | for i, d in enumerate(test_dataloader): 167 | bs = d.size(0) 168 | betas = torch.randint(0, 2, (bs,)) 169 | 170 | 171 | 172 | d = d.to(device) 173 | betas = betas.to(device) 174 | 175 | out_net = model((d, betas)) 176 | out_criterion = criterion(out_net, d) 177 | closs = clip_loss(out_net["x_hat"], d, "eval") 178 | aux_loss.update(model.aux_loss()) 179 | bpp_loss.update(out_criterion["bpp_loss"]) 180 | cliploss.update(closs) 181 | loss.update(out_criterion["loss"]+(lambda_clip * closs)) 182 | if out_criterion["mse_loss"] is not None: 183 | mse_loss.update(out_criterion["mse_loss"]) 184 | if out_criterion["ms_ssim_loss"] is not None: 185 | ms_ssim_loss.update(out_criterion["ms_ssim_loss"]) 186 | 187 | 188 | 189 | rec = torch2img(out_net['x_hat'][0]) 190 | img = torch2img(d[0]) 191 | p, m = compute_metrics(rec, img) 192 | psnr.update(p) 193 | ms_ssim.update(m) 194 | 195 | 196 | # if not os.path.exists(save_dir): 197 | # os.makedirs(save_dir) 198 | # rec.save(os.path.join(save_dir, '%03d_rec.png' % i)) 199 | # img.save(os.path.join(save_dir, '%03d_gt.png' % i)) 200 | 201 | tb_logger.add_scalar('{}'.format('[val]: loss'), loss.avg, epoch + 1) 202 | tb_logger.add_scalar('{}'.format('[val]: bpp_loss'), bpp_loss.avg, epoch + 1) 203 | tb_logger.add_scalar('{}'.format('[val]: cliploss'), cliploss.avg, epoch + 1) 204 | tb_logger.add_scalar('{}'.format('[val]: psnr'), psnr.avg, epoch + 1) 205 | tb_logger.add_scalar('{}'.format('[val]: ms-ssim'), ms_ssim.avg, epoch + 1) 206 | 207 | 208 | if out_criterion["mse_loss"] is not None: 209 | logger_val.info( 210 | f"Test epoch {epoch}: Average losses: " 211 | f"Loss: {loss.avg:.4f} | " 212 | f"MSE loss: {mse_loss.avg:.6f} | " 213 | f"CLIP loss: {cliploss.avg:.6f} | " 214 | f"Bpp loss: {bpp_loss.avg:.4f} | " 215 | f"Aux loss: {aux_loss.avg:.2f} | " 216 | f"PSNR: {psnr.avg:.6f} | " 217 | f"MS-SSIM: {ms_ssim.avg:.6f}" 218 | ) 219 | tb_logger.add_scalar('{}'.format('[val]: mse_loss'), mse_loss.avg, epoch + 1) 220 | if out_criterion["ms_ssim_loss"] is not None: 221 | logger_val.info( 222 | f"Test epoch {epoch}: Average losses: " 223 | f"Loss: {loss.avg:.4f} | " 224 | f"MS-SSIM loss: {ms_ssim_loss.avg:.6f} | " 225 | f"CLIP loss: {cliploss.avg:.6f} | " 226 | f"Bpp loss: {bpp_loss.avg:.4f} | " 227 | f"Aux loss: {aux_loss.avg:.2f} | " 228 | f"PSNR: {psnr.avg:.6f} | " 229 | f"MS-SSIM: {ms_ssim.avg:.6f}" 230 | ) 231 | tb_logger.add_scalar('{}'.format('[val]: ms_ssim_loss'), ms_ssim_loss.avg, epoch + 1) 232 | 233 | return loss.avg 234 | 235 | def newtest_conditional_one_epoch(epoch, test_dataloader, model, criterion, save_dir, logger_val, tb_logger, clip_loss, lambda_clip): 236 | model.eval() 237 | device = next(model.parameters()).device 238 | 239 | loss = AverageMeter() 240 | bpp_loss = AverageMeter() 241 | mse_loss = AverageMeter() 242 | ms_ssim_loss = AverageMeter() 243 | aux_loss = AverageMeter() 244 | psnr = AverageMeter() 245 | ms_ssim = AverageMeter() 246 | cliploss = AverageMeter() 247 | 248 | with torch.no_grad(): 249 | for i, d in enumerate(test_dataloader): 250 | bs = d.size(0) 251 | betas = torch.randint(0, 2, (bs,)) 252 | 253 | 254 | 255 | d = d.to(device) 256 | betas = betas.to(device) 257 | 258 | out_net = model((d, betas)) 259 | out_criterion = criterion(out_net, d) 260 | clip_out = clip_loss(out_net["x_hat"], d, "eval") 261 | closs = clip_out["fc"] + clip_out["clip_conv_all"] 262 | closs = closs.mean() 263 | aux_loss.update(model.aux_loss()) 264 | bpp_loss.update(out_criterion["bpp_loss"]) 265 | cliploss.update(closs) 266 | loss.update(out_criterion["loss"]+(lambda_clip * closs)) 267 | if out_criterion["mse_loss"] is not None: 268 | mse_loss.update(out_criterion["mse_loss"]) 269 | if out_criterion["ms_ssim_loss"] is not None: 270 | ms_ssim_loss.update(out_criterion["ms_ssim_loss"]) 271 | 272 | 273 | 274 | rec = torch2img(out_net['x_hat'][0]) 275 | img = torch2img(d[0]) 276 | p, m = compute_metrics(rec, img) 277 | psnr.update(p) 278 | ms_ssim.update(m) 279 | 280 | 281 | # if not os.path.exists(save_dir): 282 | # os.makedirs(save_dir) 283 | # rec.save(os.path.join(save_dir, '%03d_rec.png' % i)) 284 | # img.save(os.path.join(save_dir, '%03d_gt.png' % i)) 285 | 286 | tb_logger.add_scalar('{}'.format('[val]: loss'), loss.avg, epoch + 1) 287 | tb_logger.add_scalar('{}'.format('[val]: bpp_loss'), bpp_loss.avg, epoch + 1) 288 | tb_logger.add_scalar('{}'.format('[val]: cliploss'), cliploss.avg, epoch + 1) 289 | tb_logger.add_scalar('{}'.format('[val]: psnr'), psnr.avg, epoch + 1) 290 | tb_logger.add_scalar('{}'.format('[val]: ms-ssim'), ms_ssim.avg, epoch + 1) 291 | 292 | 293 | if out_criterion["mse_loss"] is not None: 294 | logger_val.info( 295 | f"Test epoch {epoch}: Average losses: " 296 | f"Loss: {loss.avg:.4f} | " 297 | f"MSE loss: {mse_loss.avg:.6f} | " 298 | f"CLIP loss: {cliploss.avg:.6f} | " 299 | f"Bpp loss: {bpp_loss.avg:.4f} | " 300 | f"Aux loss: {aux_loss.avg:.2f} | " 301 | f"PSNR: {psnr.avg:.6f} | " 302 | f"MS-SSIM: {ms_ssim.avg:.6f}" 303 | ) 304 | tb_logger.add_scalar('{}'.format('[val]: mse_loss'), mse_loss.avg, epoch + 1) 305 | if out_criterion["ms_ssim_loss"] is not None: 306 | logger_val.info( 307 | f"Test epoch {epoch}: Average losses: " 308 | f"Loss: {loss.avg:.4f} | " 309 | f"MS-SSIM loss: {ms_ssim_loss.avg:.6f} | " 310 | f"CLIP loss: {cliploss.avg:.6f} | " 311 | f"Bpp loss: {bpp_loss.avg:.4f} | " 312 | f"Aux loss: {aux_loss.avg:.2f} | " 313 | f"PSNR: {psnr.avg:.6f} | " 314 | f"MS-SSIM: {ms_ssim.avg:.6f}" 315 | ) 316 | tb_logger.add_scalar('{}'.format('[val]: ms_ssim_loss'), ms_ssim_loss.avg, epoch + 1) 317 | 318 | return loss.avg 319 | 320 | def compress_one_image(model, x, stream_path, H, W, img_name): 321 | with torch.no_grad(): 322 | out = model.compress(x) 323 | 324 | shape = out["shape"] 325 | output = os.path.join(stream_path, img_name) 326 | with Path(output).open("wb") as f: 327 | write_uints(f, (H, W)) 328 | write_body(f, shape, out["strings"]) 329 | 330 | size = filesize(output) 331 | bpp = float(size) * 8 / (H * W) 332 | return bpp, out["cost_time"] 333 | 334 | 335 | def decompress_one_image(model, stream_path, img_name): 336 | output = os.path.join(stream_path, img_name) 337 | with Path(output).open("rb") as f: 338 | original_size = read_uints(f, 2) 339 | strings, shape = read_body(f) 340 | 341 | with torch.no_grad(): 342 | out = model.decompress(strings, shape) 343 | 344 | x_hat = out["x_hat"] 345 | x_hat = x_hat[:, :, 0 : original_size[0], 0 : original_size[1]] 346 | cost_time = out["cost_time"] 347 | return x_hat, cost_time 348 | 349 | 350 | 351 | def test_model(test_dataloader, net, logger_test, save_dir, epoch): 352 | net.eval() 353 | device = next(net.parameters()).device 354 | 355 | avg_psnr = AverageMeter() 356 | avg_ms_ssim = AverageMeter() 357 | avg_bpp = AverageMeter() 358 | avg_enc_time = AverageMeter() 359 | avg_dec_time = AverageMeter() 360 | 361 | with torch.no_grad(): 362 | for i, data in enumerate(test_dataloader): 363 | img, names = data 364 | img = img.to(device) 365 | B, C, H, W = img.shape 366 | if H >1000 or W>1000: 367 | img = F.interpolate(img, size=512, mode='bilinear', align_corners=False) 368 | B, C, H, W = img.shape 369 | pad_h = 0 370 | pad_w = 0 371 | if H % 64 != 0: 372 | pad_h = 64 * (H // 64 + 1) - H 373 | if W % 64 != 0: 374 | pad_w = 64 * (W // 64 + 1) - W 375 | img_pad = F.pad(img, (0, pad_w, 0, pad_h), mode='constant', value=0) 376 | # warmup GPU 377 | if i == 0: 378 | bpp, enc_time = compress_one_image(model=net, x=img_pad, stream_path=save_dir, H=H, W=W, img_name=str(i)) 379 | # avoid resolution leakage 380 | net.update_resolutions(16, 16) 381 | bpp, enc_time = compress_one_image(model=net, x=img_pad, stream_path=save_dir, H=H, W=W, img_name=str(i)) 382 | # avoid resolution leakage 383 | net.update_resolutions(16, 16) 384 | x_hat, dec_time = decompress_one_image(model=net, stream_path=save_dir, img_name=str(i)) 385 | rec = torch2img(x_hat) 386 | img = torch2img(img) 387 | # img.save(os.path.join(save_dir, '%03d_gt.png' % i)) 388 | rec.save(os.path.join(save_dir, names[0]) + ".png") 389 | p, m = compute_metrics2(rec, img) 390 | avg_psnr.update(p) 391 | avg_ms_ssim.update(m) 392 | avg_bpp.update(bpp) 393 | avg_enc_time.update(enc_time) 394 | avg_dec_time.update(dec_time) 395 | logger_test.info( 396 | f"Image[{i}] | " 397 | f"Bpp loss: {bpp:.2f} | " 398 | f"PSNR: {p:.4f} | " 399 | f"MS-SSIM: {m:.4f} | " 400 | f"Encoding Latency: {enc_time:.4f} | " 401 | f"Decoding Latency: {dec_time:.4f}" 402 | ) 403 | logger_test.info( 404 | f"Epoch:[{epoch}] | " 405 | f"Avg Bpp: {avg_bpp.avg:.4f} | " 406 | f"Avg PSNR: {avg_psnr.avg:.4f} | " 407 | f"Avg MS-SSIM: {avg_ms_ssim.avg:.4f} | " 408 | f"Avg Encoding Latency:: {avg_enc_time.avg:.4f} | " 409 | f"Avg decoding Latency:: {avg_dec_time.avg:.4f}" 410 | ) 411 | 412 | 413 | -------------------------------------------------------------------------------- /clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.relu3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.relu1(self.bn1(self.conv1(x))) 46 | out = self.relu2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x[:1], key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0, 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | return x.squeeze(0) 92 | 93 | 94 | class ModifiedResNet(nn.Module): 95 | """ 96 | A ResNet class that is similar to torchvision's but contains the following changes: 97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 99 | - The final pooling layer is a QKV attention instead of an average pool 100 | """ 101 | 102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 103 | super().__init__() 104 | self.output_dim = output_dim 105 | self.input_resolution = input_resolution 106 | 107 | # the 3-layer stem 108 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 109 | self.bn1 = nn.BatchNorm2d(width // 2) 110 | self.relu1 = nn.ReLU(inplace=True) 111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 112 | self.bn2 = nn.BatchNorm2d(width // 2) 113 | self.relu2 = nn.ReLU(inplace=True) 114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 115 | self.bn3 = nn.BatchNorm2d(width) 116 | self.relu3 = nn.ReLU(inplace=True) 117 | self.avgpool = nn.AvgPool2d(2) 118 | 119 | # residual layers 120 | self._inplanes = width # this is a *mutable* variable used during construction 121 | self.layer1 = self._make_layer(width, layers[0]) 122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 125 | 126 | embed_dim = width * 32 # the ResNet feature dimension 127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 128 | 129 | def _make_layer(self, planes, blocks, stride=1): 130 | layers = [Bottleneck(self._inplanes, planes, stride)] 131 | 132 | self._inplanes = planes * Bottleneck.expansion 133 | for _ in range(1, blocks): 134 | layers.append(Bottleneck(self._inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | def stem(x): 140 | x = self.relu1(self.bn1(self.conv1(x))) 141 | x = self.relu2(self.bn2(self.conv2(x))) 142 | x = self.relu3(self.bn3(self.conv3(x))) 143 | x = self.avgpool(x) 144 | return x 145 | 146 | x = x.type(self.conv1.weight.dtype) 147 | x = stem(x) 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | x = self.layer4(x) 152 | x = self.attnpool(x) 153 | 154 | return x 155 | 156 | 157 | class LayerNorm(nn.LayerNorm): 158 | """Subclass torch's LayerNorm to handle fp16.""" 159 | 160 | def forward(self, x: torch.Tensor): 161 | orig_type = x.dtype 162 | ret = super().forward(x.type(torch.float32)) 163 | return ret.type(orig_type) 164 | 165 | 166 | class QuickGELU(nn.Module): 167 | def forward(self, x: torch.Tensor): 168 | return x * torch.sigmoid(1.702 * x) 169 | 170 | 171 | class ResidualAttentionBlock(nn.Module): 172 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 173 | super().__init__() 174 | 175 | self.attn = nn.MultiheadAttention(d_model, n_head) 176 | self.ln_1 = LayerNorm(d_model) 177 | self.mlp = nn.Sequential(OrderedDict([ 178 | ("c_fc", nn.Linear(d_model, d_model * 4)), 179 | ("gelu", QuickGELU()), 180 | ("c_proj", nn.Linear(d_model * 4, d_model)) 181 | ])) 182 | self.ln_2 = LayerNorm(d_model) 183 | self.attn_mask = attn_mask 184 | 185 | def attention(self, x: torch.Tensor): 186 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 187 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 188 | 189 | def forward(self, x: torch.Tensor): 190 | x = x + self.attention(self.ln_1(x)) 191 | x = x + self.mlp(self.ln_2(x)) 192 | return x 193 | 194 | 195 | class Transformer(nn.Module): 196 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 197 | super().__init__() 198 | self.width = width 199 | self.layers = layers 200 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 201 | 202 | def forward(self, x: torch.Tensor): 203 | return self.resblocks(x) 204 | 205 | 206 | class VisionTransformer(nn.Module): 207 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 208 | super().__init__() 209 | self.input_resolution = input_resolution 210 | self.output_dim = output_dim 211 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 212 | 213 | scale = width ** -0.5 214 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 215 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 216 | self.ln_pre = LayerNorm(width) 217 | 218 | self.transformer = Transformer(width, layers, heads) 219 | 220 | self.ln_post = LayerNorm(width) 221 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 222 | 223 | def forward(self, x: torch.Tensor): 224 | x = self.conv1(x) # shape = [*, width, grid, grid] 225 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 226 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 227 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 228 | x = x + self.positional_embedding.to(x.dtype) 229 | x = self.ln_pre(x) 230 | 231 | x = x.permute(1, 0, 2) # NLD -> LND 232 | x = self.transformer(x) 233 | x = x.permute(1, 0, 2) # LND -> NLD 234 | 235 | x = self.ln_post(x[:, 0, :]) 236 | 237 | if self.proj is not None: 238 | x = x @ self.proj 239 | 240 | return x 241 | 242 | 243 | class CLIP(nn.Module): 244 | def __init__(self, 245 | embed_dim: int, 246 | # vision 247 | image_resolution: int, 248 | vision_layers: Union[Tuple[int, int, int, int], int], 249 | vision_width: int, 250 | vision_patch_size: int, 251 | # text 252 | context_length: int, 253 | vocab_size: int, 254 | transformer_width: int, 255 | transformer_heads: int, 256 | transformer_layers: int 257 | ): 258 | super().__init__() 259 | 260 | self.context_length = context_length 261 | 262 | if isinstance(vision_layers, (tuple, list)): 263 | vision_heads = vision_width * 32 // 64 264 | self.visual = ModifiedResNet( 265 | layers=vision_layers, 266 | output_dim=embed_dim, 267 | heads=vision_heads, 268 | input_resolution=image_resolution, 269 | width=vision_width 270 | ) 271 | else: 272 | vision_heads = vision_width // 64 273 | self.visual = VisionTransformer( 274 | input_resolution=image_resolution, 275 | patch_size=vision_patch_size, 276 | width=vision_width, 277 | layers=vision_layers, 278 | heads=vision_heads, 279 | output_dim=embed_dim 280 | ) 281 | 282 | self.transformer = Transformer( 283 | width=transformer_width, 284 | layers=transformer_layers, 285 | heads=transformer_heads, 286 | attn_mask=self.build_attention_mask() 287 | ) 288 | 289 | self.vocab_size = vocab_size 290 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 291 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 292 | self.ln_final = LayerNorm(transformer_width) 293 | 294 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 295 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 296 | 297 | self.initialize_parameters() 298 | 299 | def initialize_parameters(self): 300 | nn.init.normal_(self.token_embedding.weight, std=0.02) 301 | nn.init.normal_(self.positional_embedding, std=0.01) 302 | 303 | if isinstance(self.visual, ModifiedResNet): 304 | if self.visual.attnpool is not None: 305 | std = self.visual.attnpool.c_proj.in_features ** -0.5 306 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 307 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 308 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 309 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 310 | 311 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 312 | for name, param in resnet_block.named_parameters(): 313 | if name.endswith("bn3.weight"): 314 | nn.init.zeros_(param) 315 | 316 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 317 | attn_std = self.transformer.width ** -0.5 318 | fc_std = (2 * self.transformer.width) ** -0.5 319 | for block in self.transformer.resblocks: 320 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 321 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 322 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 323 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 324 | 325 | if self.text_projection is not None: 326 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 327 | 328 | def build_attention_mask(self): 329 | # lazily create causal attention mask, with full attention between the vision tokens 330 | # pytorch uses additive attention mask; fill with -inf 331 | mask = torch.empty(self.context_length, self.context_length) 332 | mask.fill_(float("-inf")) 333 | mask.triu_(1) # zero out the lower diagonal 334 | return mask 335 | 336 | @property 337 | def dtype(self): 338 | return self.visual.conv1.weight.dtype 339 | 340 | def encode_image(self, image): 341 | return self.visual(image.type(self.dtype)) 342 | 343 | def encode_text(self, text): 344 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 345 | 346 | x = x + self.positional_embedding.type(self.dtype) 347 | x = x.permute(1, 0, 2) # NLD -> LND 348 | x = self.transformer(x) 349 | x = x.permute(1, 0, 2) # LND -> NLD 350 | x = self.ln_final(x).type(self.dtype) 351 | 352 | # x.shape = [batch_size, n_ctx, transformer.width] 353 | # take features from the eot embedding (eot_token is the highest number in each sequence) 354 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 355 | 356 | return x 357 | 358 | def forward(self, image, text): 359 | image_features = self.encode_image(image) 360 | text_features = self.encode_text(text) 361 | 362 | # normalized features 363 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 364 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 365 | 366 | # cosine similarity as logits 367 | logit_scale = self.logit_scale.exp() 368 | logits_per_image = logit_scale * image_features @ text_features.t() 369 | logits_per_text = logits_per_image.t() 370 | 371 | # shape = [global_batch_size, global_batch_size] 372 | return logits_per_image, logits_per_text 373 | 374 | 375 | def convert_weights(model: nn.Module): 376 | """Convert applicable model parameters to fp16""" 377 | 378 | def _convert_weights_to_fp16(l): 379 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 380 | l.weight.data = l.weight.data.half() 381 | if l.bias is not None: 382 | l.bias.data = l.bias.data.half() 383 | 384 | if isinstance(l, nn.MultiheadAttention): 385 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 386 | tensor = getattr(l, attr) 387 | if tensor is not None: 388 | tensor.data = tensor.data.half() 389 | 390 | for name in ["text_projection", "proj"]: 391 | if hasattr(l, name): 392 | attr = getattr(l, name) 393 | if attr is not None: 394 | attr.data = attr.data.half() 395 | 396 | model.apply(_convert_weights_to_fp16) 397 | 398 | 399 | def build_model(state_dict: dict): 400 | vit = "visual.proj" in state_dict 401 | 402 | if vit: 403 | vision_width = state_dict["visual.conv1.weight"].shape[0] 404 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 405 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 406 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 407 | image_resolution = vision_patch_size * grid_size 408 | else: 409 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 410 | vision_layers = tuple(counts) 411 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 412 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 413 | vision_patch_size = None 414 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 415 | image_resolution = output_width * 32 416 | 417 | embed_dim = state_dict["text_projection"].shape[1] 418 | context_length = state_dict["positional_embedding"].shape[0] 419 | vocab_size = state_dict["token_embedding.weight"].shape[0] 420 | transformer_width = state_dict["ln_final.weight"].shape[0] 421 | transformer_heads = transformer_width // 64 422 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 423 | 424 | model = CLIP( 425 | embed_dim, 426 | image_resolution, vision_layers, vision_width, vision_patch_size, 427 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 428 | ) 429 | 430 | for key in ["input_resolution", "context_length", "vocab_size"]: 431 | if key in state_dict: 432 | del state_dict[key] 433 | 434 | convert_weights(model) 435 | model.load_state_dict(state_dict) 436 | return model.eval() 437 | -------------------------------------------------------------------------------- /models/mlicpp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import time 5 | from compressai.models import CompressionModel 6 | from compressai.ops import ste_round 7 | from compressai.ans import BufferedRansEncoder, RansDecoder 8 | from utils.func import update_registered_buffers, get_scale_table 9 | from utils.ckbd import * 10 | from modules.transform import * 11 | 12 | 13 | class MLICPlusPlus(CompressionModel): 14 | def __init__(self, config, **kwargs): 15 | super().__init__(config.N, **kwargs) 16 | N = config.N 17 | M = config.M 18 | context_window = config.context_window 19 | slice_num = config.slice_num 20 | slice_ch = M // slice_num 21 | assert slice_ch * slice_num == M 22 | 23 | self.N = N 24 | self.M = M 25 | self.context_window = context_window 26 | self.slice_num = slice_num 27 | self.slice_ch = slice_ch 28 | 29 | self.g_a = AnalysisTransform(N=N, M=M) 30 | self.g_s = SynthesisTransform(N=N, M=M) 31 | 32 | self.h_a = HyperAnalysis(M=M, N=N) 33 | self.h_s = HyperSynthesis(M=M, N=N) 34 | 35 | # Gussian Conditional 36 | self.gaussian_conditional = GaussianConditional(None) 37 | 38 | self.local_context = nn.ModuleList( 39 | LocalContext(dim=slice_ch) 40 | for _ in range(slice_num) 41 | ) 42 | 43 | self.channel_context = nn.ModuleList( 44 | ChannelContext(in_dim=slice_ch * i, out_dim=slice_ch) if i else None 45 | for i in range(slice_num) 46 | ) 47 | 48 | # Global Reference for non-anchors 49 | self.global_inter_context = nn.ModuleList( 50 | LinearGlobalInterContext(dim=slice_ch * i, out_dim=slice_ch * 2, num_heads=slice_ch * i // 32) if i else None 51 | for i in range(slice_num) 52 | ) 53 | self.global_intra_context = nn.ModuleList( 54 | LinearGlobalIntraContext(dim=slice_ch) if i else None 55 | for i in range(slice_num) 56 | ) 57 | self.entropy_parameters_anchor = nn.ModuleList( 58 | EntropyParameters(in_dim=M * 2 + slice_ch * 6, out_dim=slice_ch * 2) 59 | if i else EntropyParameters(in_dim=M * 2, out_dim=slice_ch * 2) 60 | for i in range(slice_num) 61 | ) 62 | self.entropy_parameters_nonanchor = nn.ModuleList( 63 | EntropyParameters(in_dim=M * 2 + slice_ch * 10, out_dim=slice_ch * 2) 64 | if i else EntropyParameters(in_dim=M * 2 + slice_ch * 2, out_dim=slice_ch * 2) 65 | for i in range(slice_num) 66 | ) 67 | 68 | # Latent Residual Prediction 69 | self.lrp_anchor = nn.ModuleList( 70 | LatentResidualPrediction(in_dim=M + (i + 1) * slice_ch, out_dim=slice_ch) 71 | for i in range(slice_num) 72 | ) 73 | self.lrp_nonanchor = nn.ModuleList( 74 | LatentResidualPrediction(in_dim=M + (i + 1) * slice_ch, out_dim=slice_ch) 75 | for i in range(slice_num) 76 | ) 77 | 78 | 79 | def forward(self, x): 80 | """ 81 | Using checkerboard context model with mask attention 82 | which divides y into anchor and non-anchor parts 83 | non-anchor use anchor as spatial context 84 | In addition, a channel-wise entropy model is used, too. 85 | Args: 86 | x: [B, 3, H, W] 87 | return: 88 | x_hat: [B, 3, H, W] 89 | y_likelihoods: [B, M, H // 16, W // 16] 90 | z_likelihoods: [B, N, H // 64, W // 64] 91 | likelihoods: y_likelihoods, z_likelihoods 92 | """ 93 | self.update_resolutions(x.size(2) // 16, x.size(3) // 16) 94 | y = self.g_a(x) 95 | z = self.h_a(y) 96 | _, z_likelihoods = self.entropy_bottleneck(z) 97 | z_offset = self.entropy_bottleneck._get_medians() 98 | z_hat = ste_round(z - z_offset) + z_offset 99 | 100 | # Hyper-parameters 101 | hyper_params = self.h_s(z_hat) 102 | hyper_scales, hyper_means = hyper_params.chunk(2, 1) 103 | 104 | y_slices = y.chunk(self.slice_num, dim=1) 105 | y_hat_slices = [] 106 | y_likelihoods = [] 107 | for idx, y_slice in enumerate(y_slices): 108 | slice_anchor, slice_nonanchor = ckbd_split(y_slice) 109 | if idx == 0: 110 | # Anchor 111 | params_anchor = self.entropy_parameters_anchor[idx](hyper_params) 112 | scales_anchor, means_anchor = params_anchor.chunk(2, 1) 113 | # split means and scales of anchor 114 | scales_anchor = ckbd_anchor(scales_anchor) 115 | means_anchor = ckbd_anchor(means_anchor) 116 | # round anchor 117 | slice_anchor = ste_round(slice_anchor - means_anchor) + means_anchor 118 | # predict residuals cause by round 119 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1)) 120 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor) 121 | # Non-anchor 122 | # local_ctx: [B, H, W, 2 * C] 123 | local_ctx = self.local_context[idx](slice_anchor) 124 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, hyper_params], dim=1)) 125 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1) 126 | # split means and scales of nonanchor 127 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor) 128 | means_nonanchor = ckbd_nonanchor(means_nonanchor) 129 | # merge means and scales of anchor and nonanchor 130 | scales_slice = ckbd_merge(scales_anchor, scales_nonanchor) 131 | means_slice = ckbd_merge(means_anchor, means_nonanchor) 132 | _, y_slice_likelihoods = self.gaussian_conditional(y_slice, scales_slice, means_slice) 133 | # round slice_nonanchor 134 | slice_nonanchor = ste_round(slice_nonanchor - means_nonanchor) + means_nonanchor 135 | y_hat_slice = slice_anchor + slice_nonanchor 136 | # predict residuals cause by round 137 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [y_hat_slice]), dim=1)) 138 | y_hat_slice = y_hat_slice + ckbd_nonanchor(lrp_nonanchor) 139 | y_hat_slices.append(y_hat_slice) 140 | y_likelihoods.append(y_slice_likelihoods) 141 | 142 | else: 143 | global_inter_ctx = self.global_inter_context[idx](torch.cat(y_hat_slices, dim=1)) 144 | channel_ctx = self.channel_context[idx](torch.cat(y_hat_slices, dim=1)) 145 | # Anchor(Use channel context and hyper params) 146 | params_anchor = self.entropy_parameters_anchor[idx](torch.cat([global_inter_ctx, channel_ctx, hyper_params], dim=1)) 147 | scales_anchor, means_anchor = params_anchor.chunk(2, 1) 148 | # split means and scales of anchor 149 | scales_anchor = ckbd_anchor(scales_anchor) 150 | means_anchor = ckbd_anchor(means_anchor) 151 | # round anchor 152 | slice_anchor = ste_round(slice_anchor - means_anchor) + means_anchor 153 | # predict residuals cause by round 154 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1)) 155 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor) 156 | # Non-anchor(Use spatial context, channel context and hyper params) 157 | global_intra_ctx = self.global_intra_context[idx](y_hat_slices[-1], slice_anchor) 158 | # ctx_params: [B, H, W, 2 * C] 159 | local_ctx = self.local_context[idx](slice_anchor) 160 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, global_intra_ctx, global_inter_ctx, channel_ctx, hyper_params], dim=1)) 161 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1) 162 | # split means and scales of nonanchor 163 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor) 164 | means_nonanchor = ckbd_nonanchor(means_nonanchor) 165 | # merge means and scales of anchor and nonanchor 166 | scales_slice = ckbd_merge(scales_anchor, scales_nonanchor) 167 | means_slice = ckbd_merge(means_anchor, means_nonanchor) 168 | _, y_slice_likelihoods = self.gaussian_conditional(y_slice, scales_slice, means_slice) 169 | # round slice_nonanchor 170 | slice_nonanchor = ste_round(slice_nonanchor - means_nonanchor) + means_nonanchor 171 | y_hat_slice = slice_anchor + slice_nonanchor 172 | # predict residuals cause by round 173 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [y_hat_slice]), dim=1)) 174 | y_hat_slice = y_hat_slice + ckbd_nonanchor(lrp_nonanchor) 175 | y_hat_slices.append(y_hat_slice) 176 | y_likelihoods.append(y_slice_likelihoods) 177 | 178 | y_hat = torch.cat(y_hat_slices, dim=1) 179 | y_likelihoods = torch.cat(y_likelihoods, dim=1) 180 | x_hat = self.g_s(y_hat) 181 | 182 | return { 183 | "x_hat": x_hat, 184 | "likelihoods": {"y_likelihoods": y_likelihoods, "z_likelihoods": z_likelihoods} 185 | } 186 | 187 | def update_resolutions(self, H, W): 188 | for i in range(len(self.global_intra_context)): 189 | if i == 0: 190 | self.local_context[i].update_resolution(H, W, next(self.parameters()).device, mask=None) 191 | else: 192 | self.local_context[i].update_resolution(H, W, next(self.parameters()).device, mask=self.local_context[0].attn_mask) 193 | 194 | def compress(self, x): 195 | torch.cuda.synchronize() 196 | start_time = time.time() 197 | self.update_resolutions(x.size(2) // 16, x.size(3) // 16) 198 | y = self.g_a(x) 199 | z = self.h_a(y) 200 | z_strings = self.entropy_bottleneck.compress(z) 201 | z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) 202 | hyper_params = self.h_s(z_hat) 203 | hyper_scales, hyper_means = hyper_params.chunk(2, 1) 204 | y_slices = y.chunk(self.slice_num, dim=1) 205 | y_hat_slices = [] 206 | 207 | cdf = self.gaussian_conditional.quantized_cdf.tolist() 208 | cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist() 209 | offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist() 210 | encoder = BufferedRansEncoder() 211 | symbols_list = [] 212 | indexes_list = [] 213 | y_strings = [] 214 | 215 | for idx, y_slice in enumerate(y_slices): 216 | slice_anchor, slice_nonanchor = ckbd_split(y_slice) 217 | if idx == 0: 218 | # Anchor 219 | params_anchor = self.entropy_parameters_anchor[idx](hyper_params) 220 | scales_anchor, means_anchor = params_anchor.chunk(2, 1) 221 | # split means and scales of anchor 222 | scales_anchor = ckbd_anchor(scales_anchor) 223 | means_anchor = ckbd_anchor(means_anchor) 224 | # round and compress anchor 225 | slice_anchor = compress_anchor(self.gaussian_conditional, slice_anchor, scales_anchor, means_anchor, symbols_list, indexes_list) 226 | # predict residuals caused by round 227 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1)) 228 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor) 229 | # Non-anchor 230 | # local_ctx: [B,2 * C, H, W] 231 | local_ctx = self.local_context[idx](slice_anchor) 232 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, hyper_params], dim=1)) 233 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1) 234 | # split means and scales of nonanchor 235 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor) 236 | means_nonanchor = ckbd_nonanchor(means_nonanchor) 237 | # round and compress nonanchor 238 | slice_nonanchor = compress_nonanchor(self.gaussian_conditional, slice_nonanchor, scales_nonanchor, means_nonanchor, symbols_list, indexes_list) 239 | # predict residuals caused by round 240 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_nonanchor + slice_anchor]), dim=1)) 241 | slice_nonanchor = slice_nonanchor + ckbd_nonanchor(lrp_nonanchor) 242 | y_hat_slices.append(slice_nonanchor + slice_anchor) 243 | 244 | else: 245 | # Anchor 246 | global_inter_ctx = self.global_inter_context[idx](torch.cat(y_hat_slices, dim=1)) 247 | channel_ctx = self.channel_context[idx](torch.cat(y_hat_slices, dim=1)) 248 | params_anchor = self.entropy_parameters_anchor[idx](torch.cat([global_inter_ctx, channel_ctx, hyper_params], dim=1)) 249 | scales_anchor, means_anchor = params_anchor.chunk(2, 1) 250 | # split means and scales of anchor 251 | scales_anchor = ckbd_anchor(scales_anchor) 252 | means_anchor = ckbd_anchor(means_anchor) 253 | # round and compress anchor 254 | slice_anchor = compress_anchor(self.gaussian_conditional, slice_anchor, scales_anchor, means_anchor, symbols_list, indexes_list) 255 | # predict residuals caused by round 256 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1)) 257 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor) 258 | # Non-anchor 259 | global_intra_ctx = self.global_intra_context[idx](y_hat_slices[-1], slice_anchor) 260 | # local_ctx: [B,2 * C, H, W] 261 | local_ctx = self.local_context[idx](slice_anchor) 262 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, global_intra_ctx, global_inter_ctx, channel_ctx, hyper_params], dim=1)) 263 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1) 264 | # split means and scales of nonanchor 265 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor) 266 | means_nonanchor = ckbd_nonanchor(means_nonanchor) 267 | # round and compress nonanchor 268 | slice_nonanchor = compress_nonanchor(self.gaussian_conditional, slice_nonanchor, scales_nonanchor, means_nonanchor, symbols_list, indexes_list) 269 | # predict residuals caused by round 270 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_nonanchor + slice_anchor]), dim=1)) 271 | slice_nonanchor = slice_nonanchor + ckbd_nonanchor(lrp_nonanchor) 272 | y_hat_slices.append(slice_nonanchor + slice_anchor) 273 | 274 | encoder.encode_with_indexes(symbols_list, indexes_list, cdf, cdf_lengths, offsets) 275 | y_string = encoder.flush() 276 | y_strings.append(y_string) 277 | torch.cuda.synchronize() 278 | end_time = time.time() 279 | 280 | cost_time = end_time - start_time 281 | return { 282 | "strings": [y_strings, z_strings], 283 | "shape": z.size()[-2:], 284 | "cost_time": cost_time 285 | } 286 | 287 | def decompress(self, strings, shape): 288 | torch.cuda.synchronize() 289 | start_time = time.time() 290 | y_strings = strings[0][0] 291 | z_strings = strings[1] 292 | z_hat = self.entropy_bottleneck.decompress(z_strings, shape) 293 | self.update_resolutions(z_hat.size(2) * 4, z_hat.size(3) * 4) 294 | hyper_params = self.h_s(z_hat) 295 | hyper_scales, hyper_means = hyper_params.chunk(2, 1) 296 | y_hat_slices = [] 297 | 298 | cdf = self.gaussian_conditional.quantized_cdf.tolist() 299 | cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist() 300 | offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist() 301 | decoder = RansDecoder() 302 | decoder.set_stream(y_strings) 303 | 304 | for idx in range(self.slice_num): 305 | if idx == 0: 306 | # Anchor 307 | params_anchor = self.entropy_parameters_anchor[idx](hyper_params) 308 | scales_anchor, means_anchor = params_anchor.chunk(2, 1) 309 | # split means and scales of anchor 310 | scales_anchor = ckbd_anchor(scales_anchor) 311 | means_anchor = ckbd_anchor(means_anchor) 312 | # decompress anchor 313 | slice_anchor = decompress_anchor(self.gaussian_conditional, scales_anchor, means_anchor, decoder, cdf, cdf_lengths, offsets) 314 | # predict residuals caused by round 315 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1)) 316 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor) 317 | # Non-anchor 318 | # local_ctx: [B,2 * C, H, W] 319 | local_ctx = self.local_context[idx](slice_anchor) 320 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, hyper_params], dim=1)) 321 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1) 322 | # split means and scales of nonanchor 323 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor) 324 | means_nonanchor = ckbd_nonanchor(means_nonanchor) 325 | # decompress non-anchor 326 | slice_nonanchor = decompress_nonanchor(self.gaussian_conditional, scales_nonanchor, means_nonanchor, decoder, cdf, cdf_lengths, offsets) 327 | # predict residuals caused by round 328 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_nonanchor + slice_anchor]), dim=1)) 329 | slice_nonanchor = slice_nonanchor + ckbd_nonanchor(lrp_nonanchor) 330 | y_hat_slices.append(slice_nonanchor + slice_anchor) 331 | 332 | else: 333 | # Anchor 334 | global_inter_ctx = self.global_inter_context[idx](torch.cat(y_hat_slices, dim=1)) 335 | channel_ctx = self.channel_context[idx](torch.cat(y_hat_slices, dim=1)) 336 | params_anchor = self.entropy_parameters_anchor[idx](torch.cat([global_inter_ctx, channel_ctx, hyper_params], dim=1)) 337 | scales_anchor, means_anchor = params_anchor.chunk(2, 1) 338 | # split means and scales of anchor 339 | scales_anchor = ckbd_anchor(scales_anchor) 340 | means_anchor = ckbd_anchor(means_anchor) 341 | # decompress anchor 342 | slice_anchor = decompress_anchor(self.gaussian_conditional, scales_anchor, means_anchor, decoder, cdf, cdf_lengths, offsets) 343 | # predict residuals caused by round 344 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1)) 345 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor) 346 | # Non-anchor 347 | # Non-anchor 348 | global_intra_ctx = self.global_intra_context[idx](y_hat_slices[-1], slice_anchor) 349 | # local_ctx: [B,2 * C, H, W] 350 | local_ctx = self.local_context[idx](slice_anchor) 351 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, global_intra_ctx, global_inter_ctx, channel_ctx, hyper_params], dim=1)) 352 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1) 353 | # split means and scales of nonanchor 354 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor) 355 | means_nonanchor = ckbd_nonanchor(means_nonanchor) 356 | # decompress non-anchor 357 | slice_nonanchor = decompress_nonanchor(self.gaussian_conditional, scales_nonanchor, means_nonanchor, decoder, cdf, cdf_lengths, offsets) 358 | # predict residuals caused by round 359 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_nonanchor + slice_anchor]), dim=1)) 360 | slice_nonanchor = slice_nonanchor + ckbd_nonanchor(lrp_nonanchor) 361 | y_hat_slices.append(slice_nonanchor + slice_anchor) 362 | 363 | y_hat = torch.cat(y_hat_slices, dim=1) 364 | x_hat = self.g_s(y_hat) 365 | torch.cuda.synchronize() 366 | end_time = time.time() 367 | 368 | cost_time = end_time - start_time 369 | 370 | return { 371 | "x_hat": x_hat, 372 | "cost_time": cost_time 373 | } 374 | 375 | def load_state_dict(self, state_dict): 376 | update_registered_buffers( 377 | self.gaussian_conditional, 378 | "gaussian_conditional", 379 | ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"], 380 | state_dict, 381 | ) 382 | super().load_state_dict(state_dict) 383 | 384 | def update(self, scale_table=None, force=False): 385 | if scale_table is None: 386 | scale_table = get_scale_table() 387 | updated = self.gaussian_conditional.update_scale_table(scale_table, force=force) 388 | updated |= super().update(force=force) 389 | return updated 390 | -------------------------------------------------------------------------------- /models/ug_mlicpp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import time 5 | from compressai.models import CompressionModel 6 | from compressai.ops import ste_round 7 | from compressai.ans import BufferedRansEncoder, RansDecoder 8 | 9 | from models.fourier_cond import GlobalConditioning 10 | from utils.func import update_registered_buffers, get_scale_table 11 | from utils.ckbd import * 12 | from modules.transform import * 13 | 14 | 15 | class UG_MLICPlusPlus(CompressionModel): 16 | def __init__(self, config, **kwargs): 17 | super().__init__(config.N, **kwargs) 18 | 19 | self.beta = 0 20 | N = config.N 21 | M = config.M 22 | context_window = config.context_window 23 | slice_num = config.slice_num 24 | slice_ch = M // slice_num 25 | assert slice_ch * slice_num == M 26 | 27 | self.N = N 28 | self.M = M 29 | self.context_window = context_window 30 | self.slice_num = slice_num 31 | self.slice_ch = slice_ch 32 | 33 | self.g_a = AnalysisTransform(N=N, M=M) 34 | 35 | self.h_a = HyperAnalysis(M=M, N=N) 36 | self.h_s = HyperSynthesis(M=M, N=N) 37 | 38 | # Conditional g 39 | self.global_cond = GlobalConditioning() 40 | self.g_s = ConditionalSynthesisTransform(N=N, M=M) 41 | 42 | # Gussian Conditional 43 | self.gaussian_conditional = GaussianConditional(None) 44 | 45 | self.local_context = nn.ModuleList( 46 | LocalContext(dim=slice_ch) 47 | for _ in range(slice_num) 48 | ) 49 | 50 | self.channel_context = nn.ModuleList( 51 | ChannelContext(in_dim=slice_ch * i, out_dim=slice_ch) if i else None 52 | for i in range(slice_num) 53 | ) 54 | 55 | # Global Reference for non-anchors 56 | self.global_inter_context = nn.ModuleList( 57 | LinearGlobalInterContext(dim=slice_ch * i, out_dim=slice_ch * 2, num_heads=slice_ch * i // 32) if i else None 58 | for i in range(slice_num) 59 | ) 60 | self.global_intra_context = nn.ModuleList( 61 | LinearGlobalIntraContext(dim=slice_ch) if i else None 62 | for i in range(slice_num) 63 | ) 64 | self.entropy_parameters_anchor = nn.ModuleList( 65 | EntropyParameters(in_dim=M * 2 + slice_ch * 6, out_dim=slice_ch * 2) 66 | if i else EntropyParameters(in_dim=M * 2, out_dim=slice_ch * 2) 67 | for i in range(slice_num) 68 | ) 69 | self.entropy_parameters_nonanchor = nn.ModuleList( 70 | EntropyParameters(in_dim=M * 2 + slice_ch * 10, out_dim=slice_ch * 2) 71 | if i else EntropyParameters(in_dim=M * 2 + slice_ch * 2, out_dim=slice_ch * 2) 72 | for i in range(slice_num) 73 | ) 74 | 75 | # Latent Residual Prediction 76 | self.lrp_anchor = nn.ModuleList( 77 | LatentResidualPrediction(in_dim=M + (i + 1) * slice_ch, out_dim=slice_ch) 78 | for i in range(slice_num) 79 | ) 80 | self.lrp_nonanchor = nn.ModuleList( 81 | LatentResidualPrediction(in_dim=M + (i + 1) * slice_ch, out_dim=slice_ch) 82 | for i in range(slice_num) 83 | ) 84 | 85 | 86 | def forward(self, x): 87 | x, betas = x 88 | fourier_features_mlp = self.global_cond(betas) 89 | """ 90 | Using checkerboard context model with mask attention 91 | which divides y into anchor and non-anchor parts 92 | non-anchor use anchor as spatial context 93 | In addition, a channel-wise entropy model is used, too. 94 | Args: 95 | x: [B, 3, H, W] 96 | return: 97 | x_hat: [B, 3, H, W] 98 | y_likelihoods: [B, M, H // 16, W // 16] 99 | z_likelihoods: [B, N, H // 64, W // 64] 100 | likelihoods: y_likelihoods, z_likelihoods 101 | """ 102 | self.update_resolutions(x.size(2) // 16, x.size(3) // 16) 103 | y = self.g_a(x) 104 | z = self.h_a(y) 105 | _, z_likelihoods = self.entropy_bottleneck(z) 106 | z_offset = self.entropy_bottleneck._get_medians() 107 | z_hat = ste_round(z - z_offset) + z_offset 108 | 109 | # Hyper-parameters 110 | hyper_params = self.h_s(z_hat) 111 | hyper_scales, hyper_means = hyper_params.chunk(2, 1) 112 | 113 | y_slices = y.chunk(self.slice_num, dim=1) 114 | y_hat_slices = [] 115 | y_likelihoods = [] 116 | for idx, y_slice in enumerate(y_slices): 117 | slice_anchor, slice_nonanchor = ckbd_split(y_slice) 118 | if idx == 0: 119 | # Anchor 120 | params_anchor = self.entropy_parameters_anchor[idx](hyper_params) 121 | scales_anchor, means_anchor = params_anchor.chunk(2, 1) 122 | # split means and scales of anchor 123 | scales_anchor = ckbd_anchor(scales_anchor) 124 | means_anchor = ckbd_anchor(means_anchor) 125 | # round anchor 126 | slice_anchor = ste_round(slice_anchor - means_anchor) + means_anchor 127 | # predict residuals cause by round 128 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1)) 129 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor) 130 | # Non-anchor 131 | # local_ctx: [B, H, W, 2 * C] 132 | local_ctx = self.local_context[idx](slice_anchor) 133 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, hyper_params], dim=1)) 134 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1) 135 | # split means and scales of nonanchor 136 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor) 137 | means_nonanchor = ckbd_nonanchor(means_nonanchor) 138 | # merge means and scales of anchor and nonanchor 139 | scales_slice = ckbd_merge(scales_anchor, scales_nonanchor) 140 | means_slice = ckbd_merge(means_anchor, means_nonanchor) 141 | _, y_slice_likelihoods = self.gaussian_conditional(y_slice, scales_slice, means_slice) 142 | # round slice_nonanchor 143 | slice_nonanchor = ste_round(slice_nonanchor - means_nonanchor) + means_nonanchor 144 | y_hat_slice = slice_anchor + slice_nonanchor 145 | # predict residuals cause by round 146 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [y_hat_slice]), dim=1)) 147 | y_hat_slice = y_hat_slice + ckbd_nonanchor(lrp_nonanchor) 148 | y_hat_slices.append(y_hat_slice) 149 | y_likelihoods.append(y_slice_likelihoods) 150 | 151 | else: 152 | global_inter_ctx = self.global_inter_context[idx](torch.cat(y_hat_slices, dim=1)) 153 | channel_ctx = self.channel_context[idx](torch.cat(y_hat_slices, dim=1)) 154 | # Anchor(Use channel context and hyper params) 155 | params_anchor = self.entropy_parameters_anchor[idx](torch.cat([global_inter_ctx, channel_ctx, hyper_params], dim=1)) 156 | scales_anchor, means_anchor = params_anchor.chunk(2, 1) 157 | # split means and scales of anchor 158 | scales_anchor = ckbd_anchor(scales_anchor) 159 | means_anchor = ckbd_anchor(means_anchor) 160 | # round anchor 161 | slice_anchor = ste_round(slice_anchor - means_anchor) + means_anchor 162 | # predict residuals cause by round 163 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1)) 164 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor) 165 | # Non-anchor(Use spatial context, channel context and hyper params) 166 | global_intra_ctx = self.global_intra_context[idx](y_hat_slices[-1], slice_anchor) 167 | # ctx_params: [B, H, W, 2 * C] 168 | local_ctx = self.local_context[idx](slice_anchor) 169 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, global_intra_ctx, global_inter_ctx, channel_ctx, hyper_params], dim=1)) 170 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1) 171 | # split means and scales of nonanchor 172 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor) 173 | means_nonanchor = ckbd_nonanchor(means_nonanchor) 174 | # merge means and scales of anchor and nonanchor 175 | scales_slice = ckbd_merge(scales_anchor, scales_nonanchor) 176 | means_slice = ckbd_merge(means_anchor, means_nonanchor) 177 | _, y_slice_likelihoods = self.gaussian_conditional(y_slice, scales_slice, means_slice) 178 | # round slice_nonanchor 179 | slice_nonanchor = ste_round(slice_nonanchor - means_nonanchor) + means_nonanchor 180 | y_hat_slice = slice_anchor + slice_nonanchor 181 | # predict residuals cause by round 182 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [y_hat_slice]), dim=1)) 183 | y_hat_slice = y_hat_slice + ckbd_nonanchor(lrp_nonanchor) 184 | y_hat_slices.append(y_hat_slice) 185 | y_likelihoods.append(y_slice_likelihoods) 186 | 187 | y_hat = torch.cat(y_hat_slices, dim=1) 188 | y_likelihoods = torch.cat(y_likelihoods, dim=1) 189 | x_hat = self.g_s((y_hat, fourier_features_mlp)) 190 | 191 | return { 192 | "x_hat": x_hat, 193 | "likelihoods": {"y_likelihoods": y_likelihoods, "z_likelihoods": z_likelihoods} 194 | } 195 | 196 | def update_resolutions(self, H, W): 197 | for i in range(len(self.global_intra_context)): 198 | if i == 0: 199 | self.local_context[i].update_resolution(H, W, next(self.parameters()).device, mask=None) 200 | else: 201 | self.local_context[i].update_resolution(H, W, next(self.parameters()).device, mask=self.local_context[0].attn_mask) 202 | 203 | def compress(self, x): 204 | torch.cuda.synchronize() 205 | start_time = time.time() 206 | self.update_resolutions(x.size(2) // 16, x.size(3) // 16) 207 | y = self.g_a(x) 208 | z = self.h_a(y) 209 | z_strings = self.entropy_bottleneck.compress(z) 210 | z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:]) 211 | hyper_params = self.h_s(z_hat) 212 | hyper_scales, hyper_means = hyper_params.chunk(2, 1) 213 | y_slices = y.chunk(self.slice_num, dim=1) 214 | y_hat_slices = [] 215 | 216 | cdf = self.gaussian_conditional.quantized_cdf.tolist() 217 | cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist() 218 | offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist() 219 | encoder = BufferedRansEncoder() 220 | symbols_list = [] 221 | indexes_list = [] 222 | y_strings = [] 223 | 224 | for idx, y_slice in enumerate(y_slices): 225 | slice_anchor, slice_nonanchor = ckbd_split(y_slice) 226 | if idx == 0: 227 | # Anchor 228 | params_anchor = self.entropy_parameters_anchor[idx](hyper_params) 229 | scales_anchor, means_anchor = params_anchor.chunk(2, 1) 230 | # split means and scales of anchor 231 | scales_anchor = ckbd_anchor(scales_anchor) 232 | means_anchor = ckbd_anchor(means_anchor) 233 | # round and compress anchor 234 | slice_anchor = compress_anchor(self.gaussian_conditional, slice_anchor, scales_anchor, means_anchor, symbols_list, indexes_list) 235 | # predict residuals caused by round 236 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1)) 237 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor) 238 | # Non-anchor 239 | # local_ctx: [B,2 * C, H, W] 240 | local_ctx = self.local_context[idx](slice_anchor) 241 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, hyper_params], dim=1)) 242 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1) 243 | # split means and scales of nonanchor 244 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor) 245 | means_nonanchor = ckbd_nonanchor(means_nonanchor) 246 | # round and compress nonanchor 247 | slice_nonanchor = compress_nonanchor(self.gaussian_conditional, slice_nonanchor, scales_nonanchor, means_nonanchor, symbols_list, indexes_list) 248 | # predict residuals caused by round 249 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_nonanchor + slice_anchor]), dim=1)) 250 | slice_nonanchor = slice_nonanchor + ckbd_nonanchor(lrp_nonanchor) 251 | y_hat_slices.append(slice_nonanchor + slice_anchor) 252 | 253 | else: 254 | # Anchor 255 | global_inter_ctx = self.global_inter_context[idx](torch.cat(y_hat_slices, dim=1)) 256 | channel_ctx = self.channel_context[idx](torch.cat(y_hat_slices, dim=1)) 257 | params_anchor = self.entropy_parameters_anchor[idx](torch.cat([global_inter_ctx, channel_ctx, hyper_params], dim=1)) 258 | scales_anchor, means_anchor = params_anchor.chunk(2, 1) 259 | # split means and scales of anchor 260 | scales_anchor = ckbd_anchor(scales_anchor) 261 | means_anchor = ckbd_anchor(means_anchor) 262 | # round and compress anchor 263 | slice_anchor = compress_anchor(self.gaussian_conditional, slice_anchor, scales_anchor, means_anchor, symbols_list, indexes_list) 264 | # predict residuals caused by round 265 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1)) 266 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor) 267 | # Non-anchor 268 | global_intra_ctx = self.global_intra_context[idx](y_hat_slices[-1], slice_anchor) 269 | # local_ctx: [B,2 * C, H, W] 270 | local_ctx = self.local_context[idx](slice_anchor) 271 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, global_intra_ctx, global_inter_ctx, channel_ctx, hyper_params], dim=1)) 272 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1) 273 | # split means and scales of nonanchor 274 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor) 275 | means_nonanchor = ckbd_nonanchor(means_nonanchor) 276 | # round and compress nonanchor 277 | slice_nonanchor = compress_nonanchor(self.gaussian_conditional, slice_nonanchor, scales_nonanchor, means_nonanchor, symbols_list, indexes_list) 278 | # predict residuals caused by round 279 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_nonanchor + slice_anchor]), dim=1)) 280 | slice_nonanchor = slice_nonanchor + ckbd_nonanchor(lrp_nonanchor) 281 | y_hat_slices.append(slice_nonanchor + slice_anchor) 282 | 283 | encoder.encode_with_indexes(symbols_list, indexes_list, cdf, cdf_lengths, offsets) 284 | y_string = encoder.flush() 285 | y_strings.append(y_string) 286 | torch.cuda.synchronize() 287 | end_time = time.time() 288 | 289 | cost_time = end_time - start_time 290 | return { 291 | "strings": [y_strings, z_strings], 292 | "shape": z.size()[-2:], 293 | "cost_time": cost_time 294 | } 295 | 296 | def decompress(self, strings, shape, beta): 297 | torch.cuda.synchronize() 298 | start_time = time.time() 299 | y_strings = strings[0][0] 300 | z_strings = strings[1] 301 | z_hat = self.entropy_bottleneck.decompress(z_strings, shape) 302 | self.update_resolutions(z_hat.size(2) * 4, z_hat.size(3) * 4) 303 | hyper_params = self.h_s(z_hat) 304 | hyper_scales, hyper_means = hyper_params.chunk(2, 1) 305 | y_hat_slices = [] 306 | 307 | cdf = self.gaussian_conditional.quantized_cdf.tolist() 308 | cdf_lengths = self.gaussian_conditional.cdf_length.reshape(-1).int().tolist() 309 | offsets = self.gaussian_conditional.offset.reshape(-1).int().tolist() 310 | decoder = RansDecoder() 311 | decoder.set_stream(y_strings) 312 | 313 | for idx in range(self.slice_num): 314 | if idx == 0: 315 | # Anchor 316 | params_anchor = self.entropy_parameters_anchor[idx](hyper_params) 317 | scales_anchor, means_anchor = params_anchor.chunk(2, 1) 318 | # split means and scales of anchor 319 | scales_anchor = ckbd_anchor(scales_anchor) 320 | means_anchor = ckbd_anchor(means_anchor) 321 | # decompress anchor 322 | slice_anchor = decompress_anchor(self.gaussian_conditional, scales_anchor, means_anchor, decoder, cdf, cdf_lengths, offsets) 323 | # predict residuals caused by round 324 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1)) 325 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor) 326 | # Non-anchor 327 | # local_ctx: [B,2 * C, H, W] 328 | local_ctx = self.local_context[idx](slice_anchor) 329 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, hyper_params], dim=1)) 330 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1) 331 | # split means and scales of nonanchor 332 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor) 333 | means_nonanchor = ckbd_nonanchor(means_nonanchor) 334 | # decompress non-anchor 335 | slice_nonanchor = decompress_nonanchor(self.gaussian_conditional, scales_nonanchor, means_nonanchor, decoder, cdf, cdf_lengths, offsets) 336 | # predict residuals caused by round 337 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_nonanchor + slice_anchor]), dim=1)) 338 | slice_nonanchor = slice_nonanchor + ckbd_nonanchor(lrp_nonanchor) 339 | y_hat_slices.append(slice_nonanchor + slice_anchor) 340 | 341 | else: 342 | # Anchor 343 | global_inter_ctx = self.global_inter_context[idx](torch.cat(y_hat_slices, dim=1)) 344 | channel_ctx = self.channel_context[idx](torch.cat(y_hat_slices, dim=1)) 345 | params_anchor = self.entropy_parameters_anchor[idx](torch.cat([global_inter_ctx, channel_ctx, hyper_params], dim=1)) 346 | scales_anchor, means_anchor = params_anchor.chunk(2, 1) 347 | # split means and scales of anchor 348 | scales_anchor = ckbd_anchor(scales_anchor) 349 | means_anchor = ckbd_anchor(means_anchor) 350 | # decompress anchor 351 | slice_anchor = decompress_anchor(self.gaussian_conditional, scales_anchor, means_anchor, decoder, cdf, cdf_lengths, offsets) 352 | # predict residuals caused by round 353 | lrp_anchor = self.lrp_anchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_anchor]), dim=1)) 354 | slice_anchor = slice_anchor + ckbd_anchor(lrp_anchor) 355 | # Non-anchor 356 | # Non-anchor 357 | global_intra_ctx = self.global_intra_context[idx](y_hat_slices[-1], slice_anchor) 358 | # local_ctx: [B,2 * C, H, W] 359 | local_ctx = self.local_context[idx](slice_anchor) 360 | params_nonanchor = self.entropy_parameters_nonanchor[idx](torch.cat([local_ctx, global_intra_ctx, global_inter_ctx, channel_ctx, hyper_params], dim=1)) 361 | scales_nonanchor, means_nonanchor = params_nonanchor.chunk(2, 1) 362 | # split means and scales of nonanchor 363 | scales_nonanchor = ckbd_nonanchor(scales_nonanchor) 364 | means_nonanchor = ckbd_nonanchor(means_nonanchor) 365 | # decompress non-anchor 366 | slice_nonanchor = decompress_nonanchor(self.gaussian_conditional, scales_nonanchor, means_nonanchor, decoder, cdf, cdf_lengths, offsets) 367 | # predict residuals caused by round 368 | lrp_nonanchor = self.lrp_nonanchor[idx](torch.cat(([hyper_means] + y_hat_slices + [slice_nonanchor + slice_anchor]), dim=1)) 369 | slice_nonanchor = slice_nonanchor + ckbd_nonanchor(lrp_nonanchor) 370 | y_hat_slices.append(slice_nonanchor + slice_anchor) 371 | 372 | y_hat = torch.cat(y_hat_slices, dim=1) 373 | 374 | bs = y_hat.size(0) 375 | betas = torch.full((bs,), beta).to(y_hat.device) 376 | fourier_features_mlp = self.global_cond(betas) 377 | x_hat = self.g_s((y_hat, fourier_features_mlp)) 378 | torch.cuda.synchronize() 379 | end_time = time.time() 380 | 381 | cost_time = end_time - start_time 382 | 383 | return { 384 | "x_hat": x_hat, 385 | "cost_time": cost_time 386 | } 387 | 388 | def load_state_dict(self, state_dict): 389 | update_registered_buffers( 390 | self.gaussian_conditional, 391 | "gaussian_conditional", 392 | ["_quantized_cdf", "_offset", "_cdf_length", "scale_table"], 393 | state_dict, 394 | ) 395 | super().load_state_dict(state_dict) 396 | 397 | def update(self, scale_table=None, force=False): 398 | if scale_table is None: 399 | scale_table = get_scale_table() 400 | updated = self.gaussian_conditional.update_scale_table(scale_table, force=force) 401 | updated |= super().update(force=force) 402 | return updated 403 | --------------------------------------------------------------------------------