├── LICENSE ├── README.md ├── config ├── __init__.py ├── config.py └── loader.py ├── models ├── __init__.py ├── layers.py └── net.py ├── sample └── rgb.png ├── test.py ├── test └── McM18 │ ├── 1.tif │ ├── 10.tif │ ├── 11.tif │ ├── 12.tif │ ├── 13.tif │ ├── 14.tif │ ├── 15.tif │ ├── 16.tif │ ├── 17.tif │ ├── 18.tif │ ├── 2.tif │ ├── 3.tif │ ├── 4.tif │ ├── 5.tif │ ├── 6.tif │ ├── 7.tif │ ├── 8.tif │ └── 9.tif ├── train.py ├── train.sh └── trained_models ├── 1 └── state_dict.pth ├── 4 └── state_dict.pth ├── 10 └── state_dict.pth └── 25 └── state_dict.pth /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 ICSResearch 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TCS-Net 2 | This repository is the `pytorch` code for paper `"From Patch to Pixel: A Transformer-based Hierarchical Framework for Compressive Image Sensing"`. 3 | ## 1. Introduction ## 4 | **1) Datasets** 5 | 6 | Training set: [`BSDS500`](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html), testing sets: [`McM18`](https://www4.comp.polyu.edu.hk/~cslzhang/CDM_Dataset.html), [`LIVE29`](http://live.ece.utexas.edu/research/Quality/), [`General100`](http://mmlab.ie.cuhk.edu.hk/projects/FSRCNN.html) and [`OST300`](http://mmlab.ie.cuhk.edu.hk/projects/SFTGAN/). 7 | 8 | **2)Project structure** 9 | ``` 10 | (TCS-Net) 11 | |-dataset 12 | | |-train 13 | | |-BSDS500 (.jpg) 14 | | |-test 15 | | |-McM18 16 | | |-LIVE29 17 | | |-General100 18 | | |-OST300 19 | |-reconstructed_images 20 | | |-McM18 21 | | |-grey 22 | | |-... (Testing results .png) 23 | | |-rgb 24 | | |-... (Testing results .png) 25 | | |-... (Testing sets) 26 | | |-Res_(...).txt 27 | |-models 28 | | |-__init__.py 29 | | |-net.py 30 | | |-modules.py 31 | |-trained_models 32 | | |-1 33 | | |-4 34 | | |-... (Sampling rates) 35 | |-config 36 | | |-__init__.py 37 | | |-config.py 38 | | |-loader.py 39 | |-test.py 40 | |-train.py 41 | |-train.sh 42 | ``` 43 | 44 | **3) Competting methods** 45 | 46 | |Methods|Sources|Year| 47 | |:----|:----|:----| 48 | | ![ReconNet](https://latex.codecogs.com/svg.image?\textbf{ReconNet})| [Conf. Comput. Vis. Pattern Recog.](https://ieeexplore.ieee.org/document/7780424/) | 2016 | 49 | | ![LDIT](https://latex.codecogs.com/svg.image?\textbf{LDIT}) | [Proc. Adv. Neural Inf. Process. Syst.](https://dl.acm.org/doi/10.5555/3294771.3294940) | 2017 | 50 | | ![LDAMP](https://latex.codecogs.com/svg.image?\textbf{LDAMP}) | [Proc. Adv. Neural Inf. Process. Syst.](https://dl.acm.org/doi/10.5555/3294771.3294940) | 2017 | 51 | | ![ISTA-Net (plus)](https://latex.codecogs.com/svg.image?\textbf{ISTA-Net}^{+}) | [Conf. Comput. Vis. Pattern Recog.](https://ieeexplore.ieee.org/document/8578294) | 2018 | 52 | | ![CSGAN](https://latex.codecogs.com/svg.image?\textbf{CSGAN}) | [Proc. Int. Conf. Mach. Learn.](http://proceedings.mlr.press/v97/wu19d.html) | 2019 | 53 | | ![CSNet (plus)](https://latex.codecogs.com/svg.image?\textbf{CSNet}^{+}) | [Trans. Image Process.](https://ieeexplore.ieee.org/document/8765626/) | 2020 | 54 | | ![AMP-Net](https://latex.codecogs.com/svg.image?\textbf{AMP-Net}) | [Trans. Image Process.](https://ieeexplore.ieee.org/document/9298950) | 2021 | 55 | |CSformer| arXiv | 2022 | 56 | 57 | 58 | **4) Performance demonstrates** 59 | 60 | Visual comparisons of reconstruction images (original images are drawn from dataset `LIVE29`): 61 | 62 |
63 | 64 | ## 2. Useage ## 65 | **1) Re-training TCS-Net.** 66 | 67 | * Put the `BSDS500` and `VOC2012` images into `./dataset/train/`. 68 | * e.g., If you want to train TCS-Net at sampling rate `τ = 0.1` with `GPU No.0`, please run the following command. The train set will be automatically packaged and our model will be trained with its default parameters (please make sure you have enough GPU RAM): 69 | ``` 70 | python train.py --rate 0.1 --GPU 0 71 | ``` 72 | * You can also run our shell script directly as well, it will automatically train the model under all sampling rates, i.e., `τ ∈ {0.01, 0.04, 0.1, 0.25}`: 73 | ``` 74 | sh train.sh 75 | ``` 76 | * The trained models (.pth) will save in the `trained_models` folder. 77 | 78 | **2) Testing TCS-Net.** 79 | * We provide the trained models so that you can put them under `TCS-Net/trained_models/` and use them for testing directly; all trained TCS-Net models can be found in this [GoogleDrive link](https://drive.google.com/drive/folders/15dRG29V51i8rVraz8TkHtev7N3jLkx0U?usp=sharing); Please note that the `folder's names` are the `100 times of sampling rates`, e.g., the folder named `10` includes trained models at `sampling rate = 0.1`. 80 | 81 | * Put the testing folders into `./dataset/test/`. 82 | * e.g., if you want to test TCS-Net at sampling rate τ = 0.1 with GPU No.0, please run: 83 | ``` 84 | python test.py --rate 0.1 --GPU 0 85 | ``` 86 | * After that, the reconstructed images, PSNR and SSIM results will be saved to `./reconstructed_images/`. 87 | ## End ## 88 | 89 | We appreciate your reading and attention. For more details about TCS-Net, please refer to our paper. 90 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import para 2 | from .loader import train_loader 3 | from .loader import TrainDatasetFromFolder 4 | -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | parser = argparse.ArgumentParser(description="Args of this repo.") 5 | parser.add_argument("--rate", default=0.1, type=float) 6 | parser.add_argument("--lr", default=0.001, type=float) 7 | parser.add_argument("--device", default="0") 8 | parser.add_argument("--time", default=0, type=int) 9 | parser.add_argument("--block_size", default=96, type=int) 10 | parser.add_argument("--batch_size", default=32, type=int) 11 | 12 | parser.add_argument("--save", default=False) 13 | parser.add_argument("--manner", default="grey") 14 | 15 | parser.add_argument("--save_path", default=f"./trained_models") 16 | parser.add_argument("--folder") 17 | parser.add_argument("--my_state_dict") 18 | parser.add_argument("--my_log") 19 | parser.add_argument("--my_info") 20 | para = parser.parse_args() 21 | para.device = f"cuda:{para.device}" 22 | para.folder = f"{para.save_path}/{str(int(para.rate * 100))}/" 23 | para.my_state_dict = f"{para.folder}/state_dict.pth" 24 | para.my_log = f"{para.folder}/log.txt" 25 | para.my_info = f"{para.folder}/info.pth" 26 | -------------------------------------------------------------------------------- /config/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import time 4 | import copy 5 | import torch 6 | import random 7 | import numpy as np 8 | import torchvision 9 | from PIL import Image 10 | from tqdm import tqdm 11 | import torch.utils.data as data 12 | 13 | import config 14 | 15 | 16 | class TrainDataPackage: 17 | r""" 18 | Packaged images dataset from BSD500/train and BSD500/test (total 400 images) to *.pth, 19 | We also use VerticalFlip, HorizontalFlip and random generate function from one image (*.jpg) to 150 patches (tensor) 20 | to enhance our dataset. 21 | """ 22 | 23 | def __init__(self, root="./dataset", transform=None, packaged=True): 24 | self.training_file = "train.pt" 25 | self.aug = DataAugment(debug=True) 26 | self.packaged = packaged 27 | self.root = root 28 | self.num = 50 # 50 29 | self.transform = transform or torchvision.transforms.Compose([ 30 | torchvision.transforms.ToTensor(), 31 | torchvision.transforms.RandomVerticalFlip(), 32 | torchvision.transforms.RandomHorizontalFlip(), 33 | torchvision.transforms.Grayscale(num_output_channels=1), 34 | ]) 35 | 36 | if not (os.path.exists(os.path.join(self.root, self.training_file))): 37 | print("No packaged dataset file (*.pt) in dataset/, now generating...") 38 | self.generate() 39 | 40 | if packaged: 41 | self.train_data = torch.load(os.path.join(self.root, self.training_file)) 42 | 43 | def __len__(self): 44 | return len(self.train_data) 45 | 46 | def __getitem__(self, index): 47 | img = self.train_data[index] 48 | return img 49 | 50 | def generate(self): 51 | paths = [ 52 | os.path.join(self.root, "BSD500/"), 53 | os.path.join(self.root, "VOCdevkit/VOC2012/JPEGImages/"), 54 | ] 55 | patches_list = [] 56 | 57 | start = time.time() 58 | for path in paths: 59 | for roots, dirs, files in os.walk(path): 60 | if roots == path: 61 | print("Image number: {}".format(files.__len__())) 62 | for file in tqdm(files): 63 | if file.split('.')[1] == "jpg" or "png" or "tif" or "bmp": 64 | temp = os.path.join(path, file) 65 | tqdm.write("\r=> Processing " + temp) 66 | image = cv2.imread(temp) 67 | patches = self.random_patch(image, self.num) 68 | patches_list.extend(patches) 69 | print("Total patches: {}".format(patches_list.__len__())) 70 | print("Now Packaging...") 71 | with open(os.path.join(self.root, "train.pt"), 'wb') as f: 72 | torch.save(patches_list, f) 73 | end = time.time() 74 | print("Successfully packaged!, used time: {:.3f}".format(end - start)) 75 | 76 | def random_patch(self, image, num): 77 | size = config.para.block_size 78 | image = np.array(image, dtype=np.float32) / 255. 79 | h, w = image.shape[0], image.shape[1] 80 | if h <= config.para.block_size or w <= config.para.block_size: 81 | return [] 82 | patches = [] 83 | for n in range(num): 84 | max_h = random.randint(0, h - size) 85 | max_w = random.randint(0, w - size) 86 | patch = image[max_h:max_h + size, max_w:max_w + size, :] 87 | patch = self.transform(patch) 88 | patches.append(patch) 89 | return patches 90 | 91 | 92 | def train_loader(): 93 | train_dataset = TrainDataPackage() 94 | dst = torch.utils.data.DataLoader(train_dataset, batch_size=config.para.batch_size, drop_last=True, shuffle=True, 95 | pin_memory=True, num_workers=8, prefetch_factor=8) 96 | return dst 97 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .net import TCS_Net 2 | from .layers import OneUnit, Units 3 | from .modules import TransformerEncoder, TransformerDecoder 4 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | from einops.layers.torch import Rearrange 5 | 6 | 7 | class FeedForward(nn.Module): 8 | def __init__(self, dim, hidden_dim, dropout=0.): 9 | super().__init__() 10 | self.net = nn.Sequential( 11 | nn.Linear(dim, hidden_dim), 12 | nn.ReLU(), 13 | nn.Dropout(dropout), 14 | nn.Linear(hidden_dim, dim), 15 | ) 16 | 17 | def forward(self, x): 18 | return self.net(x) 19 | 20 | 21 | class ScaledDotProduct(nn.Module): 22 | def __init__(self): 23 | super(ScaledDotProduct, self).__init__() 24 | self.softmax = nn.Softmax(dim=-1) 25 | 26 | def forward(self, q, k, v, scale=None): 27 | attention = torch.bmm(q, k.transpose(-2, -1)) 28 | if scale: 29 | attention = attention * scale 30 | attention = self.softmax(attention) 31 | context = torch.bmm(attention, v) 32 | return context, attention 33 | 34 | 35 | class MultiHeadAttention(nn.Module): 36 | def __init__(self, model_dim=16 ** 2, num_heads=8, dropout=0., out_dim=None): 37 | super(MultiHeadAttention, self).__init__() 38 | self.dim_per_head = model_dim // num_heads 39 | self.num_heads = num_heads 40 | self.linear_k = nn.Linear(model_dim, model_dim) 41 | self.linear_v = nn.Linear(model_dim, model_dim) 42 | self.linear_q = nn.Linear(model_dim, model_dim) 43 | 44 | self.dot_product_attention = ScaledDotProduct() 45 | self.linear_out = nn.Linear(model_dim, out_dim if out_dim is not None else model_dim) 46 | self.dropout = nn.Dropout(dropout) 47 | 48 | def _reshape_to_heads(self, x): 49 | batch_size, seq_len, in_feature = x.size() 50 | sub_dim = in_feature // self.num_heads 51 | return x.reshape(batch_size, seq_len, self.num_heads, sub_dim) \ 52 | .permute(0, 2, 1, 3) \ 53 | .reshape(batch_size * self.num_heads, seq_len, sub_dim) 54 | 55 | def _reshape_from_heads(self, x): 56 | batch_size, seq_len, in_feature = x.size() 57 | batch_size //= self.num_heads 58 | out_dim = in_feature * self.num_heads 59 | return x.reshape(batch_size, self.num_heads, seq_len, in_feature) \ 60 | .permute(0, 2, 1, 3) \ 61 | .reshape(batch_size, seq_len, out_dim) 62 | 63 | def forward(self, x): 64 | key = self._reshape_to_heads(self.linear_k(x)) 65 | value = self._reshape_to_heads(self.linear_v(x)) 66 | query = self._reshape_to_heads(self.linear_q(x)) 67 | 68 | scale = key.size(-1) ** -0.5 69 | context, attention = self.dot_product_attention(query, key, value, scale) 70 | 71 | context = self._reshape_from_heads(context) 72 | 73 | output = self.linear_out(context) 74 | output = self.dropout(output) 75 | return output, attention 76 | 77 | 78 | class OneUnit(nn.Module): 79 | def __init__(self, dim=16 ** 2, heads=8, dropout=0., out_dim=None): 80 | super().__init__() 81 | self.dim = dim 82 | self.flag = out_dim 83 | # self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout) 84 | self.attn = MultiHeadAttention(model_dim=dim, num_heads=heads, dropout=dropout, out_dim=out_dim) 85 | self.ffn = FeedForward(dim=dim if out_dim is None else out_dim, hidden_dim=int(dim * 4), dropout=dropout) 86 | self.norm1 = nn.LayerNorm(dim if out_dim is None else out_dim) 87 | self.norm2 = nn.LayerNorm(dim if out_dim is None else out_dim) 88 | 89 | def forward(self, x): 90 | context, attn = self.attn(x) 91 | x = self.norm1(context + x) if self.flag is None else self.norm1(context) 92 | ffn = self.ffn(x) 93 | x = self.norm2(ffn + x) 94 | return x, attn 95 | 96 | 97 | class Units(nn.Module): 98 | def __init__(self, dim=16, depth=1, heads=8, dropout=0., out_dim=None): 99 | super().__init__() 100 | self.patch_size = dim 101 | self.dim = self.patch_size ** 2 102 | self.depth = depth 103 | self.layers = nn.ModuleList() 104 | for _ in range(depth): 105 | self.layers.append(OneUnit(dim=self.dim, heads=heads, dropout=dropout, out_dim=out_dim)) 106 | 107 | def forward(self, x): 108 | x = Rearrange('b l h w -> b l (h w)')(x) 109 | x = x * math.sqrt(self.dim) 110 | attn_maps = [] 111 | for i in range(self.depth): 112 | x, attn = self.layers[i](x) 113 | attn_maps.append(attn) 114 | x = Rearrange('b l (h w)-> b l h w', h=self.patch_size, w=self.patch_size)(x) 115 | return x, attn_maps 116 | -------------------------------------------------------------------------------- /models/net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | import config 6 | import models 7 | 8 | 9 | class TCS_Net(nn.Module): 10 | def __init__(self): 11 | super(TCS_Net, self).__init__() 12 | self.block_size = config.para.block_size 13 | 14 | self.NUM_UNITS_PATCH = 6 15 | self.NUM_UNITS_PIXEL = 1 16 | 17 | self.DIM = 32 18 | self.PIXEL_EMBED = 32 19 | self.PATCH_DIM = 8 # todo 20 | self.idx = config.para.block_size // self.DIM 21 | 22 | # pixel-wise 23 | pixel_embedding = np.random.normal(0.0, (1 / self.PIXEL_EMBED) ** 0.5, size=(1, self.PIXEL_EMBED)) 24 | self.pixel_embedding = nn.Parameter(torch.from_numpy(pixel_embedding).float(), requires_grad=True) 25 | self.transform_pixel_wise = nn.ModuleList() 26 | for i in range(self.NUM_UNITS_PIXEL): 27 | self.transform_pixel_wise.append(models.OneUnit(dim=self.PIXEL_EMBED, dropout=0.5, heads=1)) 28 | pixel_detaching = np.random.normal(0.0, (1 / self.PIXEL_EMBED) ** 0.5, size=(self.PIXEL_EMBED, 1)) 29 | self.pixel_detaching = nn.Parameter(torch.from_numpy(pixel_detaching).float(), requires_grad=True) 30 | 31 | # patch-wise 32 | self.transform_patch_wise = models.Units(dim=self.PATCH_DIM, depth=self.NUM_UNITS_PATCH, dropout=0.5, heads=8) 33 | # self.transform_patch_wise = PatchTransform( 34 | # dim=self.PATCH_DIM ** 2, num_layers=self.NUM_UNITS_PATCH, drop_out=0.5) 35 | 36 | # sampling and init recon 37 | points = self.DIM ** 2 38 | p_init = np.random.normal(0.0, (1 / points) ** 0.5, size=(points, int(config.para.rate * points))) 39 | self.P = nn.Parameter(torch.from_numpy(p_init).float(), requires_grad=True) 40 | self.R = nn.Parameter(torch.from_numpy(np.transpose(p_init)).float(), requires_grad=True) 41 | 42 | w_init = 1e-3 * np.ones(self.DIM) 43 | self.w = nn.Parameter(torch.from_numpy(w_init).float(), requires_grad=True) 44 | 45 | def forward(self, inputs): 46 | batch_size = inputs.size(0) 47 | y = self.sampling(inputs) 48 | recon, tmp = self.recon(y, batch_size) 49 | return recon, tmp 50 | 51 | def sampling(self, inputs): 52 | inputs = inputs.to(config.para.device) 53 | inputs = torch.cat(torch.split(inputs, split_size_or_sections=self.DIM, dim=3), dim=0) 54 | inputs = torch.cat(torch.split(inputs, split_size_or_sections=self.DIM, dim=2), dim=0) 55 | inputs = inputs.reshape(-1, self.DIM ** 2) 56 | y = torch.matmul(inputs, self.P.to(config.para.device)) 57 | return y 58 | 59 | def recon(self, y, batch_size): 60 | # init 61 | init = torch.matmul(y, self.R.to(config.para.device)).reshape(-1, 1, self.DIM, self.DIM) 62 | 63 | recon = torch.cat(torch.split(init, split_size_or_sections=batch_size * self.idx, dim=0), dim=2) 64 | recon = torch.cat(torch.split(recon, split_size_or_sections=batch_size, dim=0), dim=3) 65 | 66 | # patch 67 | recon = self.block2patch(recon, self.PATCH_DIM) 68 | patch = self.transform_patch_wise(recon) 69 | recon = recon - patch[0] 70 | recon = self.patch2block(recon, self.block_size, self.PATCH_DIM) 71 | 72 | recon = torch.cat(torch.split(recon, split_size_or_sections=self.DIM, dim=3), dim=0) 73 | recon = torch.cat(torch.split(recon, split_size_or_sections=self.DIM, dim=2), dim=0).reshape(-1, self.DIM ** 2, 1) 74 | 75 | # pixel 76 | recon_pixel = torch.matmul(recon, self.pixel_embedding) 77 | for i in range(self.NUM_UNITS_PIXEL): 78 | recon_pixel, _ = self.transform_pixel_wise[i](recon_pixel) 79 | recon_pixel = torch.matmul(recon_pixel, self.pixel_detaching) 80 | recon = recon.reshape(-1, 1, self.DIM, self.DIM) - recon_pixel.reshape(-1, 1, self.DIM, self.DIM) * self.w 81 | 82 | recon = torch.cat(torch.split(recon, split_size_or_sections=batch_size * self.idx, dim=0), dim=2) 83 | recon = torch.cat(torch.split(recon, split_size_or_sections=batch_size, dim=0), dim=3) 84 | 85 | return recon, None 86 | 87 | @staticmethod 88 | def block2patch(inputs, size): 89 | inputs = torch.cat(torch.split(inputs, split_size_or_sections=size, dim=3), dim=1) 90 | inputs = torch.cat(torch.split(inputs, split_size_or_sections=size, dim=2), dim=1) 91 | return inputs 92 | 93 | @staticmethod 94 | def patch2block(inputs, block_size, patch_size, ori_channel=1): 95 | assert block_size % patch_size == 0, f"block size {block_size} should be divided by patch size {patch_size}." 96 | idx = int(block_size / patch_size) 97 | outputs = torch.cat(torch.split(inputs, split_size_or_sections=ori_channel * idx, dim=1), dim=2) 98 | outputs = torch.cat(torch.split(outputs, split_size_or_sections=ori_channel, dim=1), dim=3) 99 | return outputs 100 | -------------------------------------------------------------------------------- /sample/rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/sample/rgb.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import cv2 as cv 4 | import numpy as np 5 | from skimage.metrics import structural_similarity as SSIM 6 | from skimage.metrics import peak_signal_noise_ratio as PSNR 7 | 8 | import config 9 | import models 10 | 11 | 12 | def save_log(recon_root, name_dataset, name_image, psnr, ssim, manner, consecutive=True): 13 | if not os.path.isfile(f"{recon_root}/Res_{name_dataset}_{manner}.txt"): 14 | log = open(f"{recon_root}/Res_{name_dataset}_{manner}.txt", 'w') 15 | log.write("=" * 120 + "\n") 16 | log.close() 17 | log = open(f"{recon_root}/Res_{name_dataset}_{manner}.txt", 'r+') 18 | if consecutive: 19 | old = log.read() 20 | log.seek(0) 21 | log.write(old) 22 | log.write( 23 | f"Res {name_image}: PSNR, {round(psnr, 2)}, SSIM, {round(ssim, 4)}\n") 24 | log.close() 25 | 26 | 27 | def testing(network, val, manner=config.para.manner, save_img=config.para.save): 28 | """ 29 | The pre-processing before TCS-Net's forward propagation and the testing platform. 30 | """ 31 | recon_root = "reconstructed_images" 32 | if not os.path.isdir(recon_root): 33 | os.mkdir(recon_root) 34 | datasets = ["Set11"] if val else ["McM18", "LIVE29", "General100", "OST300"] # Names of folders (testing datasets) 35 | with torch.no_grad(): 36 | for one_dataset in datasets: 37 | if not os.path.isdir(f"{recon_root}/{one_dataset}"): 38 | os.mkdir(f"{recon_root}/{one_dataset}") 39 | 40 | test_dataset_path = f"dataset/test/{one_dataset}" 41 | 42 | # Grey manner 43 | if manner == "grey": 44 | recon_dataset_path_grey = f"{recon_root}/{one_dataset}/grey/" 45 | recon_dataset_path_grey_rate = f"{recon_root}/{one_dataset}/grey/{config.para.rate}" 46 | if not os.path.isdir(recon_dataset_path_grey): 47 | os.mkdir(recon_dataset_path_grey) 48 | if not os.path.isdir(recon_dataset_path_grey_rate): 49 | os.mkdir(recon_dataset_path_grey_rate) 50 | sum_psnr, sum_ssim = 0., 0. 51 | for _, _, images in os.walk(f"{test_dataset_path}/rgb/"): 52 | for one_image in images: 53 | name_image = one_image.split('.')[0] 54 | x = cv.imread(f"{test_dataset_path}/rgb/{one_image}", flags=cv.IMREAD_GRAYSCALE) 55 | x_ori = x 56 | x = torch.from_numpy(x / 255.).float() 57 | h, w = x.size() 58 | 59 | lack = config.para.block_size - h % config.para.block_size if h % config.para.block_size != 0 else 0 60 | padding_h = torch.zeros(lack, w) 61 | expand_h = h + lack 62 | inputs = torch.cat((x, padding_h), 0) 63 | 64 | lack = config.para.block_size - w % config.para.block_size if w % config.para.block_size != 0 else 0 65 | expand_w = w + lack 66 | padding_w = torch.zeros(expand_h, lack) 67 | inputs = torch.cat((inputs, padding_w), 1).unsqueeze(0).unsqueeze(0) 68 | 69 | inputs = torch.cat(torch.split(inputs, split_size_or_sections=config.para.block_size, dim=3), dim=0) 70 | inputs = torch.cat(torch.split(inputs, split_size_or_sections=config.para.block_size, dim=2), dim=0) 71 | 72 | reconstruction, _ = network(inputs) 73 | 74 | idx = expand_w // config.para.block_size 75 | reconstruction = torch.cat(torch.split(reconstruction, split_size_or_sections=1 * idx, dim=0), dim=2) 76 | reconstruction = torch.cat(torch.split(reconstruction, split_size_or_sections=1, dim=0), dim=3) 77 | reconstruction = reconstruction.squeeze()[:h, :w] 78 | 79 | x_hat = reconstruction.cpu().numpy() * 255. 80 | x_hat = np.rint(np.clip(x_hat, 0, 255)) 81 | 82 | psnr = PSNR(x_ori, x_hat, data_range=255) 83 | ssim = SSIM(x_ori, x_hat, data_range=255, multichannel=False) 84 | 85 | sum_psnr += psnr 86 | sum_ssim += ssim 87 | 88 | if save_img: 89 | cv.imwrite(f"{recon_dataset_path_grey_rate}/{name_image}.png", x_hat) 90 | save_log(recon_root, one_dataset, name_image, psnr, ssim, f"_{config.para.rate}_{manner}") 91 | save_log(recon_root, one_dataset, None, 92 | sum_psnr / len(images), sum_ssim / len(images), f"_{config.para.rate}_{manner}_AVG", False) 93 | print( 94 | f"AVG RES: PSNR, {round(sum_psnr / len(images), 2)}, SSIM, {round(sum_ssim / len(images), 4)}") 95 | if val: 96 | return round(sum_psnr / len(images), 2), round(sum_ssim / len(images), 4) 97 | 98 | # RGB manner 99 | elif manner == "rgb": 100 | recon_dataset_path_rgb = f"{recon_root}/{one_dataset}/rgb/" 101 | recon_dataset_path_rgb_rate = f"{recon_root}/{one_dataset}/rgb/{config.para.rate}" 102 | if not os.path.isdir(recon_dataset_path_rgb): 103 | os.mkdir(recon_dataset_path_rgb) 104 | if not os.path.isdir(recon_dataset_path_rgb_rate): 105 | os.mkdir(recon_dataset_path_rgb_rate) 106 | sum_psnr, sum_ssim = 0., 0. 107 | for _, _, images in os.walk(f"{test_dataset_path}/rgb/"): 108 | for one_image in images: 109 | name_image = one_image.split('.')[0] 110 | x = cv.imread(f"{test_dataset_path}/rgb/{one_image}") 111 | x_ori = x 112 | r, g, b = cv.split(x) 113 | r = torch.from_numpy(np.asarray(r)).squeeze().float() / 255. 114 | g = torch.from_numpy(np.asarray(g)).squeeze().float() / 255. 115 | b = torch.from_numpy(np.asarray(b)).squeeze().float() / 255. 116 | 117 | x = torch.from_numpy(x).float() 118 | h, w = x.size()[0], x.size()[1] 119 | 120 | lack = config.para.block_size - h % config.para.block_size if h % config.para.block_size != 0 else 0 121 | padding_h = torch.zeros(lack, w) 122 | expand_h = h + lack 123 | inputs_r = torch.cat((r, padding_h), 0) 124 | inputs_g = torch.cat((g, padding_h), 0) 125 | inputs_b = torch.cat((b, padding_h), 0) 126 | 127 | lack = config.para.block_size - w % config.para.block_size if w % config.para.block_size != 0 else 0 128 | expand_w = w + lack 129 | padding_w = torch.zeros(expand_h, lack) 130 | inputs_r = torch.cat((inputs_r, padding_w), 1).unsqueeze(0).unsqueeze(0) 131 | inputs_g = torch.cat((inputs_g, padding_w), 1).unsqueeze(0).unsqueeze(0) 132 | inputs_b = torch.cat((inputs_b, padding_w), 1).unsqueeze(0).unsqueeze(0) 133 | 134 | inputs_r = torch.cat(torch.split(inputs_r, split_size_or_sections=config.para.block_size, dim=3), 135 | dim=0) 136 | inputs_r = torch.cat(torch.split(inputs_r, split_size_or_sections=config.para.block_size, dim=2), 137 | dim=0) 138 | 139 | inputs_g = torch.cat(torch.split(inputs_g, split_size_or_sections=config.para.block_size, dim=3), 140 | dim=0) 141 | inputs_g = torch.cat(torch.split(inputs_g, split_size_or_sections=config.para.block_size, dim=2), 142 | dim=0) 143 | 144 | inputs_b = torch.cat(torch.split(inputs_b, split_size_or_sections=config.para.block_size, dim=3), 145 | dim=0) 146 | inputs_b = torch.cat(torch.split(inputs_b, split_size_or_sections=config.para.block_size, dim=2), 147 | dim=0) 148 | 149 | r_hat, _ = network(inputs_r.to(config.para.device)) 150 | g_hat, _ = network(inputs_g.to(config.para.device)) 151 | b_hat, _ = network(inputs_b.to(config.para.device)) 152 | 153 | idx = expand_w // config.para.block_size 154 | r_hat = torch.cat(torch.split(r_hat, split_size_or_sections=1 * idx, dim=0), dim=2) 155 | r_hat = torch.cat(torch.split(r_hat, split_size_or_sections=1, dim=0), dim=3) 156 | r_hat = r_hat.squeeze()[:h, :w].cpu().numpy() * 255. 157 | 158 | g_hat = torch.cat(torch.split(g_hat, split_size_or_sections=1 * idx, dim=0), dim=2) 159 | g_hat = torch.cat(torch.split(g_hat, split_size_or_sections=1, dim=0), dim=3) 160 | g_hat = g_hat.squeeze()[:h, :w].cpu().numpy() * 255. 161 | 162 | b_hat = torch.cat(torch.split(b_hat, split_size_or_sections=1 * idx, dim=0), dim=2) 163 | b_hat = torch.cat(torch.split(b_hat, split_size_or_sections=1, dim=0), dim=3) 164 | b_hat = b_hat.squeeze()[:h, :w].cpu().numpy() * 255. 165 | 166 | r_hat, g_hat, b_hat = np.rint(np.clip(r_hat, 0, 255)), \ 167 | np.rint(np.clip(g_hat, 0, 255)), \ 168 | np.rint(np.clip(b_hat, 0, 255)) 169 | reconstruction = cv.merge([r_hat, g_hat, b_hat]) 170 | 171 | psnr = PSNR(x_ori, reconstruction, data_range=255) 172 | ssim = SSIM(x_ori, reconstruction, data_range=255, multichannel=True) 173 | 174 | sum_psnr += psnr 175 | sum_ssim += ssim 176 | 177 | if save_img: 178 | cv.imwrite(f"{recon_dataset_path_rgb_rate}/{name_image}.png", 179 | (reconstruction)) 180 | save_log(recon_root, one_dataset, name_image, psnr, ssim, f"_{config.para.rate}_{manner}") 181 | save_log(recon_root, one_dataset, None, 182 | sum_psnr / len(images), sum_ssim / len(images), f"_{config.para.rate}_{manner}_AVG", False) 183 | print( 184 | f"AVG RES: PSNR, {round(sum_psnr / len(images), 2)}, SSIM, {round(sum_ssim / len(images), 4)}") 185 | if val: 186 | return round(sum_psnr / len(images), 2), round(sum_ssim / len(images), 4) 187 | else: 188 | raise NotImplemented(f"Error manner: {manner}.") 189 | 190 | 191 | if __name__ == "__main__": 192 | my_state_dict = config.para.my_state_dict 193 | device = config.para.device 194 | 195 | net = models.TCS_Net().eval().to(device) 196 | if os.path.exists(my_state_dict): 197 | if torch.cuda.is_available(): 198 | trained_model = torch.load(my_state_dict, map_location=device) 199 | else: 200 | raise Exception(f"No GPU.") 201 | net.load_state_dict(trained_model) 202 | else: 203 | raise FileNotFoundError(f"Missing trained model of rate {config.para.rate}.") 204 | testing(net, val=False, manner=config.para.manner, save_img=config.para.save) 205 | -------------------------------------------------------------------------------- /test/McM18/1.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/1.tif -------------------------------------------------------------------------------- /test/McM18/10.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/10.tif -------------------------------------------------------------------------------- /test/McM18/11.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/11.tif -------------------------------------------------------------------------------- /test/McM18/12.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/12.tif -------------------------------------------------------------------------------- /test/McM18/13.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/13.tif -------------------------------------------------------------------------------- /test/McM18/14.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/14.tif -------------------------------------------------------------------------------- /test/McM18/15.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/15.tif -------------------------------------------------------------------------------- /test/McM18/16.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/16.tif -------------------------------------------------------------------------------- /test/McM18/17.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/17.tif -------------------------------------------------------------------------------- /test/McM18/18.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/18.tif -------------------------------------------------------------------------------- /test/McM18/2.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/2.tif -------------------------------------------------------------------------------- /test/McM18/3.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/3.tif -------------------------------------------------------------------------------- /test/McM18/4.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/4.tif -------------------------------------------------------------------------------- /test/McM18/5.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/5.tif -------------------------------------------------------------------------------- /test/McM18/6.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/6.tif -------------------------------------------------------------------------------- /test/McM18/7.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/7.tif -------------------------------------------------------------------------------- /test/McM18/8.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/8.tif -------------------------------------------------------------------------------- /test/McM18/9.tif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/9.tif -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | from tqdm import tqdm 5 | import torch.utils.data 6 | import torch.optim as optim 7 | import torch.optim.lr_scheduler as LS 8 | 9 | import config 10 | import models 11 | from test import testing 12 | 13 | 14 | def set_seed(seed): 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed_all(seed) 17 | 18 | 19 | def check_path(path): 20 | if not os.path.isdir(path): 21 | os.mkdir(path) 22 | print(f"checking paths, mkdir: {path}") 23 | 24 | 25 | def main(): 26 | check_path(config.para.save_path) 27 | check_path(config.para.folder) 28 | set_seed(996007) 29 | 30 | net = models.TCS_Net().train().to(config.para.device) 31 | optimizer = optim.Adam(filter(lambda x: x.requires_grad, net.parameters()), lr=1e-3) 32 | scheduler = LS.MultiStepLR(optimizer, milestones=[51, 101], gamma=0.1) 33 | if os.path.exists(config.para.my_state_dict): 34 | if torch.cuda.is_available(): 35 | net.load_state_dict(torch.load(config.para.my_state_dict, map_location=config.para.device)) 36 | info = torch.load(config.para.my_info, map_location=config.para.device) 37 | else: 38 | raise Exception(f"No GPU.") 39 | 40 | start_epoch = info["epoch"] 41 | current_best = info["res"] 42 | print(f"Loaded trained model of epoch {start_epoch}, res: {current_best}.") 43 | else: 44 | start_epoch = 1 45 | current_best = 0 46 | print("No saved model, start epoch = 1.") 47 | 48 | print("Data loading...") 49 | # train_set = utils.TrainDatasetFromFolder('./dataset/train/', block_size=utils.para.block_size) 50 | # dataset_train = torch.utils.data.DataLoader( 51 | # dataset=train_set, num_workers=8, batch_size=utils.para.batch_size, shuffle=True) 52 | dataset_train = config.train_loader() 53 | 54 | over_all_time = time.time() 55 | for epoch in range(start_epoch, int(200)): 56 | print("Lr: {}.".format(optimizer.param_groups[0]['lr'])) 57 | 58 | epoch_loss = 0. 59 | dic = {"epoch": epoch, "device": config.para.device, "rate": config.para.rate, "lr": optimizer.param_groups[0]['lr']} 60 | for idx, xi in enumerate(tqdm(dataset_train, desc="Now training: ", postfix=dic)): 61 | xi = xi.to(config.para.device) 62 | 63 | optimizer.zero_grad() 64 | xo, _ = net(xi) 65 | batch_loss = torch.mean(torch.pow(xo - xi, 2)).to(config.para.device) 66 | 67 | if epoch != 1 and batch_loss > 2: 68 | print("\nWarning: your loss > 2 !") 69 | 70 | epoch_loss += batch_loss.item() 71 | 72 | batch_loss.backward() 73 | optimizer.step() 74 | 75 | if idx % 10 == 0: 76 | tqdm.write("\r[{:5}/{:5}], Loss: [{:8.6f}]".format( 77 | config.para.batch_size * (idx + 1), 78 | dataset_train.__len__() * config.para.batch_size, 79 | batch_loss.item())) 80 | 81 | avg_loss = epoch_loss / dataset_train.__len__() 82 | print("\n=> Epoch of {:2}, Epoch Loss: [{:8.6f}]".format(epoch, avg_loss)) 83 | 84 | # Make a log note. 85 | if epoch == 1: 86 | if not os.path.isfile(config.para.my_log): 87 | output_file = open(config.para.my_log, 'w') 88 | output_file.write("=" * 120 + "\n") 89 | output_file.close() 90 | output_file = open(config.para.my_log, 'r+') 91 | old = output_file.read() 92 | output_file.seek(0) 93 | output_file.write("\nAbove is {} test. Note:{}.\n" 94 | .format("???", None) + "=" * 120 + "\n") 95 | output_file.write(old) 96 | output_file.close() 97 | 98 | with torch.no_grad(): 99 | p, s = testing(net.eval(), val=True, manner="grey", save_img=False) 100 | print("{:5.3f}".format(p)) 101 | if p > current_best: 102 | epoch_info = {"epoch": epoch, "res": p} 103 | torch.save(net.state_dict(), config.para.my_state_dict) 104 | torch.save(epoch_info, config.para.my_info) 105 | print("Check point saved\n") 106 | current_best = p 107 | output_file = open(config.para.my_log, 'r+') 108 | old = output_file.read() 109 | output_file.seek(0) 110 | 111 | output_file.write(f"Epoch {epoch}, Loss of train {round(avg_loss, 6)}, Res {round(current_best, 2)}, {round(s, 4)}\n") 112 | output_file.write(old) 113 | output_file.close() 114 | 115 | scheduler.step() 116 | print("Over all time: {:.3f}s".format(time.time() - over_all_time)) 117 | 118 | print("Train end.") 119 | 120 | 121 | if __name__ == "__main__": 122 | main() 123 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python train.py --device 0 --ratio 0.01 --batch_size 64 2 | python train.py --device 0 --ratio 0.04 --batch_size 64 3 | python train.py --device 0 --ratio 0.1 --batch_size 64 4 | python train.py --device 0 --ratio 0.25 --batch_size 64 5 | -------------------------------------------------------------------------------- /trained_models/1/state_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/trained_models/1/state_dict.pth -------------------------------------------------------------------------------- /trained_models/10/state_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/trained_models/10/state_dict.pth -------------------------------------------------------------------------------- /trained_models/25/state_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/trained_models/25/state_dict.pth -------------------------------------------------------------------------------- /trained_models/4/state_dict.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/trained_models/4/state_dict.pth --------------------------------------------------------------------------------