├── checkpoint └── .gitkeep ├── image ├── clamp │ ├── cover │ │ └── .gitkeep │ ├── secret │ │ └── .gitkeep │ ├── stego │ │ └── .gitkeep │ └── secret_rev │ │ └── .gitkeep └── wo_clamp │ ├── cover │ └── .gitkeep │ ├── secret │ └── .gitkeep │ ├── stego │ └── .gitkeep │ └── secret_rev │ └── .gitkeep ├── tensorboard_log └── .gitkeep ├── .gitignore ├── config.py ├── LICENSE ├── requirements.txt ├── README.md ├── critic.py ├── datasets.py ├── test_multiple_image_hiding.py ├── train_StegFormer_single_image.py ├── test_save_single_image_hiding.py ├── train_StegFormer_multiple_image.py ├── test_StegFormer.py └── model.py /checkpoint/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /image/clamp/cover/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /image/clamp/secret/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /image/clamp/stego/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tensorboard_log/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /image/clamp/secret_rev/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /image/wo_clamp/cover/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /image/wo_clamp/secret/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /image/wo_clamp/stego/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /image/wo_clamp/secret_rev/.gitkeep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | nohup** 3 | image/clamp/cover/* 4 | image/clamp/stego/* 5 | image/clamp/secret/* 6 | image/clamp/secret_rev/* 7 | !image/clamp/cover/.gitkeep 8 | !image/clamp/stego/.gitkeep 9 | !image/clamp/secret/.gitkeep 10 | !image/clamp/secret_rev/.gitkeep 11 | image/wo_clamp/cover/* 12 | image/wo_clamp/stego/* 13 | image/wo_clamp/secret/* 14 | image/wo_clamp/secret_rev/* 15 | !image/wo_clamp/cover/.gitkeep 16 | !image/wo_clamp/stego/.gitkeep 17 | !image/wo_clamp/secret/.gitkeep 18 | !image/wo_clamp/secret_rev/.gitkeep 19 | checkpoint/* 20 | !checkpoint/.gitkeep 21 | tensorboard_log/* 22 | !tensorboard_log/.gitkeep -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import torch 3 | 4 | # 以类的方式定义参数 5 | @dataclass 6 | class Args: 7 | # training config 8 | 9 | # model config 10 | image_size_train = 256 11 | image_size_test_single = 256 12 | image_size_test_multiple = 256 13 | num_secret = 4 14 | 15 | # optimer config 16 | lr = 2e-4 17 | warm_up_epoch = 20 18 | warm_up_lr_init = 5e-6 19 | 20 | # dataset 21 | DIV2K_path = '/home/whq135/dataset' # /home/whq135/dataset/DIV2K_train_HR 22 | single_batch_size = 12 23 | multi_batch_szie = 8 24 | multi_batch_iteration = (num_secret+1)*8 25 | test_multi_batch_size = num_secret+1 26 | 27 | epochs = 6000 28 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 29 | val_freq = 10 30 | save_freq = 200 31 | train_next = 0 32 | use_model = 'StegFormer-B' 33 | input_dim = 3 34 | 35 | norm_train = 'clamp' 36 | output_act = None 37 | path='/home/whq135/code/StegFormer' 38 | model_name='StegFormer-B_4baseline' -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 aoli-gei 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.3.0 2 | certifi==2023.5.7 3 | charset-normalizer==3.1.0 4 | cmake==3.26.3 5 | einops==0.6.1 6 | filelock==3.12.0 7 | fsspec==2023.5.0 8 | huggingface-hub==0.14.1 9 | idna==3.4 10 | imageio==2.29.0 11 | Jinja2==3.1.2 12 | joblib==1.2.0 13 | lazy_loader==0.2 14 | lit==16.0.5 15 | MarkupSafe==2.1.2 16 | mpmath==1.3.0 17 | natsort==8.3.1 18 | networkx==3.1 19 | numpy==1.24.3 20 | nvidia-cublas-cu11==11.10.3.66 21 | nvidia-cuda-cupti-cu11==11.7.101 22 | nvidia-cuda-nvrtc-cu11==11.7.99 23 | nvidia-cuda-runtime-cu11==11.7.99 24 | nvidia-cudnn-cu11==8.5.0.96 25 | nvidia-cufft-cu11==10.9.0.58 26 | nvidia-curand-cu11==10.2.10.91 27 | nvidia-cusolver-cu11==11.4.0.1 28 | nvidia-cusparse-cu11==11.7.4.91 29 | nvidia-nccl-cu11==2.14.3 30 | nvidia-nvtx-cu11==11.7.91 31 | opencv-python==4.7.0.72 32 | opencv-python-headless==4.7.0.72 33 | packaging==23.1 34 | Pillow==9.5.0 35 | protobuf==3.20.3 36 | PyWavelets==1.4.1 37 | PyYAML==6.0 38 | qudida==0.0.4 39 | requests==2.31.0 40 | safetensors==0.3.1 41 | scikit-image==0.20.0 42 | scikit-learn==1.2.2 43 | scipy==1.10.1 44 | sympy==1.12 45 | tensorboardX==2.6 46 | thop==0.1.1.post2209072238 47 | threadpoolctl==3.1.0 48 | tifffile==2023.4.12 49 | timm==0.9.2 50 | torch==2.0.1 51 | torchaudio==2.0.2 52 | torchvision==0.15.2 53 | tqdm==4.65.0 54 | triton==2.0.0 55 | typing_extensions==4.6.2 56 | urllib3==2.0.2 57 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StegFormer: Rebuilding the Glory of the Autoencoder-Based Steganography (AAAI-2024) 2 | **Xiao Ke, Huanqi Wu, Wenzhong Guo** 3 | 4 | The official pytorch implementation of the paper [StegFormer: Rebuilding the Glory of the Autoencoder-Based Steganography](https://github.com/aoli-gei/StegFormer). 5 | 6 | [[Project Page](https://aoli-gei.github.io/StegFormer.github.io/)] [[Paper](https://ojs.aaai.org/index.php/AAAI/article/view/28051)] [[Pretrain_model](https://drive.google.com/drive/folders/1L__astCCgm2GlQU-YaxZYTIzEnABdKEc?usp=sharing)] 7 | 8 | ## Abstract 9 | Image hiding aims to conceal one or more secret images within a cover image of the same resolution. Due to strict capacity requirements, image hiding is commonly called large-capacity steganography. In this paper, we propose StegFormer, a novel autoencoder-based image-hiding model. StegFormer can conceal one or multiple secret images within a cover image of the same resolution while preserving the high visual quality of the stego image. In addition, to mitigate the limitations of current steganographic models in real-world scenarios, we propose a normalizing training strategy and a restrict loss to improve the reliability of the steganographic models under realistic conditions. Furthermore, we propose an efficient steganographic capacity expansion method to increase the capacity of steganography and enhance the efficiency of secret communication. Through this approach, we can increase the relative payload of StegFormer to 96 bits per pixel without any training strategy modifications. Experiments demonstrate that our StegFormer outperforms existing state-of-the-art (SOTA) models. In the case of single-image steganography, there is an improvement of more than 3 dB and 5 dB in PSNR for secret/recovery image pairs and cover/stego image pairs. 10 | 11 | ## News 12 | - 2024.2.29: update README 13 | - 2024.4.23: update pretrain model 14 | 15 | ## How to train StegFormer 16 | - Please download the training dataset: [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) 17 | - Modify `DIV2K_path`, `path` and `model_name` in the file `config.py` 18 | - Training the model by `train_StegFormer_single_image.py` or `train_StegFormer_multiple_image` 19 | > Note that: please modify `num_secret` in `config.py` to define the number of secret images. 20 | 21 | ## How to test 22 | - Run `test_save_single_image_hiding.py` to test StegFormer using DIV2K valid dataset and save the images in folder `image` 23 | - Run `test_multiple_image_hiding.py` to test StegFormer in multi-image hiding 24 | > Note that: please modify `num_secret` in `config.py` to define the number of secret images. 25 | - Run `test_StegFormer.py` to calculate PSNR, SSIM, MAE and RMSE 26 | 27 | ## Contact 28 | If you have any questions, please contact [wuhuanqi135@gmail.com](wuhuanqi135@gmail.com). 29 | 30 | ## Citation 31 | If you find this work helps you, please cite: 32 | ```bibtex 33 | @inproceedings{ke2024stegformer, 34 | title={StegFormer: Rebuilding the Glory of Autoencoder-Based Steganography}, 35 | author={Ke, Xiao and Wu, Huanqi and Guo, Wenzhong}, 36 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 37 | volume={38}, 38 | number={3}, 39 | pages={2723--2731}, 40 | year={2024} 41 | ``` 42 | -------------------------------------------------------------------------------- /critic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.metrics import peak_signal_noise_ratio,structural_similarity 3 | import math 4 | import cv2 5 | 6 | 7 | def calculate_ssim_skimage(img1, img2): 8 | """ 9 | 计算 ssim 10 | img1:Tensor 11 | img2:Tensor 12 | """ 13 | img_1 = np.array(img1).astype(np.float64) 14 | img_2 = np.array(img2).astype(np.float64) 15 | ssim_score=[] 16 | for (i,j) in zip(img_1,img_2): 17 | ssim_score.append(structural_similarity(i, j, channel_axis=0,data_range=1)) 18 | return np.mean(ssim_score) 19 | 20 | # 使用 skimage 计算 PSNR 21 | def calculate_psnr_skimage(img1, img2): 22 | """ 23 | calculate psnr in Y channel. 24 | img1: Tensor 25 | img2: Tensor 26 | """ 27 | img_1 = (np.array(img1).astype(np.float64)*255).astype(np.float64) 28 | img_2 = (np.array(img2).astype(np.float64)*255).astype(np.float64) 29 | img1_y = rgb2ycbcr(img_1.transpose(0,2,3,1)) 30 | img2_y = rgb2ycbcr(img_2.transpose(0,2,3,1)) 31 | return peak_signal_noise_ratio(img1_y, img2_y,data_range=255) 32 | 33 | 34 | def calculate_mse(img1, img2): 35 | test1 = np.array(img1).astype(np.float64)*255 36 | test2 = np.array(img2).astype(np.float64)*255 37 | mse = np.mean((test1-test2)**2) 38 | return mse 39 | 40 | 41 | def calculate_rmse(img1, img2): 42 | test1 = np.array(img1).astype(np.float64)*255 43 | test2 = np.array(img2).astype(np.float64)*255 44 | rmse = np.sqrt(np.mean((test1-test2)**2)) 45 | return rmse 46 | 47 | 48 | def calculate_mae(img1, img2): 49 | test1 = np.array(img1).astype(np.float64)*255 50 | test2 = np.array(img2).astype(np.float64)*255 51 | mae = np.mean(np.abs(test1-test2)) 52 | return mae 53 | 54 | 55 | def calculate_psnr(img1, img2): 56 | img_1 = np.array(img1).astype(np.float64)*255 57 | img_2 = np.array(img2).astype(np.float64)*255 58 | mse = np.mean((img_1 - img_2)**2) 59 | if mse == 0: 60 | return float('inf') 61 | return 20 * math.log10(255.0 / math.sqrt(mse)) 62 | 63 | # HiNet 指标计算函数 64 | 65 | 66 | def calculate_psnr_y(img1, img2): 67 | img_1 = (np.array(img1).astype(np.float64)*255).astype(np.float64) 68 | img_2 = (np.array(img2).astype(np.float64)*255).astype(np.float64) 69 | img1_y = rgb2ycbcr(img_1.transpose(1,2,0)) 70 | img2_y = rgb2ycbcr(img_2.transpose(1,2,0)) 71 | mse = np.mean((img1_y - img2_y)**2) 72 | if mse == 0: 73 | return float('inf') 74 | return 20 * math.log10(255.0 / math.sqrt(mse)) 75 | 76 | 77 | def ssim(img1, img2): 78 | C1 = (0.01 * 255)**2 79 | C2 = (0.03 * 255)**2 80 | img_1 = np.array(img1).astype(np.float64) 81 | img_2 = np.array(img2).astype(np.float64) 82 | kernel = cv2.getGaussianKernel(11, 1.5) 83 | window = np.outer(kernel, kernel.transpose()) 84 | 85 | mu1 = cv2.filter2D(img_1, -1, window)[5:-5, 5:-5] # valid 86 | mu2 = cv2.filter2D(img_2, -1, window)[5:-5, 5:-5] 87 | mu1_sq = mu1**2 88 | mu2_sq = mu2**2 89 | mu1_mu2 = mu1 * mu2 90 | sigma1_sq = cv2.filter2D(img_1**2, -1, window)[5:-5, 5:-5] - mu1_sq 91 | sigma2_sq = cv2.filter2D(img_2**2, -1, window)[5:-5, 5:-5] - mu2_sq 92 | sigma12 = cv2.filter2D(img_1 * img_2, -1, window)[5:-5, 5:-5] - mu1_mu2 93 | 94 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 95 | (sigma1_sq + sigma2_sq + C2)) 96 | return ssim_map.mean() 97 | 98 | 99 | def calculate_ssim(img1, img2): 100 | '''calculate SSIM 101 | the same outputs as MATLAB's 102 | img1, img2: [0, 255] 103 | ''' 104 | img1_in=np.transpose(img1, (1, 2, 0)) 105 | img2_in=np.transpose(img2, (1, 2, 0)) 106 | if not img1_in.shape == img2_in.shape: 107 | raise ValueError('Input images must have the same dimensions.') 108 | if img1_in.ndim == 2: 109 | return ssim(img1_in, img2_in) 110 | elif img1_in.ndim == 3: 111 | # 多通道 112 | if img1_in.shape[2] == 3: 113 | ssims = [] 114 | for i in range(3): 115 | ssims.append(ssim(img1_in, img2_in)) 116 | return np.array(ssims).mean() 117 | # 单通道 118 | elif img1_in.shape[2] == 1: 119 | return ssim(np.squeeze(img1_in), np.squeeze(img2_in)) 120 | else: 121 | raise ValueError('Wrong input image dimensions.') 122 | 123 | 124 | def rgb2ycbcr(img, only_y=True): 125 | '''same as matlab rgb2ycbcr 126 | only_y: only return Y channel 127 | Input: 128 | uint8, [0, 255] 129 | float, [0, 1] 130 | ''' 131 | in_img_type = img.dtype 132 | img.astype(np.float32) 133 | if in_img_type != np.uint8: 134 | img *= 255. 135 | # convert 136 | if only_y: 137 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 138 | else: 139 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], 140 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] 141 | if in_img_type == np.uint8: 142 | rlt = rlt.round() 143 | else: 144 | rlt /= 255. 145 | return rlt.astype(in_img_type) 146 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import torch 3 | from torch.utils.data import Dataset, DataLoader 4 | import torchvision.transforms as T 5 | from natsort import natsorted 6 | from PIL import Image 7 | import albumentations as A 8 | import cv2 9 | from albumentations.pytorch import ToTensorV2 10 | import config 11 | args = config.Args() 12 | 13 | # 对数据集图像进行处理 14 | transform = T.Compose([ 15 | T.RandomCrop(128), 16 | T.RandomHorizontalFlip(), 17 | T.ToTensor() 18 | ]) 19 | 20 | # 使用 albumentations 库对图像进行处理 21 | transform_A = A.Compose([ 22 | A.RandomCrop(width=256, height=256), 23 | A.RandomRotate90(), 24 | A.HorizontalFlip(), 25 | A.augmentations.transforms.ChannelShuffle(0.3), 26 | ToTensorV2() 27 | ]) 28 | 29 | transform_A_valid = A.Compose([ 30 | A.CenterCrop(width=256, height=256), 31 | ToTensorV2() 32 | ]) 33 | 34 | transform_A_test = A.Compose([ 35 | A.CenterCrop(width=1024, height=1024), 36 | ToTensorV2() 37 | ]) 38 | 39 | transform_A_test_256 = A.Compose([ 40 | A.PadIfNeeded(min_width=256,min_height=256), 41 | A.CenterCrop(width=256, height=256), 42 | ToTensorV2() 43 | ]) 44 | 45 | DIV2K_path = "/home/whq135/dataset" 46 | Flickr2K_path = "/home/whq135/dataset/Flickr2K" 47 | 48 | batchsize = 12 49 | 50 | # dataset 51 | 52 | 53 | class DIV2K_Dataset(Dataset): 54 | def __init__(self, transforms_=None, mode='train'): 55 | self.transform = transforms_ 56 | self.mode = mode 57 | if mode == 'train': 58 | self.files = natsorted( 59 | sorted(glob.glob(DIV2K_path+"/DIV2K_train_HR"+"/*."+"png"))) 60 | else: 61 | self.files = natsorted( 62 | sorted(glob.glob(DIV2K_path+"/DIV2K_valid_HR"+"/*."+"png"))) 63 | 64 | def __getitem__(self, index): 65 | img = cv2.imread(self.files[index]) 66 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 转换为 RGB 67 | trans_img = self.transform(image=img) 68 | item = trans_img['image'] 69 | item = item/255.0 70 | return item 71 | 72 | def __len__(self): 73 | return len(self.files) 74 | 75 | class Flickr2K_Dataset(Dataset): 76 | def __init__(self, transforms_=None, mode='train'): 77 | self.transform = transforms_ 78 | self.mode = mode 79 | self.files = natsorted( 80 | sorted(glob.glob(Flickr2K_path+"/Flickr2K_HR"+"/*."+"png"))) 81 | 82 | def __getitem__(self, index): 83 | img = cv2.imread(self.files[index]) 84 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 转换为 RGB 85 | trans_img = self.transform(image=img) 86 | item= trans_img['image'] 87 | item=item/255.0 88 | return item 89 | 90 | def __len__(self): 91 | return len(self.files) 92 | 93 | class COCO_Test_Dataset(Dataset): 94 | def __init__(self, transforms_=None): 95 | self.transform = transforms_ 96 | self.files = natsorted( 97 | sorted(glob.glob("/home/whq135/dataset/COCO2017/test2017"+"/*."+"jpg"))) 98 | 99 | def __getitem__(self, index): 100 | img = cv2.imread(self.files[index]) 101 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 转换为 RGB 102 | trans_img = self.transform(image=img) 103 | item= trans_img['image'] 104 | item=item/255.0 105 | return item 106 | 107 | def __len__(self): 108 | return len(self.files) 109 | 110 | # dataloader 111 | DIV2K_train_cover_loader = DataLoader( 112 | DIV2K_Dataset(transforms_=transform_A, mode="train"), 113 | batch_size=args.single_batch_size, 114 | shuffle=True, 115 | pin_memory=True, 116 | num_workers=8, 117 | drop_last=True 118 | ) 119 | 120 | DIV2K_train_secret_loader = DataLoader( 121 | DIV2K_Dataset(transforms_=transform_A, mode="train"), 122 | batch_size=args.single_batch_size, 123 | shuffle=True, 124 | pin_memory=True, 125 | num_workers=8, 126 | drop_last=True 127 | ) 128 | 129 | DIV2K_val_cover_loader = DataLoader( 130 | DIV2K_Dataset(transforms_=transform_A_valid, mode="val"), 131 | batch_size=args.single_batch_size, 132 | shuffle=True, 133 | pin_memory=True, 134 | num_workers=2, 135 | drop_last=True 136 | ) 137 | 138 | DIV2K_val_secret_loader = DataLoader( 139 | DIV2K_Dataset(transforms_=transform_A_valid, mode="val"), 140 | batch_size=args.single_batch_size, 141 | shuffle=False, 142 | pin_memory=True, 143 | num_workers=2, 144 | drop_last=True 145 | ) 146 | 147 | DIV2K_test_cover_loader = DataLoader( 148 | DIV2K_Dataset(transforms_=transform_A_test, mode="val"), 149 | batch_size=1, 150 | shuffle=True, 151 | pin_memory=True, 152 | num_workers=1, 153 | drop_last=True 154 | ) 155 | 156 | DIV2K_test_secret_loader = DataLoader( 157 | DIV2K_Dataset(transforms_=transform_A_test, mode="val"), 158 | batch_size=1, 159 | shuffle=False, 160 | pin_memory=True, 161 | num_workers=1, 162 | drop_last=True 163 | ) 164 | 165 | DIV2K_multi_train_loader = DataLoader( 166 | DIV2K_Dataset(transforms_=transform_A, mode="train"), 167 | batch_size=args.multi_batch_iteration, 168 | shuffle=True, 169 | pin_memory=True, 170 | num_workers=16, 171 | drop_last=True 172 | ) 173 | 174 | DIV2K_multi_val_loader = DataLoader( 175 | DIV2K_Dataset(transforms_=transform_A_valid, mode="val"), 176 | batch_size=args.multi_batch_iteration, 177 | shuffle=True, 178 | pin_memory=True, 179 | num_workers=16, 180 | drop_last=True 181 | ) 182 | 183 | DIV2K_multi_test_loader = DataLoader( 184 | DIV2K_Dataset(transforms_=transform_A_test, mode="val"), 185 | batch_size=args.test_multi_batch_size, 186 | shuffle=True, 187 | pin_memory=True, 188 | num_workers=1, 189 | drop_last=True 190 | ) 191 | 192 | COCO_test_multi_loader = DataLoader( 193 | COCO_Test_Dataset(transforms_=transform_A_test_256), 194 | batch_size=args.test_multi_batch_size, 195 | shuffle=True, 196 | pin_memory=True, 197 | num_workers=1, 198 | drop_last=True 199 | ) 200 | 201 | COCO_test_cover_loader = DataLoader( 202 | COCO_Test_Dataset(transforms_=transform_A_test_256), 203 | batch_size=1, 204 | shuffle=True, 205 | pin_memory=True, 206 | num_workers=1, 207 | drop_last=True 208 | ) 209 | 210 | COCO_test_secret_loader = DataLoader( 211 | COCO_Test_Dataset(transforms_=transform_A_test_256), 212 | batch_size=1, 213 | shuffle=True, 214 | pin_memory=False, 215 | num_workers=1, 216 | drop_last=True 217 | ) 218 | 219 | Flickr2K_multi_train_loader = DataLoader( 220 | Flickr2K_Dataset(transforms_=transform_A), 221 | batch_size=40, 222 | shuffle=True, 223 | pin_memory=True, 224 | num_workers=16, 225 | drop_last=True 226 | ) 227 | 228 | Flickr2K_train_cover_loader = DataLoader( 229 | Flickr2K_Dataset(transforms_=transform_A), 230 | batch_size=batchsize, 231 | shuffle=True, 232 | pin_memory=True, 233 | num_workers=2, 234 | drop_last=True 235 | ) 236 | 237 | Flickr2K_train_secret_loader = DataLoader( 238 | Flickr2K_Dataset(transforms_=transform_A), 239 | batch_size=batchsize, 240 | shuffle=True, 241 | pin_memory=True, 242 | num_workers=2, 243 | drop_last=True 244 | ) 245 | -------------------------------------------------------------------------------- /test_multiple_image_hiding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim 4 | import torchvision 5 | import math 6 | import numpy as np 7 | from critic import * 8 | from tensorboardX import SummaryWriter 9 | from tqdm import tqdm 10 | from thop import profile 11 | import torch.nn.functional as F 12 | import os 13 | import timm 14 | import timm.scheduler 15 | from model import StegFormer 16 | from datasets import * 17 | from einops import rearrange 18 | import config 19 | args = config.Args() 20 | 21 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | 23 | # 模型初始化 24 | if args.use_model == 'StegFormer-S': 25 | encoder = StegFormer(img_resolution=args.image_size_train, input_dim=(args.num_secret+1)*3, cnn_emb_dim=8, output_dim=3, 26 | drop_key=False, patch_size=2, window_size=8, output_act=args.output_act, depth=[1, 1, 1, 1, 2, 1, 1, 1, 1], depth_tr=[2, 2, 2, 2, 2, 2, 2, 2]) 27 | decoder = StegFormer(img_resolution=args.image_size_train, input_dim=3, cnn_emb_dim=8, output_dim=args.num_secret*3, 28 | drop_key=False, patch_size=2, window_size=8, output_act=args.output_act, depth=[1, 1, 1, 1, 2, 1, 1, 1, 1], depth_tr=[2, 2, 2, 2, 2, 2, 2, 2]) 29 | if args.use_model == 'StegFormer-B': 30 | encoder = StegFormer(img_resolution=args.image_size_train, input_dim=(args.num_secret+1)*3, cnn_emb_dim=16, output_dim=3, 31 | drop_key=False, patch_size=2, window_size=8, output_act=args.output_act, depth=[1, 1, 1, 1, 2, 1, 1, 1, 1], depth_tr=[2, 2, 2, 2, 2, 2, 2, 2]) 32 | decoder = StegFormer(img_resolution=args.image_size_train, input_dim=3, cnn_emb_dim=16, output_dim=args.num_secret*3, 33 | drop_key=False, patch_size=2, window_size=8, output_act=args.output_act, depth=[1, 1, 1, 1, 2, 1, 1, 1, 1], depth_tr=[2, 2, 2, 2, 2, 2, 2, 2]) 34 | if args.use_model == 'StegFormer-L': 35 | encoder = StegFormer(img_resolution=args.image_size_train, input_dim=(args.num_secret+1)*3, cnn_emb_dim=32, output_dim=3, depth=[2, 2, 2, 2, 2, 2, 2, 2, 2]) 36 | decoder = StegFormer(img_resolution=args.image_size_train, input_dim=3, cnn_emb_dim=32, output_dim=args.num_secret*3, depth=[2, 2, 2, 2, 2, 2, 2, 2, 2]) 37 | 38 | # 加载模型 39 | save_path = args.path+'/checkpoint/'+args.model_name 40 | model_path = f'{save_path}/{args.model_name}.pt' 41 | state_dicts = torch.load(model_path) 42 | encoder.load_state_dict(state_dicts['encoder'], strict=False) 43 | decoder.load_state_dict(state_dicts['decoder'], strict=False) 44 | 45 | encoder.to(args.device) 46 | decoder.to(args.device) 47 | 48 | 49 | # 计算模型参数量 50 | with torch.no_grad(): 51 | test_encoder_input = torch.randn(1, (args.num_secret+1)*3, 256, 256).to(args.device) 52 | test_decoder_input = torch.randn(1, 3, 256, 256).to(args.device) 53 | encoder_mac, encoder_params = profile(encoder, inputs=(test_encoder_input,)) 54 | 55 | decoder_mac, decoder_params = profile(decoder, inputs=(test_decoder_input,)) 56 | print("thop result:encoder FLOPs="+str(encoder_mac*2)+",encoder params="+str(encoder_params)) 57 | print("thop result:decoder FLOPs="+str(decoder_mac*2)+",decoder params="+str(decoder_params)) 58 | i = 0 # 为每一张图编号 59 | # 评价指标 60 | psnr_secret = [] 61 | psnr_cover = [] 62 | psnr_secret_y = [] 63 | psnr_seret1=[] 64 | psnr_seret2=[] 65 | psnr_seret3=[] 66 | psnr_seret4=[] 67 | 68 | ssim_seret1=[] 69 | ssim_seret2=[] 70 | ssim_seret3=[] 71 | ssim_seret4=[] 72 | 73 | psnr_cover_y = [] 74 | ssim_secret = [] 75 | ssim_cover = [] 76 | mse_cover = [] 77 | mse_secret = [] 78 | rmse_cover = [] 79 | rmse_secret = [] 80 | mae_cover = [] 81 | mae_secret = [] 82 | 83 | # without clamp 84 | for j in range(1): 85 | # test 1,000 images 86 | with torch.no_grad(): 87 | # val 88 | encoder.eval() 89 | decoder.eval() 90 | 91 | # 在验证集上测试 92 | for i_batch, img in enumerate(COCO_test_multi_loader): 93 | img = img.to(args.device) 94 | cover = img[0:1, :, :, :] 95 | secret = img[1:, :, :, :] 96 | secret_cat = rearrange(secret, '(b n) c h w -> b (n c) h w', n=args.num_secret) # 通道级联 97 | 98 | # encode 99 | msg = torch.cat([cover, secret_cat], 1) 100 | encode_img = encoder(msg) # 添加残差连接 101 | encode_img_c = torch.clamp(encode_img, 0, 1) 102 | 103 | # decode 104 | decode_img = decoder(encode_img_c) 105 | decode_img = rearrange(decode_img, 'b (n c) h w -> (b n) c h w', n=args.num_secret) 106 | 107 | # 限制为图像表示 108 | decode_img = decode_img.clamp(0, 1) 109 | encode_img = encode_img.clamp(0, 1) 110 | 111 | cover_dif=(cover-encode_img)*10 112 | secret_dif=(secret-decode_img)*10 113 | 114 | # 计算各种指标 115 | # 拷贝进内存以方便计算 116 | cover = cover.cpu() 117 | secret = secret.cpu() 118 | encode_img = encode_img.cpu() 119 | decode_img = decode_img.cpu() 120 | 121 | # 移除第一个通道 122 | decode_img1 = decode_img[:1, :, :, :] 123 | decode_img2 = decode_img[:2, :, :, :] 124 | decode_img3 = decode_img[:3, :, :, :] 125 | decode_img4 = decode_img[:4, :, :, :] 126 | 127 | # 计算 Y 通道 PSNR 128 | psnry_encode_temp = calculate_psnr_skimage(cover, encode_img) 129 | psnry_decode_temp1 = calculate_psnr_skimage(secret[:1, :, :, :], decode_img1) 130 | psnry_decode_temp2 = calculate_psnr_skimage(secret[:2, :, :, :], decode_img2) 131 | psnry_decode_temp3 = calculate_psnr_skimage(secret[:3,:,:,:], decode_img3) 132 | psnry_decode_temp4 = calculate_psnr_skimage(secret[:4,:,:,:], decode_img4) 133 | psnry_decode_temp = (psnry_decode_temp1+psnry_decode_temp2+psnry_decode_temp3+psnry_decode_temp4)/4.0 134 | psnr_seret1.append(psnry_decode_temp1) 135 | psnr_seret2.append(psnry_decode_temp2) 136 | psnr_seret3.append(psnry_decode_temp3) 137 | psnr_seret4.append(psnry_decode_temp4) 138 | psnr_cover_y.append(psnry_encode_temp) 139 | psnr_secret_y.append(psnry_decode_temp) 140 | 141 | # 计算 ssim 142 | ssim_encode_temp = calculate_ssim_skimage(cover, encode_img) 143 | ssim_decode_temp1 = calculate_ssim_skimage(secret[:1, :, :, :], decode_img1) 144 | ssim_decode_temp2 = calculate_ssim_skimage(secret[:2, :, :, :], decode_img2) 145 | ssim_decode_temp3 = calculate_ssim_skimage(secret[:3,:,:,:], decode_img3) 146 | ssim_decode_temp4 = calculate_ssim_skimage(secret[:4,:,:,:], decode_img4) 147 | ssim_decode_temp = (ssim_decode_temp1+ssim_decode_temp2+ssim_decode_temp3+ssim_decode_temp4)/4.0 148 | ssim_seret1.append(ssim_decode_temp1) 149 | ssim_seret2.append(ssim_decode_temp2) 150 | ssim_seret3.append(ssim_decode_temp3) 151 | ssim_seret4.append(ssim_decode_temp4) 152 | ssim_cover.append(ssim_encode_temp) 153 | ssim_secret.append(ssim_decode_temp) 154 | 155 | i += 1 # 下一张图像 156 | print("img "+str(i)+" :") 157 | print("PSNR_Y_cover:" + str(np.mean(psnr_cover_y)) + " PSNR_Y_secret:" + str(np.mean(psnr_secret_y))) 158 | print("SSIM_cover:" + str(np.mean(ssim_cover)) + " SSIM_secret:" + str(np.mean(ssim_secret))) 159 | 160 | print("clamp total result:") 161 | print("PSNR_Y_cover:" + str(np.mean(psnr_cover_y)) + " PSNR_Y_secret:" + str(np.mean(psnr_secret_y))) 162 | print("SSIM_cover:" + str(np.mean(ssim_cover)) + " SSIM_secret:" + str(np.mean(ssim_secret))) 163 | print(f'secret1: {np.mean(psnr_seret1)} {np.mean(ssim_seret1)} secret2: {np.mean(psnr_seret2)} {np.mean(ssim_seret2)} secret3: {np.mean(psnr_seret3)} {np.mean(ssim_seret3)} secret4: {np.mean(psnr_seret4)} {np.mean(ssim_seret4)}') 164 | -------------------------------------------------------------------------------- /train_StegFormer_single_image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim 4 | import math 5 | import numpy as np 6 | from critic import * 7 | from tensorboardX import SummaryWriter 8 | from thop import profile 9 | from datasets import * 10 | from model import StegFormer 11 | import os 12 | import timm 13 | import timm.scheduler 14 | import config 15 | args = config.Args() 16 | 17 | # 设置随机种子 18 | seed = 42 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | 22 | # loss function 23 | class L1_Charbonnier_loss(torch.nn.Module): 24 | """L1 Charbonnierloss.""" 25 | 26 | def __init__(self): 27 | super(L1_Charbonnier_loss, self).__init__() 28 | self.eps = 1e-6 29 | 30 | def forward(self, X, Y): 31 | diff = torch.add(X, -Y) 32 | error = torch.sqrt(diff * diff + self.eps) 33 | loss = torch.mean(error) 34 | return loss 35 | 36 | class Restrict_Loss(nn.Module): 37 | """Restrict loss using L2 loss function""" 38 | 39 | def __init__(self): 40 | super().__init__() 41 | self.eps = 1e-6 42 | 43 | def forward(self, X): 44 | count1 = torch.sum(X > 1) 45 | count0 = torch.sum(X < 0) 46 | if count1 == 0: 47 | count1 = 1 48 | if count0 == 0: 49 | count0 = 1 50 | one = torch.ones_like(X) 51 | zero = torch.zeros_like(X) 52 | X_one = torch.where(X <= 1, 1, X) # 对超过 1 的值施加惩罚 53 | X_zero = torch.where(X >= 0, 0, X) # 对小于 0 的值施加惩罚 54 | diff_one = X_one-one 55 | diff_zero = zero-X_zero 56 | loss = torch.sum(0.5*(diff_one**2))/count1 + torch.sum(0.5*(diff_zero**2))/count0 57 | return loss 58 | 59 | # 新建文件夹 60 | model_version_name = args.model_name 61 | save_path = args.path+'/checkpoint/'+model_version_name # 新建一个以模型版本名为名字的文件夹 62 | if not os.path.exists(save_path): 63 | os.makedirs(save_path) 64 | 65 | # tensorboard 66 | writer = SummaryWriter(f'{args.path}/tensorboard_log/{args.model_name}/') 67 | 68 | # StegFormer initiate 69 | if args.use_model == 'StegFormer-S': 70 | encoder = StegFormer(img_resolution=args.image_size_train, input_dim=(args.num_secret+1)*3, cnn_emb_dim=8, output_dim=3, 71 | drop_key=False, patch_size=2, window_size=8, output_act=args.output_act, depth=[1, 1, 1, 1, 2, 1, 1, 1, 1], depth_tr=[2, 2, 2, 2, 2, 2, 2, 2]) 72 | decoder = StegFormer(img_resolution=args.image_size_train, input_dim=3, cnn_emb_dim=8, output_dim=3, 73 | drop_key=False, patch_size=2, window_size=8, output_act=args.output_act, depth=[1, 1, 1, 1, 2, 1, 1, 1, 1], depth_tr=[2, 2, 2, 2, 2, 2, 2, 2]) 74 | if args.use_model == 'StegFormer-B': 75 | encoder = StegFormer(img_resolution=args.image_size_train, input_dim=(args.num_secret+1)*3, cnn_emb_dim=16, output_dim=3) 76 | decoder = StegFormer(img_resolution=args.image_size_train, input_dim=3, cnn_emb_dim=16, output_dim=args.num_secret*3) 77 | if args.use_model == 'StegFormer-L': 78 | encoder = StegFormer(img_resolution=args.image_size_train, input_dim=(args.num_secret+1)*3, cnn_emb_dim=32, output_dim=3, depth=[2, 2, 2, 2, 2, 2, 2, 2, 2]) 79 | decoder = StegFormer(img_resolution=args.image_size_train, input_dim=3, cnn_emb_dim=32, output_dim=args.num_secret*3, depth=[2, 2, 2, 2, 2, 2, 2, 2, 2]) 80 | encoder.cuda() 81 | decoder.cuda() 82 | 83 | # loading model 84 | if args.train_next != 0: 85 | model_path = save_path + '/model_checkpoint_%.5i' % args.train_next + '.pt' 86 | state_dicts = torch.load(model_path) 87 | encoder.load_state_dict(state_dicts['encoder'], strict=False) 88 | decoder.load_state_dict(state_dicts['decoder'], strict=False) 89 | 90 | 91 | # optimer and the learning rate scheduler 92 | optim = torch.optim.AdamW([{'params': encoder.parameters()}, {'params': decoder.parameters()}], lr=args.lr) 93 | if args.train_next != 0: 94 | optim.load_state_dict(state_dicts['opt']) 95 | scheduler = timm.scheduler.CosineLRScheduler(optimizer=optim, 96 | t_initial=args.epochs, 97 | lr_min=0, 98 | warmup_t=args.warm_up_epoch, 99 | warmup_lr_init=args.warm_up_lr_init) 100 | 101 | # numbers of the parameter 102 | with torch.no_grad(): 103 | test_encoder_input = torch.randn(1, (args.num_secret+1)*3, args.image_size_train, args.image_size_train).to(args.device) 104 | test_decoder_input = torch.randn(1, 3, args.image_size_train, args.image_size_train).to(args.device) 105 | encoder_mac, encoder_params = profile(encoder, inputs=(test_encoder_input,)) 106 | decoder_mac, decoder_params = profile(decoder, inputs=(test_decoder_input,)) 107 | print("thop result:encoder FLOPs="+str(encoder_mac*2)+",encoder params="+str(encoder_params)) 108 | print("thop result:decoder FLOPs="+str(decoder_mac*2)+",decoder params="+str(decoder_params)) 109 | 110 | # loss function 111 | conceal_loss_function = L1_Charbonnier_loss().to(args.device) 112 | reveal_loss_function = L1_Charbonnier_loss().to(args.device) 113 | restrict_loss_funtion = Restrict_Loss().to(args.device) 114 | 115 | # train 116 | for i_epoch in range(args.epochs): 117 | sum_loss = [] 118 | scheduler.step(i_epoch+args.train_next) 119 | for i_batch, (cover, secret) in enumerate(zip(DIV2K_train_cover_loader, DIV2K_train_secret_loader)): 120 | cover = cover.to(args.device) 121 | secret = secret.to(args.device) 122 | 123 | # encode 124 | msg = torch.cat([cover, secret], 1) 125 | encode_img = encoder(msg) 126 | 127 | # normalizing 128 | if args.norm_train == 'clamp': 129 | encode_img_c = torch.clamp(encode_img, 0, 1) 130 | else: 131 | encode_img_c = encode_img 132 | 133 | # decode 134 | decode_img = decoder(encode_img_c) 135 | 136 | # loss 137 | conceal_loss = conceal_loss_function(cover.cuda(), encode_img.cuda()) 138 | reveal_loss = reveal_loss_function(secret.cuda(), decode_img.cuda()) 139 | 140 | total_loss = None 141 | if args.norm_train: 142 | restrict_loss = restrict_loss_funtion(encode_img.cuda()) 143 | total_loss = conceal_loss + reveal_loss + restrict_loss 144 | else: 145 | total_loss = conceal_loss + reveal_loss 146 | sum_loss.append(total_loss.item()) 147 | 148 | # backward 149 | total_loss.backward() 150 | optim.step() 151 | optim.zero_grad() 152 | 153 | # valid 154 | if i_epoch % args.val_freq == 0: 155 | print("validation begin:") 156 | with torch.no_grad(): 157 | encoder.eval() 158 | decoder.eval() 159 | 160 | # psnr and ssim 161 | psnr_secret = [] 162 | psnr_cover = [] 163 | ssim_secret = [] 164 | ssim_cover = [] 165 | 166 | # 在验证集上测试 167 | for (cover, secret) in zip(DIV2K_val_cover_loader, DIV2K_val_secret_loader): 168 | cover = cover.to(args.device) 169 | secret = secret.to(args.device) 170 | 171 | # encode 172 | msg = torch.cat([cover, secret], 1) 173 | encode_img = encoder(msg) 174 | 175 | if args.norm_train: 176 | encode_img = torch.clamp(encode_img, 0, 1) 177 | 178 | # decode 179 | decode_img = decoder(encode_img) 180 | decode_img = torch.clamp(decode_img, 0, 1) 181 | encode_img = torch.clamp(encode_img, 0, 1) 182 | 183 | cover = cover.cpu() 184 | secret = secret.cpu() 185 | encode_img = encode_img.cpu() 186 | decode_img = decode_img.cpu() 187 | 188 | psnr_encode_temp = calculate_psnr(cover, encode_img) 189 | psnr_decode_temp = calculate_psnr(secret, decode_img) 190 | psnr_cover.append(psnr_encode_temp) 191 | psnr_secret.append(psnr_decode_temp) 192 | 193 | ssim_encode_temp = calculate_ssim_skimage(cover, encode_img) 194 | ssim_decode_temp = calculate_ssim_skimage(secret, decode_img) 195 | ssim_cover.append(ssim_encode_temp) 196 | ssim_secret.append(ssim_decode_temp) 197 | writer.add_images('image/encode_img', encode_img, dataformats='NCHW', global_step=i_epoch+args.train_next) 198 | writer.add_images('image/decode_img', decode_img, dataformats='NCHW', global_step=i_epoch+args.train_next) 199 | 200 | writer.add_scalar("PSNR/PSNR_cover", np.mean(psnr_cover), i_epoch+args.train_next) 201 | writer.add_scalar("PSNR/PSNR_secret", np.mean(psnr_secret), i_epoch+args.train_next) 202 | writer.add_scalar("SSIM/SSIM_cover", np.mean(ssim_cover), i_epoch+args.train_next) 203 | writer.add_scalar("SSIM/SSIM_secret", np.mean(ssim_secret), i_epoch+args.train_next) 204 | print("PSNR_cover:" + str(np.mean(psnr_cover)) + " PSNR_secret:" + str(np.mean(psnr_secret))) 205 | print("SSIM_cover:" + str(np.mean(ssim_cover)) + " SSIM_secret:" + str(np.mean(ssim_secret))) 206 | 207 | print("epoch:"+str(i_epoch+args.train_next) + ":" + str(np.mean(sum_loss))) 208 | if i_epoch % 2 == 0: 209 | writer.add_scalar("loss", np.mean(sum_loss), i_epoch+args.train_next) 210 | 211 | # 保存当前模型以及优化器参数 212 | if (i_epoch % args.save_freq) == 0: 213 | torch.save({'opt': optim.state_dict(), 214 | 'encoder': encoder.state_dict(), 215 | 'decoder': decoder.state_dict()}, save_path + '/model_checkpoint_%.5i' % (i_epoch+args.train_next)+'.pt') 216 | 217 | 218 | torch.save({'opt': optim.state_dict(), 219 | 'encoder': encoder.state_dict(), 220 | 'decoder': decoder.state_dict()}, f'{save_path}/{model_version_name}.pt') 221 | -------------------------------------------------------------------------------- /test_save_single_image_hiding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim 4 | import torchvision 5 | import numpy as np 6 | from critic import * 7 | from thop import profile 8 | import os 9 | from model import StegFormer 10 | from datasets import * 11 | import config 12 | args = config.Args() 13 | 14 | # initialization 15 | if args.use_model == 'StegFormer-S': 16 | encoder = StegFormer(1024, input_dim=(args.num_secret+1)*3, cnn_emb_dim=8, output_dim=3) 17 | decoder = StegFormer(1024, input_dim=3, cnn_emb_dim=8, output_dim=args.num_secret*3) 18 | if args.use_model == 'StegFormer-B': 19 | encoder = StegFormer(1024, input_dim=(args.num_secret+1)*3, cnn_emb_dim=16, output_dim=3, 20 | drop_key=False, patch_size=2, window_size=8, output_act=None,depth=[1, 1, 1, 1, 2, 1, 1, 1, 1], depth_tr=[2, 2, 2, 2, 2, 2, 2, 2]) 21 | decoder = StegFormer(1024, input_dim=3, cnn_emb_dim=16, output_dim=args.num_secret*3, 22 | drop_key=False, patch_size=2, window_size=8, output_act=None,depth=[1, 1, 1, 1, 2, 1, 1, 1, 1], depth_tr=[2, 2, 2, 2, 2, 2, 2, 2]) 23 | if args.use_model == 'StegFormer-L': 24 | encoder = StegFormer(1024, input_dim=(args.num_secret+1)*3, cnn_emb_dim=32, output_dim=3, depth=[2, 2, 2, 2, 2, 2, 2, 2, 2]) 25 | decoder = StegFormer(1024, input_dim=3, cnn_emb_dim=32, output_dim=args.num_secret*3, depth=[2, 2, 2, 2, 2, 2, 2, 2, 2]) 26 | 27 | # 加载模型 28 | save_path = args.path+'/checkpoint/'+args.model_name 29 | model_path = f'{save_path}/{args.model_name}.pt' 30 | state_dicts = torch.load(model_path) 31 | encoder.load_state_dict(state_dicts['encoder'], strict=False) 32 | decoder.load_state_dict(state_dicts['decoder'], strict=False) 33 | 34 | encoder.to(args.device) 35 | decoder.to(args.device) 36 | 37 | # 计算模型参数量 38 | with torch.no_grad(): 39 | test_encoder_input = torch.randn(1, 6, 1024, 1024).to(args.device) 40 | test_decoder_input = torch.randn(1, 3, 1024, 1024).to(args.device) 41 | encoder_mac, encoder_params = profile(encoder, inputs=(test_encoder_input,)) 42 | 43 | decoder_mac, decoder_params = profile(decoder, inputs=(test_decoder_input,)) 44 | print("thop result:encoder FLOPs="+str(encoder_mac*2)+",encoder params="+str(encoder_params)) 45 | print("thop result:decoder FLOPs="+str(decoder_mac*2)+",decoder params="+str(decoder_params)) 46 | 47 | i = 0 # 为每一张图编号 48 | # 评价指标 49 | psnr_secret = [] 50 | psnr_cover = [] 51 | psnr_secret_y = [] 52 | psnr_cover_y = [] 53 | ssim_secret = [] 54 | ssim_cover = [] 55 | mse_cover = [] 56 | mse_secret = [] 57 | rmse_cover = [] 58 | rmse_secret = [] 59 | mae_cover = [] 60 | mae_secret = [] 61 | 62 | # without clamp 63 | for j in range(1): 64 | with torch.no_grad(): 65 | # val 66 | encoder.eval() 67 | decoder.eval() 68 | 69 | # 在验证集上测试 70 | for (cover, secret) in zip(DIV2K_test_cover_loader, DIV2K_test_secret_loader): 71 | cover = cover.to(args.device) 72 | secret = secret.to(args.device) 73 | 74 | # encode 75 | msg = torch.cat([cover, secret], 1) 76 | encode_img = encoder(msg) # 添加残差连接 77 | 78 | # decode 79 | decode_img = decoder(encode_img) 80 | 81 | # 限制为图像表示 82 | decode_img = decode_img.clamp(0, 1) 83 | encode_img = encode_img.clamp(0, 1) 84 | 85 | # 计算各种指标 86 | # 拷贝进内存以方便计算 87 | cover = cover.cpu() 88 | secret = secret.cpu() 89 | encode_img = encode_img.cpu() 90 | decode_img = decode_img.cpu() 91 | 92 | # 计算 Y 通道 PSNR 93 | psnry_encode_temp = calculate_psnr_skimage(cover, encode_img) 94 | psnry_decode_temp = calculate_psnr_skimage(secret, decode_img) 95 | psnr_cover_y.append(psnry_encode_temp) 96 | psnr_secret_y.append(psnry_decode_temp) 97 | 98 | # 计算 SSIM 99 | ssim_encode=calculate_ssim_skimage(cover,encode_img) 100 | ssim_decode=calculate_ssim_skimage(secret,decode_img) 101 | ssim_cover.append(ssim_encode) 102 | ssim_secret.append(ssim_decode) 103 | 104 | # 计算 RMSE 105 | rmse_cover_temp = calculate_rmse(cover, encode_img) 106 | rmse_secret_temp = calculate_rmse(secret, decode_img) 107 | rmse_cover.append(rmse_cover_temp) 108 | rmse_secret.append(rmse_secret_temp) 109 | 110 | # 计算 MAE 111 | mae_cover_temp = calculate_mae(cover, encode_img) 112 | mae_secret_temp = calculate_mae(secret, decode_img) 113 | mae_cover.append(mae_cover_temp) 114 | mae_secret.append(mae_secret_temp) 115 | 116 | # 保存图像 117 | torchvision.utils.save_image(cover, args.path + '/image/wo_clamp/cover/' + '%.5d.png' % i) 118 | torchvision.utils.save_image(secret, args.path + '/image/wo_clamp/secret/' + '%.5d.png' % i) 119 | torchvision.utils.save_image(encode_img, args.path + '/image/wo_clamp/stego/' + '%.5d.png' % i) 120 | torchvision.utils.save_image(decode_img, args.path + '/image/wo_clamp/secret_rev/' + '%.5d.png' % i) 121 | i += 1 # 下一张图像 122 | print("img "+str(i)+" :") 123 | print("PSNR_Y_cover:" + str(np.mean(psnry_encode_temp)) + " PSNR_Y_secret:" + str(np.mean(psnry_decode_temp))) 124 | print("SSIM_cover:" + str(np.mean(ssim_cover)) + " SSIM_secret:" + str(np.mean(ssim_secret))) 125 | print("RMSE_cover:" + str(np.mean(rmse_cover_temp)) + " RMSE_secret:" + str(np.mean(rmse_secret_temp))) 126 | print("MAE_cover:" + str(np.mean(mae_cover_temp)) + " MAE_secret:" + str(np.mean(mae_secret_temp))) 127 | 128 | print("wo_clamp total result:") 129 | print("PSNR_Y_cover:" + str(np.mean(psnr_cover_y)) + " PSNR_Y_secret:" + str(np.mean(psnr_secret_y))) 130 | print("SSIM_cover:" + str(np.mean(ssim_cover)) + " SSIM_secret:" + str(np.mean(ssim_secret))) 131 | print("MAE_cover:" + str(np.mean(mae_cover)) + " MSE_secret:" + str(np.mean(mae_secret))) 132 | print("RMSE_cover:" + str(np.mean(rmse_cover)) + " RMSE_secret:" + str(np.mean(rmse_secret))) 133 | 134 | # 计算 clamp 的指标 135 | i = 0 # 为每一张图编号 136 | # 评价指标 137 | psnr_secret = [] 138 | psnr_cover = [] 139 | psnr_secret_y = [] 140 | psnr_cover_y = [] 141 | ssim_secret = [] 142 | ssim_cover = [] 143 | mse_cover = [] 144 | mse_secret = [] 145 | rmse_cover = [] 146 | rmse_secret = [] 147 | mae_cover = [] 148 | mae_secret = [] 149 | 150 | # with clamp 151 | for j in range(1): 152 | # test 1,000 images 153 | with torch.no_grad(): 154 | # val 155 | encoder.eval() 156 | decoder.eval() 157 | 158 | # 在验证集上测试 159 | for (cover, secret) in zip(DIV2K_test_cover_loader, DIV2K_test_secret_loader): 160 | cover = cover.to(args.device) 161 | secret = secret.to(args.device) 162 | 163 | # encode 164 | msg = torch.cat([cover, secret], 1) 165 | encode_img = encoder(msg) # 添加残差连接 166 | encode_img=torch.clamp(encode_img,0,1) 167 | 168 | # decode 169 | decode_img = decoder(encode_img) 170 | 171 | # 限制为图像表示 172 | decode_img = decode_img.clamp(0, 1) 173 | encode_img = encode_img.clamp(0, 1) 174 | 175 | # 计算各种指标 176 | # 拷贝进内存以方便计算 177 | cover = cover.cpu() 178 | secret = secret.cpu() 179 | encode_img = encode_img.cpu() 180 | decode_img = decode_img.cpu() 181 | 182 | # 计算 Y 通道 PSNR 183 | psnry_encode_temp = calculate_psnr_skimage(cover, encode_img) 184 | psnry_decode_temp = calculate_psnr_skimage(secret, decode_img) 185 | psnr_cover_y.append(psnry_encode_temp) 186 | psnr_secret_y.append(psnry_decode_temp) 187 | 188 | # 计算 SSIM 189 | ssim_encode=calculate_ssim_skimage(cover,encode_img) 190 | ssim_decode=calculate_ssim_skimage(secret,decode_img) 191 | ssim_cover.append(ssim_encode) 192 | ssim_secret.append(ssim_decode) 193 | 194 | # 计算 RMSE 195 | rmse_cover_temp = calculate_rmse(cover, encode_img) 196 | rmse_secret_temp = calculate_rmse(secret, decode_img) 197 | rmse_cover.append(rmse_cover_temp) 198 | rmse_secret.append(rmse_secret_temp) 199 | 200 | # 计算 MAE 201 | mae_cover_temp = calculate_mae(cover, encode_img) 202 | mae_secret_temp = calculate_mae(secret, decode_img) 203 | mae_cover.append(mae_cover_temp) 204 | mae_secret.append(mae_secret_temp) 205 | 206 | # 保存图像 207 | torchvision.utils.save_image(cover, args.path + '/image/clamp/cover/' + '%.5d.png' % i) 208 | torchvision.utils.save_image(secret, args.path + '/image/clamp/secret/' + '%.5d.png' % i) 209 | torchvision.utils.save_image(encode_img, args.path + '/image/clamp/stego/' + '%.5d.png' % i) 210 | torchvision.utils.save_image(decode_img, args.path + '/image/clamp/secret_rev/' + '%.5d.png' % i) 211 | i += 1 # 下一张图像 212 | # print("img "+str(i)+" :") 213 | # print("PSNR_Y_cover:" + str(np.mean(psnry_encode_temp)) + " PSNR_Y_secret:" + str(np.mean(psnry_decode_temp))) 214 | # print("SSIM_cover:" + str(np.mean(ssim_cover)) + " SSIM_secret:" + str(np.mean(ssim_secret))) 215 | # print("RMSE_cover:" + str(np.mean(rmse_cover_temp)) + " RMSE_secret:" + str(np.mean(rmse_secret_temp))) 216 | # print("MAE_cover:" + str(np.mean(mae_cover_temp)) + " MAE_secret:" + str(np.mean(mae_secret_temp))) 217 | 218 | print("clamp total result:") 219 | print("PSNR_Y_cover:" + str(np.mean(psnr_cover_y)) + " PSNR_Y_secret:" + str(np.mean(psnr_secret_y))) 220 | print("SSIM_cover:" + str(np.mean(ssim_cover)) + " SSIM_secret:" + str(np.mean(ssim_secret))) 221 | print("MAE_cover:" + str(np.mean(mae_cover)) + " MAE_secret:" + str(np.mean(mae_secret))) 222 | print("RMSE_cover:" + str(np.mean(rmse_cover)) + " RMSE_secret:" + str(np.mean(rmse_secret))) -------------------------------------------------------------------------------- /train_StegFormer_multiple_image.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim 4 | import math 5 | import numpy as np 6 | from critic import * 7 | from tensorboardX import SummaryWriter 8 | from thop import profile 9 | from datasets import * 10 | from model import StegFormer 11 | import os 12 | import timm 13 | import timm.scheduler 14 | import config 15 | from einops import rearrange 16 | args = config.Args() 17 | 18 | # 设置随机种子 19 | seed = 42 20 | torch.manual_seed(seed) 21 | torch.cuda.manual_seed_all(seed) 22 | 23 | # loss function 24 | class L1_Charbonnier_loss(torch.nn.Module): 25 | """L1 Charbonnierloss.""" 26 | 27 | def __init__(self): 28 | super(L1_Charbonnier_loss, self).__init__() 29 | self.eps = 1e-6 30 | 31 | def forward(self, X, Y): 32 | diff = torch.add(X, -Y) 33 | error = torch.sqrt(diff * diff + self.eps) 34 | loss = torch.mean(error) 35 | return loss 36 | 37 | class Restrict_Loss(nn.Module): 38 | """Restrict loss using L2 loss function""" 39 | 40 | def __init__(self): 41 | super().__init__() 42 | self.eps = 1e-6 43 | 44 | def forward(self, X): 45 | count1 = torch.sum(X > 1) 46 | count0 = torch.sum(X < 0) 47 | if count1 == 0: 48 | count1 = 1 49 | if count0 == 0: 50 | count0 = 1 51 | one = torch.ones_like(X) 52 | zero = torch.zeros_like(X) 53 | X_one = torch.where(X <= 1, 1, X) 54 | X_zero = torch.where(X >= 0, 0, X) 55 | diff_one = X_one-one 56 | diff_zero = zero-X_zero 57 | loss = torch.sum(0.5*(diff_one**2))/count1 + torch.sum(0.5*(diff_zero**2))/count0 58 | return loss 59 | 60 | # 新建文件夹 61 | model_version_name = args.model_name 62 | save_path = args.path+'/checkpoint/'+model_version_name # 新建一个以模型版本名为名字的文件夹 63 | if not os.path.exists(save_path): 64 | os.makedirs(save_path) 65 | 66 | # tensorboard 67 | writer = SummaryWriter(f'{args.path}/tensorboard_log/{args.model_name}/') 68 | 69 | # StegFormer initiate 70 | if args.use_model == 'StegFormer-S': 71 | encoder = StegFormer(img_resolution=args.image_size_train, input_dim=(args.num_secret+1)*3, cnn_emb_dim=8, output_dim=3, 72 | drop_key=False, patch_size=2, window_size=8, output_act=args.output_act, depth=[1, 1, 1, 1, 2, 1, 1, 1, 1], depth_tr=[2, 2, 2, 2, 2, 2, 2, 2]) 73 | decoder = StegFormer(img_resolution=args.image_size_train, input_dim=3, cnn_emb_dim=8, output_dim=args.num_secret*3, 74 | drop_key=False, patch_size=2, window_size=8, output_act=args.output_act, depth=[1, 1, 1, 1, 2, 1, 1, 1, 1], depth_tr=[2, 2, 2, 2, 2, 2, 2, 2]) 75 | if args.use_model == 'StegFormer-B': 76 | encoder = StegFormer(img_resolution=args.image_size_train, input_dim=(args.num_secret+1)*3, cnn_emb_dim=16, output_dim=3, 77 | drop_key=False, patch_size=2, window_size=8, output_act=args.output_act, depth=[1, 1, 1, 1, 2, 1, 1, 1, 1], depth_tr=[2, 2, 2, 2, 2, 2, 2, 2]) 78 | decoder = StegFormer(img_resolution=args.image_size_train, input_dim=3, cnn_emb_dim=16, output_dim=args.num_secret*3, 79 | drop_key=False, patch_size=2, window_size=8, output_act=args.output_act, depth=[1, 1, 1, 1, 2, 1, 1, 1, 1], depth_tr=[2, 2, 2, 2, 2, 2, 2, 2]) 80 | if args.use_model == 'StegFormer-L': 81 | encoder = StegFormer(img_resolution=args.image_size_train, input_dim=(args.num_secret+1)*3, cnn_emb_dim=32, output_dim=3, depth=[2, 2, 2, 2, 2, 2, 2, 2, 2]) 82 | decoder = StegFormer(img_resolution=args.image_size_train, input_dim=3, cnn_emb_dim=32, output_dim=args.num_secret*3, depth=[2, 2, 2, 2, 2, 2, 2, 2, 2]) 83 | encoder.cuda() 84 | decoder.cuda() 85 | 86 | # loading model 87 | model_path = '' 88 | if args.train_next != 0: 89 | model_path = save_path + '/model_checkpoint_%.5i' % args.train_next + '.pt' 90 | state_dicts = torch.load(model_path) 91 | encoder.load_state_dict(state_dicts['encoder'], strict=False) 92 | decoder.load_state_dict(state_dicts['decoder'], strict=False) 93 | 94 | 95 | # optimer and the learning rate scheduler 96 | optim = torch.optim.AdamW([{'params': encoder.parameters()}, {'params': decoder.parameters()}], lr=args.lr) 97 | if args.train_next != 0: 98 | optim.load_state_dict(state_dicts['opt']) 99 | scheduler = timm.scheduler.CosineLRScheduler(optimizer=optim, 100 | t_initial=args.epochs, 101 | lr_min=0, 102 | warmup_t=args.warm_up_epoch, 103 | warmup_lr_init=args.warm_up_lr_init) 104 | 105 | # numbers of the parameter 106 | with torch.no_grad(): 107 | test_encoder_input = torch.randn(1, (args.num_secret+1)*3, args.image_size_train, args.image_size_train).to(args.device) 108 | test_decoder_input = torch.randn(1, 3, args.image_size_train, args.image_size_train).to(args.device) 109 | encoder_mac, encoder_params = profile(encoder, inputs=(test_encoder_input,)) 110 | decoder_mac, decoder_params = profile(decoder, inputs=(test_decoder_input,)) 111 | print("thop result:encoder FLOPs="+str(encoder_mac*2)+",encoder params="+str(encoder_params)) 112 | print("thop result:decoder FLOPs="+str(decoder_mac*2)+",decoder params="+str(decoder_params)) 113 | 114 | # loss function 115 | conceal_loss_function = L1_Charbonnier_loss().to(args.device) 116 | reveal_loss_function = L1_Charbonnier_loss().to(args.device) 117 | restrict_loss_funtion = Restrict_Loss().to(args.device) 118 | 119 | # train 120 | for i_epoch in range(args.epochs): 121 | sum_loss = [] 122 | scheduler.step(i_epoch+args.train_next) 123 | for i_batch, img in enumerate(DIV2K_multi_train_loader): 124 | img = img.to(args.device) 125 | bs = args.multi_batch_szie 126 | cover = img[0:bs,:,:,:] 127 | secret = img[bs:,:,:,:] 128 | secret = rearrange(secret,'(b n) c h w -> b (n c) h w',n=args.num_secret) 129 | 130 | # encode 131 | msg = torch.cat([cover, secret], 1) 132 | encode_img = encoder(msg) # 编码图像 133 | if args.norm_train == 'clamp': 134 | encode_img_c = torch.clamp(encode_img, 0, 1) 135 | else: 136 | encode_img_c = encode_img 137 | 138 | # decode 139 | decode_img = decoder(encode_img_c) # 解码图像 140 | 141 | # loss 142 | conceal_loss = conceal_loss_function(cover.cuda(), encode_img.cuda()) 143 | reveal_loss = 2*reveal_loss_function(secret.cuda(), decode_img.cuda()) 144 | total_loss = None 145 | if args.norm_train: 146 | restrict_loss = restrict_loss_funtion(encode_img.cuda()) 147 | total_loss = conceal_loss + reveal_loss + restrict_loss 148 | else: 149 | total_loss = conceal_loss + reveal_loss 150 | sum_loss.append(total_loss.item()) 151 | 152 | # backward 153 | total_loss.backward() 154 | optim.step() 155 | optim.zero_grad() 156 | # 进行验证,并记录指标 157 | if i_epoch % args.val_freq == 0: 158 | print("validation begin:") 159 | with torch.no_grad(): 160 | # val 161 | encoder.eval() 162 | decoder.eval() 163 | 164 | # 评价指标 165 | psnr_secret = [] 166 | psnr_cover = [] 167 | ssim_secret = [] 168 | ssim_cover = [] 169 | 170 | # 在验证集上测试 171 | for img in DIV2K_multi_val_loader: 172 | img = img.to(args.device) 173 | bs = args.multi_batch_szie 174 | cover = img[0:bs,:,:,:] 175 | secret = img[bs:,:,:,:] 176 | secret_cat = rearrange(secret,'(b n) c h w -> b (n c) h w',n=args.num_secret) 177 | 178 | # encode 179 | msg = torch.cat([cover, secret_cat], 1) 180 | encode_img = encoder(msg) 181 | if args.norm_train == 'clamp': 182 | encode_img = torch.clamp(encode_img, 0, 1) 183 | 184 | # decode 185 | decode_img = decoder(encode_img) 186 | decode_img = rearrange(decode_img,'b (n c) h w -> (b n) c h w',n=args.num_secret) 187 | 188 | # 计算各种指标 189 | # 拷贝进内存以方便计算 190 | cover = cover.cpu() 191 | secret = secret.cpu() 192 | encode_img = encode_img.cpu() 193 | decode_img = decode_img.cpu() 194 | 195 | psnr_encode_temp = calculate_psnr(cover, encode_img) 196 | psnr_decode_temp = calculate_psnr(secret, decode_img) 197 | psnr_cover.append(psnr_encode_temp) 198 | psnr_secret.append(psnr_decode_temp) 199 | 200 | ssim_encode_temp = calculate_ssim_skimage(cover, encode_img) 201 | ssim_decode_temp = calculate_ssim_skimage(secret, decode_img) 202 | ssim_cover.append(ssim_encode_temp) 203 | ssim_secret.append(ssim_decode_temp) 204 | writer.add_images('image/encode_img', encode_img, dataformats='NCHW', global_step=i_epoch+args.train_next) 205 | writer.add_images('image/decode_img', decode_img, dataformats='NCHW', global_step=i_epoch+args.train_next) 206 | 207 | # 写入 tensorboard 208 | writer.add_scalar("PSNR/PSNR_cover", np.mean(psnr_cover), i_epoch+args.train_next) 209 | writer.add_scalar("PSNR/PSNR_secret", np.mean(psnr_secret), i_epoch+args.train_next) 210 | writer.add_scalar("SSIM/SSIM_cover", np.mean(ssim_cover), i_epoch+args.train_next) 211 | writer.add_scalar("SSIM/SSIM_secret", np.mean(ssim_secret), i_epoch+args.train_next) 212 | print("PSNR_cover:" + str(np.mean(psnr_cover)) + " PSNR_secret:" + str(np.mean(psnr_secret))) 213 | 214 | # 绘制损失函数曲线 215 | print("epoch:"+str(i_epoch+args.train_next) + ":" + str(np.mean(sum_loss))) 216 | if i_epoch % 2 == 0: 217 | writer.add_scalar("loss", np.mean(sum_loss), i_epoch+args.train_next) 218 | 219 | # 保存当前模型以及优化器参数 220 | if (i_epoch % args.save_freq) == 0: 221 | torch.save({'opt': optim.state_dict(), 222 | 'encoder': encoder.state_dict(), 223 | 'decoder': decoder.state_dict()}, save_path + '/model_checkpoint_%.5i' % (i_epoch+args.train_next)+'.pt') 224 | 225 | 226 | torch.save({'opt': optim.state_dict(), 227 | 'encoder': encoder.state_dict(), 228 | 'decoder': decoder.state_dict()}, f'{save_path}/{model_version_name}.pt') 229 | -------------------------------------------------------------------------------- /test_StegFormer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim 4 | import numpy as np 5 | from critic import * 6 | from thop import profile 7 | import os 8 | from model import StegFormer 9 | from datasets import DIV2K_test_cover_loader, DIV2K_test_secret_loader 10 | import time 11 | 12 | 13 | # 以类的方式定义参数 14 | class Args: 15 | def __init__(self) -> None: 16 | self.batch_size = 8 17 | self.image_size = 256 18 | self.patch_size = 16 19 | self.lr = 1e-3 20 | self.epochs = 2000 21 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 22 | self.val_freq = 50 23 | self.save_freq = 200 24 | self.train_next = 0 25 | self.use_model = 'StegFormer-S' 26 | self.input_dim = 3 27 | self.num_secret = 1 28 | self.norm_train = True 29 | self.output_act=None 30 | 31 | 32 | args = Args() 33 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 34 | # 多卡训练 35 | USE_MULTI_GPU = False 36 | 37 | # 检测机器是否有多张显卡 38 | if USE_MULTI_GPU and torch.cuda.device_count() > 1: 39 | MULTI_GPU = True 40 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 41 | os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1" 42 | device_ids = [0, 1] 43 | 44 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 45 | 46 | # 模型初始化 47 | if args.use_model == 'StegFormer-S': 48 | encoder = StegFormer(1024, img_dim=(args.num_secret+1)*3, cnn_emb_dim=8, output_dim=3, 49 | drop_key=False, patch_size=2, window_size=8, output_act=args.output_act,depth_tr=[2,2,6,2,2,6,2,2]) 50 | decoder = StegFormer(1024, img_dim=3, cnn_emb_dim=6, output_dim=args.num_secret*3, drop_key=False, patch_size=1, window_size=8, output_act=None, depth_tr=[2,2,6,2,2,6,2,2]) 51 | if args.use_model == 'StegFormer-B': 52 | encoder = StegFormer(1024, img_dim=(args.num_secret+1)*3, cnn_emb_dim=16, output_dim=3) 53 | decoder = StegFormer(1024, img_dim=3, cnn_emb_dim=16, output_dim=args.num_secret*3) 54 | if args.use_model == 'StegFormer-L': 55 | encoder = StegFormer(1024, img_dim=(args.num_secret+1)*3, cnn_emb_dim=32, output_dim=3, depth=[2, 2, 2, 2, 2, 2, 2, 2, 2]) 56 | decoder = StegFormer(1024, img_dim=3, cnn_emb_dim=32, output_dim=args.num_secret*3, depth=[2, 2, 2, 2, 2, 2, 2, 2, 2]) 57 | encoder.cuda() 58 | decoder.cuda() 59 | 60 | 61 | # 加载模型 62 | model_path = '/home/whq135/code/stegv1/model_uformer/StegFormer_HELD3.pt' 63 | state_dicts = torch.load(model_path) 64 | encoder.load_state_dict(state_dicts['encoder'], strict=False) 65 | decoder.load_state_dict(state_dicts['decoder'], strict=False) 66 | 67 | # 数据并行 68 | if USE_MULTI_GPU: 69 | encoder = torch.nn.DataParallel(encoder, device_ids=device_ids) 70 | decoder = torch.nn.DataParallel(decoder, device_ids=device_ids) 71 | encoder.to(device) 72 | decoder.to(device) 73 | 74 | # 计算模型参数量 75 | with torch.no_grad(): 76 | test_encoder_input = torch.randn(1, 6, 1024, 1024).to(device) 77 | test_decoder_input = torch.randn(1, 3, 1024, 1024).to(device) 78 | encoder_mac, encoder_params = profile(encoder, inputs=(test_encoder_input,)) 79 | 80 | decoder_mac, decoder_params = profile(decoder, inputs=(test_decoder_input,)) 81 | print("thop result:encoder FLOPs="+str(encoder_mac*2)+",encoder params="+str(encoder_params)) 82 | print("thop result:decoder FLOPs="+str(decoder_mac*2)+",decoder params="+str(decoder_params)) 83 | 84 | i = 0 # 为每一张图编号 85 | # 评价指标 86 | 87 | 88 | # without clamp 89 | with torch.no_grad(): 90 | # val 91 | encoder.eval() 92 | decoder.eval() 93 | 94 | psnr_secret = [] 95 | psnr_cover = [] 96 | psnr_secret_y = [] 97 | psnr_cover_y = [] 98 | ssim_secret = [] 99 | ssim_cover = [] 100 | mse_cover = [] 101 | mse_secret = [] 102 | rmse_cover = [] 103 | rmse_secret = [] 104 | mae_cover = [] 105 | mae_secret = [] 106 | 107 | # 在 DIV2K 108 | i = 0 109 | for j in range(1): # 需要的轮次 110 | break 111 | for (cover, secret) in zip(DIV2K_test_cover_loader, DIV2K_test_secret_loader): 112 | cover = cover.to(device) 113 | secret = secret.to(device) 114 | 115 | # encode 116 | msg = torch.cat([cover, secret], 1) 117 | encode_img = encoder(msg) # 添加残差连接 118 | 119 | # decode 120 | decode_img = decoder(encode_img) 121 | 122 | # 限制为图像表示 123 | decode_img = decode_img.clamp(0, 1) 124 | encode_img = encode_img.clamp(0, 1) 125 | 126 | # 计算各种指标 127 | # 拷贝进内存以方便计算 128 | cover = cover.cpu() 129 | secret = secret.cpu() 130 | encode_img = encode_img.cpu() 131 | decode_img = decode_img.cpu() 132 | 133 | # # 计算 RGB PSNR 134 | # psnr_encode_temp = calculate_psnr(cover, encode_img) 135 | # psnr_decode_temp = calculate_psnr(secret, decode_img) 136 | # psnr_cover.append(psnr_encode_temp) 137 | # psnr_secret.append(psnr_decode_temp) 138 | 139 | # 计算 Y 通道 PSNR 140 | psnry_encode_temp = calculate_psnr_skimage(cover, encode_img) 141 | psnry_decode_temp = calculate_psnr_skimage(secret, decode_img) 142 | psnr_cover_y.append(psnry_encode_temp) 143 | psnr_secret_y.append(psnry_decode_temp) 144 | 145 | # 计算 SSIM 146 | ssim_cover_temp=calculate_ssim_skimage(cover,encode_img) 147 | ssim_secret_temp=calculate_ssim_skimage(secret,decode_img) 148 | ssim_cover.append(ssim_cover_temp) 149 | ssim_secret.append(ssim_secret_temp) 150 | 151 | # # 计算 MSE 152 | # mse_cover_temp = calculate_mse(cover, encode_img) 153 | # mse_secret_temp = calculate_mse(secret, decode_img) 154 | # mse_cover.append(mse_cover_temp) 155 | # mse_secret.append(mse_secret_temp) 156 | 157 | # 计算 RMSE 158 | rmse_cover_temp = calculate_rmse(cover, encode_img) 159 | rmse_secret_temp = calculate_rmse(secret, decode_img) 160 | rmse_cover.append(rmse_cover_temp) 161 | rmse_secret.append(rmse_secret_temp) 162 | 163 | # 计算 MAE 164 | mae_cover_temp = calculate_mae(cover, encode_img) 165 | mae_secret_temp = calculate_mae(secret, decode_img) 166 | mae_cover.append(mae_cover_temp) 167 | mae_secret.append(mae_secret_temp) 168 | 169 | # 保存图像 170 | # torchvision.utils.save_image(cover, '/home/whq135/code/stegv1/image/wo_clamp/cover/' + '%.5d.png' % i) 171 | # torchvision.utils.save_image(secret, '/home/whq135/code/stegv1/image/wo_clamp/secret/' + '%.5d.png' % i) 172 | # torchvision.utils.save_image(encode_img, '/home/whq135/code/stegv1/image/wo_clamp/stego/' + '%.5d.png' % i) 173 | # torchvision.utils.save_image(decode_img, '/home/whq135/code/stegv1/image/wo_clamp/secret-rev/' + '%.5d.png' % i) 174 | i += 1 # 下一张图像 175 | print(f'item {i} : cover_psnr: {psnry_encode_temp} cover_ssim: {ssim_cover_temp} cover_mae: {mae_cover_temp} cover_rmse: {rmse_cover_temp}\n \ 176 | secret_psnr: {psnry_decode_temp} secret_ssim: {ssim_secret_temp} secret_mae: {mae_secret_temp} secret_rmse: {rmse_secret_temp}') 177 | cover_div2k_psnr = np.mean(psnr_cover_y) 178 | cover_div2k_ssim = np.mean(ssim_cover) 179 | cover_div2k_mae = np.mean(mae_cover) 180 | cover_div2k_rmse = np.mean(rmse_cover) 181 | secret_div2k_psnr = np.mean(psnr_secret_y) 182 | secret_div2k_ssim = np.mean(ssim_secret) 183 | secret_div2k_mae = np.mean(mae_secret) 184 | secret_div2k_rmse = np.mean(rmse_secret) 185 | print('DIV2K:') 186 | print(f'cover:\n \ 187 | psnr: {cover_div2k_psnr}; ssim: {cover_div2k_ssim}; mae: {cover_div2k_mae}; rmse: {cover_div2k_rmse};\n \ 188 | secret:\n \ 189 | psnr: {secret_div2k_psnr}; ssim: {secret_div2k_ssim}; mae: {secret_div2k_mae}; rmse: {secret_div2k_rmse};\n') 190 | 191 | # # 重载分辨率大小为 256 192 | # encoder = StegFormer(256, img_dim=(args.num_secret+1)*3, cnn_emb_dim=8, output_dim=3) 193 | # decoder = StegFormer(256, img_dim=3, cnn_emb_dim=8, output_dim=args.num_secret*3) 194 | 195 | # # 加载模型 196 | # model_path = '/home/whq135/code/stegv1/model_uformer/StegFormer_woClamp_nuew.pt' 197 | # state_dicts = torch.load(model_path) 198 | # encoder.load_state_dict(state_dicts['encoder'], strict=False) 199 | # decoder.load_state_dict(state_dicts['decoder'], strict=False) 200 | # encoder.to(device) 201 | # decoder.to(device) 202 | 203 | # psnr_secret = [] 204 | # psnr_cover = [] 205 | # psnr_secret_y = [] 206 | # psnr_cover_y = [] 207 | # ssim_secret = [] 208 | # ssim_cover = [] 209 | # mse_cover = [] 210 | # mse_secret = [] 211 | # rmse_cover = [] 212 | # rmse_secret = [] 213 | # mae_cover = [] 214 | # mae_secret = [] 215 | 216 | # # 在 COCO 217 | # i = 0 218 | # for j in range(1): # 需要的轮次 219 | # for (cover, secret) in zip(COCO_test_cover_loader, COCO_test_secret_loader): 220 | # if i == 1000: 221 | # break 222 | # else: 223 | # cover = cover.to(device) 224 | # secret = secret.to(device) 225 | 226 | # # encode 227 | # msg = torch.cat([cover, secret], 1) 228 | # encode_img = encoder(msg) # 添加残差连接 229 | 230 | # # decode 231 | # decode_img = decoder(encode_img) 232 | 233 | # # 限制为图像表示 234 | # decode_img = decode_img.clamp(0, 1) 235 | # encode_img = encode_img.clamp(0, 1) 236 | 237 | # # 计算各种指标 238 | # # 拷贝进内存以方便计算 239 | # cover = cover.cpu() 240 | # secret = secret.cpu() 241 | # encode_img = encode_img.cpu() 242 | # decode_img = decode_img.cpu() 243 | 244 | # # # 计算 RGB PSNR 245 | # # psnr_encode_temp = calculate_psnr(cover, encode_img) 246 | # # psnr_decode_temp = calculate_psnr(secret, decode_img) 247 | # # psnr_cover.append(psnr_encode_temp) 248 | # # psnr_secret.append(psnr_decode_temp) 249 | 250 | # # 计算 Y 通道 PSNR 251 | # psnry_encode_temp = calculate_psnr_skimage(cover, encode_img) 252 | # psnry_decode_temp = calculate_psnr_skimage(secret, decode_img) 253 | # psnr_cover_y.append(psnry_encode_temp) 254 | # psnr_secret_y.append(psnry_decode_temp) 255 | 256 | # # 计算 SSIM 257 | # ssim_cover_temp=calculate_ssim_skimage(cover,encode_img) 258 | # ssim_secret_temp=calculate_ssim_skimage(secret,decode_img) 259 | # ssim_cover.append(ssim_cover_temp) 260 | # ssim_secret.append(ssim_secret_temp) 261 | 262 | # # # 计算 MSE 263 | # # mse_cover_temp = calculate_mse(cover, encode_img) 264 | # # mse_secret_temp = calculate_mse(secret, decode_img) 265 | # # mse_cover.append(mse_cover_temp) 266 | # # mse_secret.append(mse_secret_temp) 267 | 268 | # # 计算 RMSE 269 | # rmse_cover_temp = calculate_rmse(cover, encode_img) 270 | # rmse_secret_temp = calculate_rmse(secret, decode_img) 271 | # rmse_cover.append(rmse_cover_temp) 272 | # rmse_secret.append(rmse_secret_temp) 273 | 274 | # # 计算 MAE 275 | # mae_cover_temp = calculate_mae(cover, encode_img) 276 | # mae_secret_temp = calculate_mae(secret, decode_img) 277 | # mae_cover.append(mae_cover_temp) 278 | # mae_secret.append(mae_secret_temp) 279 | 280 | # # # 保存图像 281 | # # torchvision.utils.save_image(cover, '/home/whq135/code/stegv1/image/wo_clamp/cover/' + '%.5d.png' % i) 282 | # # torchvision.utils.save_image(secret, '/home/whq135/code/stegv1/image/wo_clamp/secret/' + '%.5d.png' % i) 283 | # # torchvision.utils.save_image(encode_img, '/home/whq135/code/stegv1/image/wo_clamp/stego/' + '%.5d.png' % i) 284 | # # torchvision.utils.save_image(decode_img, '/home/whq135/code/stegv1/image/wo_clamp/secret-rev/' + '%.5d.png' % i) 285 | # i += 1 # 下一张图像 286 | # # print(f'item {i} : cover_psnr: {psnry_encode_temp} cover_ssim: {ssim_cover_temp} cover_mae: {mae_cover_temp} cover_rmse: {rmse_cover_temp}\n \ 287 | # # secret_psnr: {psnry_decode_temp} secret_ssim: {ssim_cover_temp} secret_mae: {mae_secret_temp} secret_rmse: {rmse_secret_temp}') 288 | # cover_coco_psnr = np.mean(psnr_cover_y) 289 | # cover_coco_ssim = np.mean(ssim_cover) 290 | # cover_coco_mae = np.mean(mae_cover) 291 | # cover_coco_rmse = np.mean(rmse_cover) 292 | # secret_coco_psnr = np.mean(psnr_secret_y) 293 | # secret_coco_ssim = np.mean(ssim_secret) 294 | # secret_coco_mae = np.mean(mae_secret) 295 | # secret_coco_rmse = np.mean(rmse_secret) 296 | # print('COCO') 297 | # print(f'cover:\n \ 298 | # psnr: {cover_coco_psnr}; ssim: {cover_coco_ssim}; mae: {cover_coco_mae}; rmse: {cover_coco_rmse};\n \ 299 | # secret:\n \ 300 | # psnr: {secret_coco_psnr}; ssim: {secret_coco_ssim}; mae: {secret_coco_mae}; rmse: {secret_coco_rmse};\n') 301 | 302 | # psnr_secret = [] 303 | # psnr_cover = [] 304 | # psnr_secret_y = [] 305 | # psnr_cover_y = [] 306 | # ssim_secret = [] 307 | # ssim_cover = [] 308 | # mse_cover = [] 309 | # mse_secret = [] 310 | # rmse_cover = [] 311 | # rmse_secret = [] 312 | # mae_cover = [] 313 | # mae_secret = [] 314 | 315 | # # 在 ImageNet 316 | # i = 0 317 | # for j in range(1): # 需要的轮次 318 | # for (cover, secret) in zip(ImageNet_test_cover_loader, ImageNet_test_secret_loader): 319 | # if i == 1000: 320 | # break 321 | # else: 322 | # cover = cover.to(device) 323 | # secret = secret.to(device) 324 | 325 | # # encode 326 | # msg = torch.cat([cover, secret], 1) 327 | # encode_img = encoder(msg) # 添加残差连接 328 | 329 | # # decode 330 | # decode_img = decoder(encode_img) 331 | 332 | # # 限制为图像表示 333 | # decode_img = decode_img.clamp(0, 1) 334 | # encode_img = encode_img.clamp(0, 1) 335 | 336 | # # 计算各种指标 337 | # # 拷贝进内存以方便计算 338 | # cover = cover.cpu() 339 | # secret = secret.cpu() 340 | # encode_img = encode_img.cpu() 341 | # decode_img = decode_img.cpu() 342 | 343 | # # # 计算 RGB PSNR 344 | # # psnr_encode_temp = calculate_psnr(cover, encode_img) 345 | # # psnr_decode_temp = calculate_psnr(secret, decode_img) 346 | # # psnr_cover.append(psnr_encode_temp) 347 | # # psnr_secret.append(psnr_decode_temp) 348 | 349 | # # 计算 Y 通道 PSNR 350 | # psnry_encode_temp = calculate_psnr_skimage(cover, encode_img) 351 | # psnry_decode_temp = calculate_psnr_skimage(secret, decode_img) 352 | # psnr_cover_y.append(psnry_encode_temp) 353 | # psnr_secret_y.append(psnry_decode_temp) 354 | 355 | # # 计算 SSIM 356 | # ssim_cover_temp=calculate_ssim_skimage(cover,encode_img) 357 | # ssim_secret_temp=calculate_ssim_skimage(secret,decode_img) 358 | # ssim_cover.append(ssim_cover_temp) 359 | # ssim_secret.append(ssim_secret_temp) 360 | 361 | # # # 计算 MSE 362 | # # mse_cover_temp = calculate_mse(cover, encode_img) 363 | # # mse_secret_temp = calculate_mse(secret, decode_img) 364 | # # mse_cover.append(mse_cover_temp) 365 | # # mse_secret.append(mse_secret_temp) 366 | 367 | # # 计算 RMSE 368 | # rmse_cover_temp = calculate_rmse(cover, encode_img) 369 | # rmse_secret_temp = calculate_rmse(secret, decode_img) 370 | # rmse_cover.append(rmse_cover_temp) 371 | # rmse_secret.append(rmse_secret_temp) 372 | 373 | # # 计算 MAE 374 | # mae_cover_temp = calculate_mae(cover, encode_img) 375 | # mae_secret_temp = calculate_mae(secret, decode_img) 376 | # mae_cover.append(mae_cover_temp) 377 | # mae_secret.append(mae_secret_temp) 378 | 379 | # # # 保存图像 380 | # # torchvision.utils.save_image(cover, '/home/whq135/code/stegv1/image/wo_clamp/cover/' + '%.5d.png' % i) 381 | # # torchvision.utils.save_image(secret, '/home/whq135/code/stegv1/image/wo_clamp/secret/' + '%.5d.png' % i) 382 | # # torchvision.utils.save_image(encode_img, '/home/whq135/code/stegv1/image/wo_clamp/stego/' + '%.5d.png' % i) 383 | # # torchvision.utils.save_image(decode_img, '/home/whq135/code/stegv1/image/wo_clamp/secret-rev/' + '%.5d.png' % i) 384 | # i += 1 # 下一张图像 385 | # # print(f'item {i} : cover_psnr: {psnry_encode_temp} cover_ssim: {ssim_cover_temp} cover_mae: {mae_cover_temp} cover_rmse: {rmse_cover_temp}\n \ 386 | # # secret_psnr: {psnry_decode_temp} secret_ssim: {ssim_cover_temp} secret_mae: {mae_secret_temp} secret_rmse: {rmse_secret_temp}') 387 | # cover_imagenet_psnr = np.mean(psnr_cover_y) 388 | # cover_imagenet_ssim = np.mean(ssim_cover) 389 | # cover_imagenet_mae = np.mean(mae_cover) 390 | # cover_imagenet_rmse = np.mean(rmse_cover) 391 | # secret_imagenet_psnr = np.mean(psnr_secret_y) 392 | # secret_imagenet_ssim = np.mean(ssim_secret) 393 | # secret_imagenet_mae = np.mean(mae_secret) 394 | # secret_imagenet_rmse = np.mean(rmse_secret) 395 | # print('imagenet') 396 | # print(f'cover:\n \ 397 | # psnr: {cover_imagenet_psnr}; ssim: {cover_imagenet_ssim}; mae: {cover_imagenet_mae}; rmse: {cover_imagenet_rmse};\n \ 398 | # secret:\n \ 399 | # psnr: {secret_imagenet_psnr}; ssim: {secret_imagenet_ssim}; mae: {secret_imagenet_mae}; rmse: {secret_imagenet_rmse};\n') 400 | 401 | 402 | # clamp 403 | print('clamp') 404 | with torch.no_grad(): 405 | # val 406 | encoder = StegFormer(1024, img_dim=(args.num_secret+1)*3, cnn_emb_dim=12, output_dim=3, 407 | drop_key=False, patch_size=2, window_size=8, output_act=args.output_act,depth_tr=[2,2,2,2,2,2,2,2]) 408 | decoder = StegFormer(1024, img_dim=3, cnn_emb_dim=6, output_dim=args.num_secret*3, drop_key=False, patch_size=1, window_size=8, output_act=None, depth_tr=[2,2,6,2,2,6,2,2]) 409 | # 加载模型 410 | model_path = '/home/whq135/code/stegv1/model_uformer/StegFormer_HELD13.pt' 411 | state_dicts = torch.load(model_path) 412 | encoder.load_state_dict(state_dicts['encoder'], strict=False) 413 | decoder.load_state_dict(state_dicts['decoder'], strict=False) 414 | encoder.to(device) 415 | decoder.to(device) 416 | encoder.eval() 417 | decoder.eval() 418 | 419 | psnr_secret = [] 420 | psnr_cover = [] 421 | psnr_secret_y = [] 422 | psnr_cover_y = [] 423 | ssim_secret = [] 424 | ssim_cover = [] 425 | mse_cover = [] 426 | mse_secret = [] 427 | rmse_cover = [] 428 | rmse_secret = [] 429 | mae_cover = [] 430 | mae_secret = [] 431 | encode_times=[] 432 | decode_times=[] 433 | 434 | # 在 DIV2K 435 | i = 0 436 | for j in range(1): # 需要的轮次 437 | for (cover, secret) in zip(DIV2K_test_cover_loader, DIV2K_test_secret_loader): 438 | cover = cover.to(device) 439 | secret = secret.to(device) 440 | 441 | # encode 442 | msg = torch.cat([cover, secret], 1) 443 | start_time=time.perf_counter() 444 | encode_img = encoder(msg) # 添加残差连接 445 | end_time=time.perf_counter() 446 | encode_time=end_time-start_time 447 | encode_times.append(encode_time) 448 | encode_img = torch.clamp(encode_img,0,1) 449 | 450 | # decode 451 | start_time=time.perf_counter() 452 | decode_img = decoder(encode_img) 453 | end_time=time.perf_counter() 454 | decode_time=end_time-start_time 455 | decode_times.append(decode_time) 456 | 457 | # 限制为图像表示 458 | decode_img = decode_img.clamp(0, 1) 459 | encode_img = encode_img.clamp(0, 1) 460 | print(f'{i} encode time:{encode_time}; decode time: {decode_time}') 461 | 462 | # 计算各种指标 463 | # 拷贝进内存以方便计算 464 | cover = cover.cpu() 465 | secret = secret.cpu() 466 | encode_img = encode_img.cpu() 467 | decode_img = decode_img.cpu() 468 | 469 | # # 计算 RGB PSNR 470 | # psnr_encode_temp = calculate_psnr(cover, encode_img) 471 | # psnr_decode_temp = calculate_psnr(secret, decode_img) 472 | # psnr_cover.append(psnr_encode_temp) 473 | # psnr_secret.append(psnr_decode_temp) 474 | 475 | # 计算 Y 通道 PSNR 476 | psnry_encode_temp = calculate_psnr_skimage(cover, encode_img) 477 | psnry_decode_temp = calculate_psnr_skimage(secret, decode_img) 478 | psnr_cover_y.append(psnry_encode_temp) 479 | psnr_secret_y.append(psnry_decode_temp) 480 | 481 | # 计算 SSIM 482 | ssim_cover_temp=calculate_ssim_skimage(cover,encode_img) 483 | ssim_secret_temp=calculate_ssim_skimage(secret,decode_img) 484 | ssim_cover.append(ssim_cover_temp) 485 | ssim_secret.append(ssim_secret_temp) 486 | 487 | # # 计算 MSE 488 | # mse_cover_temp = calculate_mse(cover, encode_img) 489 | # mse_secret_temp = calculate_mse(secret, decode_img) 490 | # mse_cover.append(mse_cover_temp) 491 | # mse_secret.append(mse_secret_temp) 492 | 493 | # 计算 RMSE 494 | rmse_cover_temp = calculate_rmse(cover, encode_img) 495 | rmse_secret_temp = calculate_rmse(secret, decode_img) 496 | rmse_cover.append(rmse_cover_temp) 497 | rmse_secret.append(rmse_secret_temp) 498 | 499 | # 计算 MAE 500 | mae_cover_temp = calculate_mae(cover, encode_img) 501 | mae_secret_temp = calculate_mae(secret, decode_img) 502 | mae_cover.append(mae_cover_temp) 503 | mae_secret.append(mae_secret_temp) 504 | 505 | # # 保存图像 506 | # torchvision.utils.save_image(cover, '/home/whq135/code/stegv1/image/wo_clamp/cover/' + '%.5d.png' % i) 507 | # torchvision.utils.save_image(secret, '/home/whq135/code/stegv1/image/wo_clamp/secret/' + '%.5d.png' % i) 508 | # torchvision.utils.save_image(encode_img, '/home/whq135/code/stegv1/image/wo_clamp/stego/' + '%.5d.png' % i) 509 | # torchvision.utils.save_image(decode_img, '/home/whq135/code/stegv1/image/wo_clamp/secret-rev/' + '%.5d.png' % i) 510 | i += 1 # 下一张图像 511 | print(f'item {i} : cover_psnr: {psnry_encode_temp} cover_ssim: {ssim_cover_temp} cover_mae: {mae_cover_temp} cover_rmse: {rmse_cover_temp}\n \ 512 | secret_psnr: {psnry_decode_temp} secret_ssim: {ssim_secret_temp} secret_mae: {mae_secret_temp} secret_rmse: {rmse_secret_temp}') 513 | cover_div2k_psnr = np.mean(psnr_cover_y) 514 | cover_div2k_ssim = np.mean(ssim_cover) 515 | cover_div2k_mae = np.mean(mae_cover) 516 | cover_div2k_rmse = np.mean(rmse_cover) 517 | secret_div2k_psnr = np.mean(psnr_secret_y) 518 | secret_div2k_ssim = np.mean(ssim_secret) 519 | secret_div2k_mae = np.mean(mae_secret) 520 | secret_div2k_rmse = np.mean(rmse_secret) 521 | print('DIV2K:') 522 | print(f'cover:\n \ 523 | psnr: {cover_div2k_psnr}; ssim: {cover_div2k_ssim}; mae: {cover_div2k_mae}; rmse: {cover_div2k_rmse};\n \ 524 | secret:\n \ 525 | psnr: {secret_div2k_psnr}; ssim: {secret_div2k_ssim}; mae: {secret_div2k_mae}; rmse: {secret_div2k_rmse};\n') 526 | print(f'encode_time:{np.mean(encode_times)}') 527 | print(f'decode_time:{np.mean(decode_times)}') 528 | 529 | # # 重载分辨率大小为 256 530 | # encoder = StegFormer(256, img_dim=(args.num_secret+1)*3, cnn_emb_dim=8, output_dim=3) 531 | # decoder = StegFormer(256, img_dim=3, cnn_emb_dim=8, output_dim=args.num_secret*3) 532 | 533 | # # 加载模型 534 | # model_path = '/home/whq135/code/stegv1/model_uformer/StegFormer_Clamp_nuew.pt' 535 | # state_dicts = torch.load(model_path) 536 | # encoder.load_state_dict(state_dicts['encoder'], strict=False) 537 | # decoder.load_state_dict(state_dicts['decoder'], strict=False) 538 | # encoder.to(device) 539 | # decoder.to(device) 540 | 541 | # psnr_secret = [] 542 | # psnr_cover = [] 543 | # psnr_secret_y = [] 544 | # psnr_cover_y = [] 545 | # ssim_secret = [] 546 | # ssim_cover = [] 547 | # mse_cover = [] 548 | # mse_secret = [] 549 | # rmse_cover = [] 550 | # rmse_secret = [] 551 | # mae_cover = [] 552 | # mae_secret = [] 553 | 554 | # # 在 COCO 555 | # i = 0 556 | # for j in range(1): # 需要的轮次 557 | # for (cover, secret) in zip(COCO_test_cover_loader, COCO_test_secret_loader): 558 | # if i == 1000: 559 | # break 560 | # else: 561 | # cover = cover.to(device) 562 | # secret = secret.to(device) 563 | 564 | # # encode 565 | # msg = torch.cat([cover, secret], 1) 566 | # encode_img = encoder(msg) # 添加残差连接 567 | # encode_img = torch.clamp(encode_img,0,1) 568 | 569 | # # decode 570 | # decode_img = decoder(encode_img) 571 | 572 | # # 限制为图像表示 573 | # decode_img = decode_img.clamp(0, 1) 574 | # encode_img = encode_img.clamp(0, 1) 575 | 576 | # # 计算各种指标 577 | # # 拷贝进内存以方便计算 578 | # cover = cover.cpu() 579 | # secret = secret.cpu() 580 | # encode_img = encode_img.cpu() 581 | # decode_img = decode_img.cpu() 582 | 583 | # # # 计算 RGB PSNR 584 | # # psnr_encode_temp = calculate_psnr(cover, encode_img) 585 | # # psnr_decode_temp = calculate_psnr(secret, decode_img) 586 | # # psnr_cover.append(psnr_encode_temp) 587 | # # psnr_secret.append(psnr_decode_temp) 588 | 589 | # # 计算 Y 通道 PSNR 590 | # psnry_encode_temp = calculate_psnr_skimage(cover, encode_img) 591 | # psnry_decode_temp = calculate_psnr_skimage(secret, decode_img) 592 | # psnr_cover_y.append(psnry_encode_temp) 593 | # psnr_secret_y.append(psnry_decode_temp) 594 | 595 | # # 计算 SSIM 596 | # ssim_cover_temp=calculate_ssim_skimage(cover,encode_img) 597 | # ssim_secret_temp=calculate_ssim_skimage(secret,decode_img) 598 | # ssim_cover.append(ssim_cover_temp) 599 | # ssim_secret.append(ssim_secret_temp) 600 | 601 | # # # 计算 MSE 602 | # # mse_cover_temp = calculate_mse(cover, encode_img) 603 | # # mse_secret_temp = calculate_mse(secret, decode_img) 604 | # # mse_cover.append(mse_cover_temp) 605 | # # mse_secret.append(mse_secret_temp) 606 | 607 | # # 计算 RMSE 608 | # rmse_cover_temp = calculate_rmse(cover, encode_img) 609 | # rmse_secret_temp = calculate_rmse(secret, decode_img) 610 | # rmse_cover.append(rmse_cover_temp) 611 | # rmse_secret.append(rmse_secret_temp) 612 | 613 | # # 计算 MAE 614 | # mae_cover_temp = calculate_mae(cover, encode_img) 615 | # mae_secret_temp = calculate_mae(secret, decode_img) 616 | # mae_cover.append(mae_cover_temp) 617 | # mae_secret.append(mae_secret_temp) 618 | 619 | # # # 保存图像 620 | # # torchvision.utils.save_image(cover, '/home/whq135/code/stegv1/image/wo_clamp/cover/' + '%.5d.png' % i) 621 | # # torchvision.utils.save_image(secret, '/home/whq135/code/stegv1/image/wo_clamp/secret/' + '%.5d.png' % i) 622 | # # torchvision.utils.save_image(encode_img, '/home/whq135/code/stegv1/image/wo_clamp/stego/' + '%.5d.png' % i) 623 | # # torchvision.utils.save_image(decode_img, '/home/whq135/code/stegv1/image/wo_clamp/secret-rev/' + '%.5d.png' % i) 624 | # i += 1 # 下一张图像 625 | # # print(f'item {i} : cover_psnr: {psnry_encode_temp} cover_ssim: {ssim_cover_temp} cover_mae: {mae_cover_temp} cover_rmse: {rmse_cover_temp}\n \ 626 | # # secret_psnr: {psnry_decode_temp} secret_ssim: {ssim_cover_temp} secret_mae: {mae_secret_temp} secret_rmse: {rmse_secret_temp}') 627 | # cover_coco_psnr = np.mean(psnr_cover_y) 628 | # cover_coco_ssim = np.mean(ssim_cover) 629 | # cover_coco_mae = np.mean(mae_cover) 630 | # cover_coco_rmse = np.mean(rmse_cover) 631 | # secret_coco_psnr = np.mean(psnr_secret_y) 632 | # secret_coco_ssim = np.mean(ssim_secret) 633 | # secret_coco_mae = np.mean(mae_secret) 634 | # secret_coco_rmse = np.mean(rmse_secret) 635 | # print('COCO') 636 | # print(f'cover:\n \ 637 | # psnr: {cover_coco_psnr}; ssim: {cover_coco_ssim}; mae: {cover_coco_mae}; rmse: {cover_coco_rmse};\n \ 638 | # secret:\n \ 639 | # psnr: {secret_coco_psnr}; ssim: {secret_coco_ssim}; mae: {secret_coco_mae}; rmse: {secret_coco_rmse};\n') 640 | 641 | # psnr_secret = [] 642 | # psnr_cover = [] 643 | # psnr_secret_y = [] 644 | # psnr_cover_y = [] 645 | # ssim_secret = [] 646 | # ssim_cover = [] 647 | # mse_cover = [] 648 | # mse_secret = [] 649 | # rmse_cover = [] 650 | # rmse_secret = [] 651 | # mae_cover = [] 652 | # mae_secret = [] 653 | 654 | # # 在 ImageNet 655 | # i = 0 656 | # for j in range(1): # 需要的轮次 657 | # for (cover, secret) in zip(ImageNet_test_cover_loader, ImageNet_test_secret_loader): 658 | # if i == 1000: 659 | # break 660 | # else: 661 | # cover = cover.to(device) 662 | # secret = secret.to(device) 663 | 664 | # # encode 665 | # msg = torch.cat([cover, secret], 1) 666 | # encode_img = encoder(msg) # 添加残差连接 667 | # encode_img = torch.clamp(encode_img,0,1) 668 | 669 | # # decode 670 | # decode_img = decoder(encode_img) 671 | 672 | # # 限制为图像表示 673 | # decode_img = decode_img.clamp(0, 1) 674 | # encode_img = encode_img.clamp(0, 1) 675 | 676 | # # 计算各种指标 677 | # # 拷贝进内存以方便计算 678 | # cover = cover.cpu() 679 | # secret = secret.cpu() 680 | # encode_img = encode_img.cpu() 681 | # decode_img = decode_img.cpu() 682 | 683 | # # # 计算 RGB PSNR 684 | # # psnr_encode_temp = calculate_psnr(cover, encode_img) 685 | # # psnr_decode_temp = calculate_psnr(secret, decode_img) 686 | # # psnr_cover.append(psnr_encode_temp) 687 | # # psnr_secret.append(psnr_decode_temp) 688 | 689 | # # 计算 Y 通道 PSNR 690 | # psnry_encode_temp = calculate_psnr_skimage(cover, encode_img) 691 | # psnry_decode_temp = calculate_psnr_skimage(secret, decode_img) 692 | # psnr_cover_y.append(psnry_encode_temp) 693 | # psnr_secret_y.append(psnry_decode_temp) 694 | 695 | # # 计算 SSIM 696 | # ssim_cover_temp=calculate_ssim_skimage(cover,encode_img) 697 | # ssim_secret_temp=calculate_ssim_skimage(secret,decode_img) 698 | # ssim_cover.append(ssim_cover_temp) 699 | # ssim_secret.append(ssim_secret_temp) 700 | 701 | # # # 计算 MSE 702 | # # mse_cover_temp = calculate_mse(cover, encode_img) 703 | # # mse_secret_temp = calculate_mse(secret, decode_img) 704 | # # mse_cover.append(mse_cover_temp) 705 | # # mse_secret.append(mse_secret_temp) 706 | 707 | # # 计算 RMSE 708 | # rmse_cover_temp = calculate_rmse(cover, encode_img) 709 | # rmse_secret_temp = calculate_rmse(secret, decode_img) 710 | # rmse_cover.append(rmse_cover_temp) 711 | # rmse_secret.append(rmse_secret_temp) 712 | 713 | # # 计算 MAE 714 | # mae_cover_temp = calculate_mae(cover, encode_img) 715 | # mae_secret_temp = calculate_mae(secret, decode_img) 716 | # mae_cover.append(mae_cover_temp) 717 | # mae_secret.append(mae_secret_temp) 718 | 719 | # # # 保存图像 720 | # # torchvision.utils.save_image(cover, '/home/whq135/code/stegv1/image/wo_clamp/cover/' + '%.5d.png' % i) 721 | # # torchvision.utils.save_image(secret, '/home/whq135/code/stegv1/image/wo_clamp/secret/' + '%.5d.png' % i) 722 | # # torchvision.utils.save_image(encode_img, '/home/whq135/code/stegv1/image/wo_clamp/stego/' + '%.5d.png' % i) 723 | # # torchvision.utils.save_image(decode_img, '/home/whq135/code/stegv1/image/wo_clamp/secret-rev/' + '%.5d.png' % i) 724 | # i += 1 # 下一张图像 725 | # # print(f'item {i} : cover_psnr: {psnry_encode_temp} cover_ssim: {ssim_cover_temp} cover_mae: {mae_cover_temp} cover_rmse: {rmse_cover_temp}\n \ 726 | # # secret_psnr: {psnry_decode_temp} secret_ssim: {ssim_cover_temp} secret_mae: {mae_secret_temp} secret_rmse: {rmse_secret_temp}') 727 | # cover_imagenet_psnr = np.mean(psnr_cover_y) 728 | # cover_imagenet_ssim = np.mean(ssim_cover) 729 | # cover_imagenet_mae = np.mean(mae_cover) 730 | # cover_imagenet_rmse = np.mean(rmse_cover) 731 | # secret_imagenet_psnr = np.mean(psnr_secret_y) 732 | # secret_imagenet_ssim = np.mean(ssim_secret) 733 | # secret_imagenet_mae = np.mean(mae_secret) 734 | # secret_imagenet_rmse = np.mean(rmse_secret) 735 | # print('imagenet') 736 | # print(f'cover:\n \ 737 | # psnr: {cover_imagenet_psnr}; ssim: {cover_imagenet_ssim}; mae: {cover_imagenet_mae}; rmse: {cover_imagenet_rmse};\n \ 738 | # secret:\n \ 739 | # psnr: {secret_imagenet_psnr}; ssim: {secret_imagenet_ssim}; mae: {secret_imagenet_mae}; rmse: {secret_imagenet_rmse};\n') 740 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.checkpoint as checkpoint 4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 5 | from einops import rearrange 6 | 7 | 8 | class Mlp(nn.Module): 9 | """ 10 | MLP of the transformer layer 11 | input : x(B,N,C) 12 | Args: 13 | in_features: the dim of the input features 14 | hidden_features: the hidden dim of the MLP 15 | out_features: the dim of the output features 16 | act_layers: the act function of the MLP, default GELU 17 | drop: dropout ratio 18 | returns: (B,N,C) 19 | """ 20 | 21 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 22 | super().__init__() 23 | out_features = out_features or in_features 24 | hidden_features = hidden_features or in_features 25 | self.fc1 = nn.Linear(in_features, hidden_features) 26 | self.act = act_layer() 27 | self.fc2 = nn.Linear(hidden_features, out_features) 28 | self.drop = nn.Dropout(drop) 29 | 30 | def forward(self, x): 31 | x = self.fc1(x) 32 | x = self.act(x) 33 | x = self.drop(x) 34 | x = self.fc2(x) 35 | x = self.drop(x) 36 | return x 37 | 38 | 39 | class ConvFFN(nn.Module): 40 | """CNN for the transformer FFN module 41 | 42 | input: x(B,C,H,W) 43 | Args: 44 | in_features: the dim of the input features 45 | hidden_features: the hidden dim of the MLP 46 | out_features: the dim of the output features 47 | act_layers: the act function of the MLP, default GELU 48 | drop: dropout ratio 49 | returns: (B,C,H,W) 50 | """ 51 | 52 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 53 | super().__init__() 54 | out_features = out_features or in_features 55 | hidden_features = hidden_features or in_features 56 | self.conv = nn.Conv2d(in_features, in_features, 3, 1, 3, groups=in_features) 57 | self.norm = nn.LayerNorm(in_features) 58 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1, 1, 0) 59 | self.act = act_layer() 60 | self.fc2 = nn.Conv2d(hidden_features, in_features, 1, 1, 0) 61 | self.drop = nn.Dropout(drop) 62 | 63 | def forward(self, x): 64 | x = self.conv(x) 65 | x = x.permute(0, 2, 3, 1).contiguous() # B H W C 66 | X = self.norm(x) 67 | x = x.permute(0, 3, 1, 2).contiguous() 68 | x = self.fc1(x) 69 | x = self.act(x) 70 | x = self.drop(x) 71 | x = self.fc2(x) 72 | x = self.drop(x) 73 | return x 74 | 75 | 76 | def window_partition(x, window_size): 77 | """ 78 | window partion function 79 | Args: 80 | x: (B, H, W, C) 81 | window_size (int): window size 82 | 83 | Returns: 84 | windows: (num_windows*B, window_size, window_size, C) 85 | """ 86 | B, H, W, C = x.shape 87 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 88 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 89 | return windows 90 | 91 | 92 | def window_reverse(windows, window_size, H, W): 93 | """ 94 | Args: 95 | windows: (num_windows*B, window_size, window_size, C) 96 | window_size (int): Window size 97 | H (int): Height of image 98 | W (int): Width of image 99 | 100 | Returns: 101 | x: (B, H, W, C) 102 | """ 103 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 104 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 105 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 106 | return x 107 | 108 | 109 | class WindowAttention(nn.Module): 110 | r""" Window based multi-head self attention (W-MSA) module with relative position bias. 111 | It supports both of shifted and non-shifted window. 112 | 113 | Args: 114 | dim (int): Number of input channels. 115 | window_size (tuple[int]): The height and width of the window. 116 | num_heads (int): Number of attention heads. 117 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 118 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 119 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 120 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 121 | """ 122 | 123 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., drop_key=False): 124 | 125 | super().__init__() 126 | self.dim = dim 127 | self.window_size = window_size # Wh, Ww 128 | self.num_heads = num_heads 129 | head_dim = dim // num_heads 130 | self.scale = qk_scale or head_dim ** -0.5 131 | self.drop_key = drop_key 132 | 133 | # define a parameter table of relative position bias 134 | self.relative_position_bias_table = nn.Parameter( 135 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 136 | 137 | # get pair-wise relative position index for each token inside the window 138 | coords_h = torch.arange(self.window_size[0]) 139 | coords_w = torch.arange(self.window_size[1]) 140 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 141 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 142 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 143 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 144 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 145 | relative_coords[:, :, 1] += self.window_size[1] - 1 146 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 147 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 148 | self.register_buffer("relative_position_index", relative_position_index) 149 | 150 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 151 | self.attn_drop = nn.Dropout(attn_drop) 152 | self.proj = nn.Linear(dim, dim) 153 | self.proj_drop = nn.Dropout(proj_drop) 154 | 155 | trunc_normal_(self.relative_position_bias_table, std=.02) 156 | self.softmax = nn.Softmax(dim=-1) 157 | 158 | def forward(self, x, mask=None): 159 | """ 160 | Args: 161 | x: input features with shape of (num_windows*B, N, C) 162 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 163 | """ 164 | B_, N, C = x.shape 165 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 166 | 0, 3, 1, 4).contiguous() # (3,B*num_windows,heads,N,head_dim) 167 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 168 | 169 | q = q * self.scale 170 | attn = (q @ k.transpose(-2, -1)) # q:(B*windows, heads, N, C), att:(B*windows, heads, N, N) 171 | 172 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 173 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 174 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 175 | attn = attn + relative_position_bias.unsqueeze(0) # positional enbedding 176 | 177 | # Drop key 178 | if self.drop_key: 179 | m_r = torch.ones_like(attn)*0.1 180 | attn = attn+torch.bernoulli(m_r)*-1e12 181 | 182 | if mask is not None: 183 | nW = mask.shape[0] # mask(nW,N,N) 184 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) # att(bs,nW,heads,N,N); mask(1,nW,1,N,N) 185 | attn = attn.view(-1, self.num_heads, N, N) 186 | attn = self.softmax(attn) 187 | else: 188 | attn = self.softmax(attn) 189 | 190 | attn = self.attn_drop(attn) 191 | 192 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # x:(B*windows, N, C) 193 | x = self.proj(x) 194 | x = self.proj_drop(x) 195 | return x 196 | 197 | 198 | class Channel_Adapter(nn.Module): 199 | """The Channel Adapter module 200 | args: 201 | num_channels: the channel number of the feature map 202 | resolution: the resolution of the feature map 203 | """ 204 | 205 | def __init__(self, num_channels, resolution): 206 | super().__init__() 207 | self.proj = nn.Sequential( 208 | nn.Conv2d(num_channels, 4*num_channels, 1), 209 | nn.GELU(), 210 | nn.Conv2d(4*num_channels, num_channels, 1) 211 | ) 212 | self.pool = nn.AvgPool2d(resolution) 213 | self.mlp = nn.Sequential( 214 | nn.Linear(num_channels, 4*num_channels), 215 | nn.Linear(4*num_channels, num_channels), 216 | nn.Sigmoid() 217 | ) 218 | 219 | def forward(self, x): 220 | B, C, H, W = x.shape 221 | x_proj = self.proj(x) 222 | avg_x = self.pool(x_proj).reshape(B, C) # (B,C) 223 | attn = self.mlp(avg_x) 224 | attn = attn.reshape(B, C, 1, 1)*x_proj 225 | x = x+attn 226 | return x 227 | 228 | 229 | class PEG_Conv(nn.Module): 230 | """Conditional position encode generator, using depth-wise convolution 231 | 232 | Args: 233 | input: (B,C,H,W), based on patch 234 | dim: the dim of the heads' token 235 | """ 236 | 237 | def __init__(self, dim): 238 | super().__init__() 239 | self.conv = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) 240 | 241 | def forward(self, x): 242 | pe = self.conv(x) 243 | return pe 244 | 245 | 246 | class ConditinoalAttention(nn.Module): 247 | r""" Global attention with conditional positional encoding. 248 | 249 | Args: 250 | dim (int): Number of input channels. 251 | num_heads (int): Number of attention heads. 252 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 253 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 254 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 255 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 256 | drop_key (bool, optional): Using dropkey or not. Default: False 257 | """ 258 | 259 | def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., drop_key=False): 260 | 261 | super().__init__() 262 | self.dim = dim 263 | self.head_dim = dim//num_heads 264 | self.num_heads = num_heads 265 | head_dim = dim // num_heads 266 | self.scale = qk_scale or head_dim ** -0.5 267 | self.drop_key = drop_key 268 | 269 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) # 生成 QKV 270 | self.attn_drop = nn.Dropout(attn_drop) # 注意力的 dropout 271 | self.proj = nn.Linear(dim, dim) # 输出映射矩阵 272 | self.proj_drop = nn.Dropout(proj_drop) # 输出的 dropout 273 | 274 | self.softmax = nn.Softmax(dim=-1) 275 | 276 | def forward(self, x, mask=None): 277 | """ 278 | Args: 279 | x: input features with shape of (B, N, C) 280 | """ 281 | B_, N, C = x.shape 282 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 283 | 0, 3, 1, 4).contiguous() # (3,B*num_windows,heads,N,head_dim) 284 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 285 | 286 | q = q * self.scale 287 | attn = (q @ k.transpose(-2, -1)) # q:(B, heads, N, C), att:(B, heads, N, N) 288 | 289 | # Drop key 290 | if self.drop_key: 291 | m_r = torch.ones_like(attn)*0.1 292 | attn = attn+torch.bernoulli(m_r)*-1e12 293 | 294 | attn = self.softmax(attn) 295 | 296 | attn = self.attn_drop(attn) 297 | 298 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) # x:(B, N, C) 299 | x = self.proj(x) 300 | x = self.proj_drop(x) 301 | return x 302 | 303 | 304 | def PatchReverse(x, patch_size): 305 | """Transform the tokens back to image 306 | 307 | input: 308 | x: (B, N, C), N = (ph pw): Token representation. 309 | returns: 310 | x: (B,C,H,W) 311 | 312 | """ 313 | _, N, _ = x.shape 314 | p1 = p2 = int(N**0.5) # 有多少个 patch 315 | x = rearrange(x, 'b (p1 p2) (ph pw c) -> b c (p1 ph) (p2 pw)', p1=p1, p2=p2, ph=patch_size, pw=patch_size).contiguous() 316 | return x 317 | 318 | 319 | def PatchDivide(x, patch_size): 320 | """Transform the image to tokens. 321 | 322 | input: 323 | x: (B,C,H,W) 324 | returns: 325 | x: (B, N, C) 326 | """ 327 | x = rearrange(x, 'b c (p1 h) (p2 w) -> b (p1 p2) (h w c)', h=patch_size, w=patch_size).contiguous() 328 | return x 329 | 330 | 331 | class DownSampler(nn.Module): 332 | """Down sample the feature map. 333 | """ 334 | 335 | def __init__(self, in_channel, out_channel): 336 | super().__init__() 337 | self.conv = nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=4, stride=2, padding=1) 338 | 339 | def forward(self, x): 340 | x = self.conv(x) 341 | return x 342 | 343 | 344 | class UpSampler(nn.Module): 345 | """Up sample the feature map.""" 346 | 347 | def __init__(self, in_channel, out_channel): 348 | super().__init__() 349 | self.conv = nn.ConvTranspose2d(in_channels=in_channel, out_channels=out_channel, kernel_size=2, stride=2) 350 | 351 | def forward(self, x): 352 | x = self.conv(x) 353 | return x 354 | 355 | 356 | class SwinTransformerBlock(nn.Module): 357 | r""" Swin Transformer Block. 358 | 359 | Args: 360 | dim (int): Number of input channels. 361 | input_resolution (tuple[int]): Input resulotion, based on patch. 362 | num_heads (int): Number of attention heads. 363 | window_size (int): Window size. 364 | shift_size (int): Shift size for SW-MSA. 365 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 366 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 367 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 368 | drop (float, optional): Dropout rate. Default: 0.0 369 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 370 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 371 | act_layer (nn.Module, optional): Activation layer. Default: nn.LeakyReLU 372 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 373 | ffn_type (string): FFN module,using Convolution or MLP, Default: Mlp 374 | fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False 375 | drop_key (bool, optional): Using dropkey or not. Default: False 376 | 377 | 输入的 x 为 (B,N,C) 表示,在过程中进行窗口划分 378 | """ 379 | 380 | def __init__(self, dim, input_resolution, num_heads, window_size=8, shift_size=0, # 你妈的,这里写错了,弄得窗口变得没有滑动了 381 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 382 | act_layer=nn.LeakyReLU, norm_layer=nn.LayerNorm, ffn_type='Mlp', # 消融:修改为 LeakyReLU 383 | fused_window_process=False, drop_key=False): 384 | super().__init__() 385 | self.dim = dim 386 | self.input_resolution = input_resolution 387 | self.num_heads = num_heads 388 | self.window_size = window_size 389 | self.shift_size = shift_size 390 | self.mlp_ratio = mlp_ratio 391 | self.ffn_type = ffn_type 392 | if min(self.input_resolution) <= self.window_size: 393 | # if window size is larger than input resolution, we don't partition windows 394 | self.shift_size = 0 395 | self.window_size = min(self.input_resolution) 396 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 397 | 398 | self.norm1 = norm_layer(dim) 399 | self.attn = WindowAttention( 400 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 401 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, drop_key=drop_key) 402 | 403 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 404 | self.norm2 = norm_layer(dim) 405 | mlp_hidden_dim = int(dim * mlp_ratio) 406 | if self.ffn_type == 'Mlp': 407 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 408 | else: 409 | self.mlp = ConvFFN(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 410 | 411 | if self.shift_size > 0: 412 | # calculate attention mask for SW-MSA 413 | H, W = self.input_resolution 414 | img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 415 | h_slices = (slice(0, -self.window_size), 416 | slice(-self.window_size, -self.shift_size), 417 | slice(-self.shift_size, None)) 418 | w_slices = (slice(0, -self.window_size), 419 | slice(-self.window_size, -self.shift_size), 420 | slice(-self.shift_size, None)) 421 | cnt = 0 422 | for h in h_slices: 423 | for w in w_slices: 424 | img_mask[:, h, w, :] = cnt 425 | cnt += 1 426 | 427 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 428 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 429 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 430 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 431 | else: 432 | attn_mask = None 433 | 434 | self.register_buffer("attn_mask", attn_mask, persistent=False) # 注册为 buffer 435 | self.fused_window_process = fused_window_process # 是否执行 merge 操作 436 | 437 | def forward(self, x): 438 | H, W = self.input_resolution 439 | B, L, C = x.shape 440 | assert L == H * W, "input feature has wrong size" 441 | 442 | shortcut = x 443 | x = self.norm1(x) 444 | x = x.view(B, H, W, C) # turn to 2-D reprensentation 445 | 446 | # cyclic shift 447 | if self.shift_size > 0: 448 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 449 | # partition windows 450 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 451 | else: 452 | shifted_x = x 453 | # partition windows 454 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 455 | 456 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 457 | 458 | # W-MSA/SW-MSA 459 | attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C 460 | 461 | # merge windows 462 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) # (B*windows,Wh,Ww,C) 463 | 464 | # reverse cyclic shift 恢复循环移位操作 465 | if self.shift_size > 0: 466 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C 467 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 468 | else: 469 | shifted_x = window_reverse(attn_windows, self.window_size, H, W) 470 | x = shifted_x 471 | x = x.view(B, H * W, C) 472 | 473 | x = shortcut + self.drop_path(x) 474 | 475 | # FFN 476 | if self.ffn_type == 'Mlp': 477 | x = x + self.drop_path(self.mlp(self.norm2(x))) 478 | else: 479 | shortcut2 = x 480 | x = rearrange(x, 'bs (h w) c -> bs c h w', bs=B, h=H, w=W).contiguous() # B C H W 481 | x = self.mlp(x) 482 | x = rearrange(x, 'bs c h w -> bs (h w) c').contiguous() # B N C 483 | x = self.drop_path(x) 484 | x = shortcut2+x 485 | return x 486 | 487 | 488 | class Global_Enhanced_BottleNeck_Block(nn.Module): 489 | r""" Global Enhanced BottleNeck Block, using global attention to model global information. 490 | 491 | Args: 492 | dim (int): Number of input channels. 493 | input_resolution (tuple[int]): Input resulotion. 这个 resolution 是 patch 的分辨率 494 | num_heads (int): Number of attention heads. 495 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 496 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 497 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 498 | drop (float, optional): Dropout rate. Default: 0.0 499 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 500 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 501 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 502 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 503 | ffn_type (string): FFN module,using Convolution or MLP, Default: Mlp 504 | drop_key (bool, optional): Using dropkey or not. Default: False 505 | 506 | 输入的 x 为 (B,N,C) 表示,不再划分窗口 507 | """ 508 | 509 | def __init__(self, dim, input_resolution, num_heads, 510 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 511 | act_layer=nn.GELU, norm_layer=nn.LayerNorm, ffn_type='Mlp', drop_key=False): 512 | super().__init__() 513 | self.dim = dim 514 | self.input_resolution = input_resolution 515 | self.num_heads = num_heads 516 | self.mlp_ratio = mlp_ratio 517 | self.ffn_type = ffn_type 518 | 519 | self.norm1 = norm_layer(dim) 520 | self.attn = ConditinoalAttention( 521 | dim, num_heads=num_heads, 522 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, drop_key=drop_key) 523 | 524 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 525 | self.norm2 = norm_layer(dim) 526 | mlp_hidden_dim = int(dim * mlp_ratio) 527 | if self.ffn_type == 'Mlp': 528 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 529 | else: 530 | self.mlp = ConvFFN(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 531 | 532 | def forward(self, x): 533 | H, W = self.input_resolution 534 | B, L, C = x.shape 535 | 536 | shortcut = x 537 | x = self.norm1(x) 538 | 539 | # W-MSA/SW-MSA 540 | x = self.attn(x) # B, N, C 541 | 542 | x = shortcut + self.drop_path(x) 543 | 544 | # FFN 545 | if self.ffn_type == 'Mlp': 546 | x = x + self.drop_path(self.mlp(self.norm2(x))) 547 | else: 548 | # 使用 ConvNext block 替换 549 | shortcut2 = x 550 | x = rearrange(x, 'bs (h w) c -> bs c h w', bs=B, h=H, w=W).contiguous() # B C H W 551 | x = self.mlp(x) 552 | x = rearrange(x, 'bs c h w -> bs (h w) c').contiguous() # B N C 553 | x = self.drop_path(x) 554 | x = shortcut2+x 555 | return x 556 | 557 | 558 | class GEB(nn.Module): 559 | r""" Global Enhanced BottleNeck. 560 | 561 | Args: 562 | dim (int): Number of input channels. 563 | input_resolution (tuple[int]): Input resulotion. 这个 resolution 是 patch 的分辨率 564 | num_heads (int): Number of attention heads. 565 | patch_size (int): Patch size o 566 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 567 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 568 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 569 | drop (float, optional): Dropout rate. Default: 0.0 570 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 571 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 572 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 573 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 574 | ffn_type (string): FFN module,using Convolution or MLP, Default: Mlp 575 | drop_key (bool, optional): Using dropkey or not. Default: False 576 | 577 | 输入的 x 为 (B,N,C) 表示,不再划分窗口 578 | """ 579 | 580 | def __init__(self, dim, input_resolution, num_heads, patch_size, depth, 581 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 582 | norm_layer=nn.LayerNorm, use_checkpoint=False, drop_key=False): 583 | super().__init__() 584 | self.use_checkpoint = use_checkpoint 585 | self.patch_size = patch_size 586 | self.blocks = nn.ModuleList([ 587 | Global_Enhanced_BottleNeck_Block(dim=dim, input_resolution=to_2tuple(input_resolution), 588 | num_heads=num_heads, 589 | mlp_ratio=mlp_ratio, 590 | qkv_bias=qkv_bias, qk_scale=qk_scale, 591 | drop=drop, attn_drop=attn_drop, 592 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 593 | norm_layer=norm_layer, drop_key=drop_key) 594 | for i in range(depth)]) 595 | self.pe_generator = PEG_Conv(dim) 596 | 597 | def forward(self, x): 598 | x = PatchDivide(x, self.patch_size) 599 | B_, N, C = x.shape 600 | x = rearrange(x, 'b (h w) c -> b c h w ', h=int(N**0.5), w=int(N**0.5)).contiguous() # 转换为图像表示 601 | pe = self.pe_generator(x) # 条件位置编码 602 | x = x+pe # 消融: PEG 603 | x = rearrange(x, ' b c h w -> b (h w) c').contiguous() # 转换成序列表示 604 | for blk in self.blocks: 605 | if self.use_checkpoint: 606 | x = checkpoint.checkpoint(blk, x) 607 | else: 608 | x = blk(x) 609 | x = PatchReverse(x, self.patch_size) 610 | return x 611 | 612 | 613 | class Swin_Transformer(nn.Module): 614 | """ A basic Swin Transformer layer for one stage. 615 | 616 | input: 617 | x: (B, N, C) 618 | 619 | Args: 620 | dim (int): Number of input channels. 621 | input_resolution (tuple[int]): Input resolution. 622 | depth (int): Number of blocks. 623 | num_heads (int): Number of attention heads. 624 | window_size (int): Local window size. 625 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 626 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 627 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 628 | drop (float, optional): Dropout rate. Default: 0.0 629 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 630 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 631 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 632 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 633 | fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False 634 | drop_key (bool, optional): Using dropkey or not. Default: False 635 | """ 636 | 637 | def __init__(self, dim, input_resolution, depth, num_heads, window_size, 638 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 639 | drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False, 640 | fused_window_process=False, drop_key=False): 641 | 642 | super().__init__() 643 | self.dim = dim 644 | self.input_resolution = input_resolution 645 | self.depth = depth 646 | self.use_checkpoint = use_checkpoint 647 | 648 | # build blocks 649 | # 构建一个 Swin Transformer Block 650 | self.blocks = nn.ModuleList([ 651 | SwinTransformerBlock(dim=dim, input_resolution=input_resolution, 652 | num_heads=num_heads, window_size=window_size, 653 | shift_size=0 if (i % 2 == 0) else window_size // 2, # 消融:滑动窗口 654 | mlp_ratio=mlp_ratio, 655 | qkv_bias=qkv_bias, qk_scale=qk_scale, 656 | drop=drop, attn_drop=attn_drop, 657 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 658 | norm_layer=norm_layer, 659 | fused_window_process=fused_window_process, drop_key=drop_key) 660 | for i in range(depth)]) 661 | 662 | def forward(self, x): 663 | for blk in self.blocks: 664 | if self.use_checkpoint: 665 | x = checkpoint.checkpoint(blk, x) 666 | else: 667 | x = blk(x) 668 | return x 669 | 670 | 671 | class Channel_Adaptive_Transformer_Block(nn.Module): 672 | """Channel Adaptive Transformer Block (CATB). 673 | 674 | input: 675 | x: (B, C, H, W), feature map 676 | 677 | Args: 678 | dim (int): Number of input channels. 679 | input_resolution (int): Input resolution. based on patch 680 | num_heads (int): Number of attention heads. 681 | window_size (int): Local window size. 682 | patch_size (int): the size of the patch 683 | depth (int): Number of Swin Transformer blocks, default 2. 684 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 685 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 686 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 687 | drop (float, optional): Dropout rate. Default: 0.0 688 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 689 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 690 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 691 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 692 | fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False 693 | drop_key (bool, optional): Using dropkey or not. Default: False 694 | """ 695 | 696 | def __init__(self, dim, input_resolution, num_heads, window_size, patch_size, 697 | depth=2, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 698 | drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False, 699 | fused_window_process=False, drop_key=False): 700 | 701 | super().__init__() 702 | self.transformer_block = Swin_Transformer(dim, to_2tuple(input_resolution), depth, num_heads, window_size, mlp_ratio, qkv_bias, 703 | qk_scale, drop, attn_drop, drop_path, norm_layer, use_checkpoint, fused_window_process, drop_key=drop_key) 704 | self.CA = Channel_Adapter(dim//(patch_size**2), input_resolution*patch_size) 705 | self.patch_size = patch_size 706 | self.window_size = window_size 707 | self.input_resolution = input_resolution 708 | 709 | def forward(self, x): 710 | x = self.CA(x) # 消融: CA 711 | x = PatchDivide(x, self.patch_size) 712 | x = self.transformer_block(x) 713 | x = PatchReverse(x, self.patch_size) 714 | return x 715 | 716 | 717 | class CATB_Layer(nn.Module): 718 | """Channel Adaptive Transformer Block Layer, compose with numbers of CATBs. 719 | 720 | input: 721 | x: (B, C, H, W), feature map 722 | 723 | Args: 724 | dim (int): Number of input channels. 725 | input_resolution (int): Input resolution. based on patch 726 | num_heads (int): Number of attention heads. 727 | window_size (int): Local window size. 728 | patch_size (int): the size of the patch 729 | depth (int): Number of CATB, default 1. 730 | depth_tr (int): Number of Swin Transformer blocks in each CATB, default 2. 731 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 732 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 733 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 734 | drop (float, optional): Dropout rate. Default: 0.0 735 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 736 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 737 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 738 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 739 | fused_window_process (bool, optional): If True, use one kernel to fused window shift & window partition for acceleration, similar for the reversed part. Default: False 740 | drop_key (bool, optional): Using dropkey or not. Default: False 741 | """ 742 | 743 | def __init__(self, dim, input_resolution, num_heads, window_size, patch_size, depth=1, 744 | depth_tr=2, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 745 | drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False, 746 | fused_window_process=False, drop_key=False): 747 | 748 | super().__init__() 749 | self.use_checkpoint = use_checkpoint 750 | self.blocks = nn.ModuleList([ 751 | Channel_Adaptive_Transformer_Block(dim, input_resolution, num_heads, window_size, patch_size, depth_tr, mlp_ratio, qkv_bias, 752 | qk_scale, drop, attn_drop, drop_path, norm_layer, use_checkpoint, fused_window_process, drop_key)for i in range(depth)]) 753 | 754 | def forward(self, x): 755 | for blk in self.blocks: 756 | if self.use_checkpoint: 757 | x = checkpoint.checkpoint(blk, x) 758 | else: 759 | x = blk(x) 760 | return x 761 | 762 | 763 | class StegFormer(nn.Module): 764 | """StegFormer 765 | 766 | Args: 767 | img_resolution (int): Resolution of images 768 | input_dim(int): Dim of the input 769 | output_dim (int): Dim of the finnal output 770 | cnn_emb_dim (int): Embedding dim using convolution 771 | output_act (nn.Module): The act function in the end of StegFormer, Default: None 772 | patch_size(int): the size of the patch 773 | num_heads (list): Number of attention heads. 774 | window_size (int): Local window size 775 | depth (list): Number of CATB in each CATB Layer 776 | depth_tr (list): Number of the Swin Transformer Block in each CATB 777 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 778 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 779 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 780 | drop (float, optional): Dropout rate. Default: 0.0 781 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 782 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 783 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 784 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 785 | drop_key (bool, optional): Using dropkey or not. Default: False 786 | """ 787 | 788 | def __init__(self, img_resolution, input_dim=3, output_dim=3, cnn_emb_dim=16, output_act=None, patch_size=2, num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2], window_size=8, 789 | depth=[1, 1, 1, 1, 2, 1, 1, 1, 1], depth_tr=[2, 2, 2, 2, 2, 2, 2, 2], mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 790 | drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False, drop_key=False): 791 | super().__init__() 792 | self.dim = cnn_emb_dim 793 | self.token_dim = (patch_size**2)*cnn_emb_dim 794 | self.embedding = nn.Conv2d(input_dim, self.dim, 3, 1, 1) 795 | self.patch_size = patch_size 796 | self.patch_resolution = img_resolution//patch_size 797 | if output_act: 798 | self.output_act = output_act() 799 | else: 800 | self.output_act = None 801 | 802 | # encoder 803 | self.encoderlayer_0 = CATB_Layer(self.token_dim, self.patch_resolution, num_heads[0], 804 | window_size, patch_size, depth[0], depth_tr[0], mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, drop_path, norm_layer, use_checkpoint, drop_key=drop_key) 805 | self.downsampler_0 = DownSampler(self.dim, self.dim*2) 806 | 807 | self.encoderlayer_1 = CATB_Layer(self.token_dim*2, self.patch_resolution//(2**1), 808 | num_heads[1], window_size, patch_size, depth[1], depth_tr[1], mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, drop_path, norm_layer, use_checkpoint, drop_key=drop_key) 809 | self.downsampler_1 = DownSampler(self.dim*2, self.dim*4) 810 | 811 | self.encoderlayer_2 = CATB_Layer(self.token_dim*4, self.patch_resolution//(2**2), 812 | num_heads[2], window_size, patch_size, depth[2], depth_tr[2], mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, drop_path, norm_layer, use_checkpoint, drop_key=drop_key) 813 | self.downsampler_2 = DownSampler(self.dim*4, self.dim*8) 814 | 815 | self.encoderlayer_3 = CATB_Layer(self.token_dim*8, self.patch_resolution//(2**3), 816 | num_heads[3], window_size, patch_size, depth[3], depth_tr[3], mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, drop_path, norm_layer, use_checkpoint, drop_key=drop_key) 817 | self.downsampler_3 = DownSampler(self.dim*8, self.dim*16) 818 | 819 | # bottleneck 820 | self.bottleneck = GEB(self.token_dim*16, self.patch_resolution//(2**4), 821 | num_heads[4], patch_size, depth[4], mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, drop_path, norm_layer, use_checkpoint, drop_key=drop_key) 822 | 823 | # decoder 824 | self.upsampler_0 = UpSampler(self.dim*16, self.dim*8) 825 | self.decoderlayer_0 = CATB_Layer(self.token_dim*16, self.patch_resolution//(2**3), 826 | num_heads[5], window_size, patch_size, depth[5], depth_tr[4], mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, drop_path, norm_layer, use_checkpoint, drop_key=drop_key) 827 | 828 | self.upsampler_1 = UpSampler(self.dim*16, self.dim*4) 829 | self.decoderlayer_1 = CATB_Layer(self.token_dim*8, self.patch_resolution//(2**2), 830 | num_heads[6], window_size, patch_size, depth[6], depth_tr[5], mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, drop_path, norm_layer, use_checkpoint, drop_key=drop_key) 831 | 832 | self.upsampler_2 = UpSampler(self.dim*8, self.dim*2) 833 | self.decoderlayer_2 = CATB_Layer(self.token_dim*4, self.patch_resolution//(2**1), 834 | num_heads[7], window_size, patch_size, depth[7], depth_tr[6], mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, drop_path, norm_layer, use_checkpoint, drop_key=drop_key) 835 | 836 | self.upsampler_3 = UpSampler(self.dim*4, self.dim) 837 | self.decoderlayer_3 = CATB_Layer(self.token_dim*2, self.patch_resolution, 838 | num_heads[8], window_size, patch_size, depth[8], depth_tr[7], mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, drop_path, norm_layer, use_checkpoint, drop_key=drop_key) 839 | 840 | self.output_proj = nn.Conv2d(self.dim*2, output_dim, 3, 1, 1) 841 | 842 | def forward(self, x): 843 | x = self.embedding(x) 844 | 845 | # encode 846 | conv0 = self.encoderlayer_0(x) 847 | pool0 = self.downsampler_0(conv0) 848 | 849 | conv1 = self.encoderlayer_1(pool0) 850 | pool1 = self.downsampler_1(conv1) 851 | 852 | conv2 = self.encoderlayer_2(pool1) 853 | pool2 = self.downsampler_2(conv2) 854 | 855 | conv3 = self.encoderlayer_3(pool2) 856 | pool3 = self.downsampler_3(conv3) 857 | 858 | # bottle neck 859 | bottle = self.bottleneck(pool3) 860 | 861 | # decode 862 | up0 = self.upsampler_0(bottle) 863 | deconv0 = torch.cat([up0, conv3], 1) 864 | deconv0 = self.decoderlayer_0(deconv0) 865 | 866 | up1 = self.upsampler_1(deconv0) 867 | deconv1 = torch.cat([up1, conv2], 1) 868 | deconv1 = self.decoderlayer_1(deconv1) 869 | 870 | up2 = self.upsampler_2(deconv1) 871 | deconv2 = torch.cat([up2, conv1], 1) 872 | deconv2 = self.decoderlayer_2(deconv2) 873 | 874 | up3 = self.upsampler_3(deconv2) 875 | deconv3 = torch.cat([up3, conv0], 1) 876 | deconv3 = self.decoderlayer_3(deconv3) 877 | 878 | output = self.output_proj(deconv3) 879 | if self.output_act: 880 | output = self.output_act(output) 881 | return output 882 | --------------------------------------------------------------------------------