├── LICENSE
├── README.md
├── config
├── __init__.py
├── config.py
└── loader.py
├── models
├── __init__.py
├── layers.py
└── net.py
├── sample
└── rgb.png
├── test.py
├── test
└── McM18
│ ├── 1.tif
│ ├── 10.tif
│ ├── 11.tif
│ ├── 12.tif
│ ├── 13.tif
│ ├── 14.tif
│ ├── 15.tif
│ ├── 16.tif
│ ├── 17.tif
│ ├── 18.tif
│ ├── 2.tif
│ ├── 3.tif
│ ├── 4.tif
│ ├── 5.tif
│ ├── 6.tif
│ ├── 7.tif
│ ├── 8.tif
│ └── 9.tif
├── train.py
├── train.sh
└── trained_models
├── 1
└── state_dict.pth
├── 4
└── state_dict.pth
├── 10
└── state_dict.pth
└── 25
└── state_dict.pth
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 ICSResearch
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TCS-Net
2 | This repository is the `pytorch` code for paper `"From Patch to Pixel: A Transformer-based Hierarchical Framework for Compressive Image Sensing"`.
3 | ## 1. Introduction ##
4 | **1) Datasets**
5 |
6 | Training set: [`BSDS500`](https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html), testing sets: [`McM18`](https://www4.comp.polyu.edu.hk/~cslzhang/CDM_Dataset.html), [`LIVE29`](http://live.ece.utexas.edu/research/Quality/), [`General100`](http://mmlab.ie.cuhk.edu.hk/projects/FSRCNN.html) and [`OST300`](http://mmlab.ie.cuhk.edu.hk/projects/SFTGAN/).
7 |
8 | **2)Project structure**
9 | ```
10 | (TCS-Net)
11 | |-dataset
12 | | |-train
13 | | |-BSDS500 (.jpg)
14 | | |-test
15 | | |-McM18
16 | | |-LIVE29
17 | | |-General100
18 | | |-OST300
19 | |-reconstructed_images
20 | | |-McM18
21 | | |-grey
22 | | |-... (Testing results .png)
23 | | |-rgb
24 | | |-... (Testing results .png)
25 | | |-... (Testing sets)
26 | | |-Res_(...).txt
27 | |-models
28 | | |-__init__.py
29 | | |-net.py
30 | | |-modules.py
31 | |-trained_models
32 | | |-1
33 | | |-4
34 | | |-... (Sampling rates)
35 | |-config
36 | | |-__init__.py
37 | | |-config.py
38 | | |-loader.py
39 | |-test.py
40 | |-train.py
41 | |-train.sh
42 | ```
43 |
44 | **3) Competting methods**
45 |
46 | |Methods|Sources|Year|
47 | |:----|:----|:----|
48 | | | [Conf. Comput. Vis. Pattern Recog.](https://ieeexplore.ieee.org/document/7780424/) | 2016 |
49 | |  | [Proc. Adv. Neural Inf. Process. Syst.](https://dl.acm.org/doi/10.5555/3294771.3294940) | 2017 |
50 | |  | [Proc. Adv. Neural Inf. Process. Syst.](https://dl.acm.org/doi/10.5555/3294771.3294940) | 2017 |
51 | |  | [Conf. Comput. Vis. Pattern Recog.](https://ieeexplore.ieee.org/document/8578294) | 2018 |
52 | |  | [Proc. Int. Conf. Mach. Learn.](http://proceedings.mlr.press/v97/wu19d.html) | 2019 |
53 | |  | [Trans. Image Process.](https://ieeexplore.ieee.org/document/8765626/) | 2020 |
54 | |  | [Trans. Image Process.](https://ieeexplore.ieee.org/document/9298950) | 2021 |
55 | |CSformer| arXiv | 2022 |
56 |
57 |
58 | **4) Performance demonstrates**
59 |
60 | Visual comparisons of reconstruction images (original images are drawn from dataset `LIVE29`):
61 |
62 |

63 |
64 | ## 2. Useage ##
65 | **1) Re-training TCS-Net.**
66 |
67 | * Put the `BSDS500` and `VOC2012` images into `./dataset/train/`.
68 | * e.g., If you want to train TCS-Net at sampling rate `τ = 0.1` with `GPU No.0`, please run the following command. The train set will be automatically packaged and our model will be trained with its default parameters (please make sure you have enough GPU RAM):
69 | ```
70 | python train.py --rate 0.1 --GPU 0
71 | ```
72 | * You can also run our shell script directly as well, it will automatically train the model under all sampling rates, i.e., `τ ∈ {0.01, 0.04, 0.1, 0.25}`:
73 | ```
74 | sh train.sh
75 | ```
76 | * The trained models (.pth) will save in the `trained_models` folder.
77 |
78 | **2) Testing TCS-Net.**
79 | * We provide the trained models so that you can put them under `TCS-Net/trained_models/` and use them for testing directly; all trained TCS-Net models can be found in this [GoogleDrive link](https://drive.google.com/drive/folders/15dRG29V51i8rVraz8TkHtev7N3jLkx0U?usp=sharing); Please note that the `folder's names` are the `100 times of sampling rates`, e.g., the folder named `10` includes trained models at `sampling rate = 0.1`.
80 |
81 | * Put the testing folders into `./dataset/test/`.
82 | * e.g., if you want to test TCS-Net at sampling rate τ = 0.1 with GPU No.0, please run:
83 | ```
84 | python test.py --rate 0.1 --GPU 0
85 | ```
86 | * After that, the reconstructed images, PSNR and SSIM results will be saved to `./reconstructed_images/`.
87 | ## End ##
88 |
89 | We appreciate your reading and attention. For more details about TCS-Net, please refer to our paper.
90 |
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import para
2 | from .loader import train_loader
3 | from .loader import TrainDatasetFromFolder
4 |
--------------------------------------------------------------------------------
/config/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | parser = argparse.ArgumentParser(description="Args of this repo.")
5 | parser.add_argument("--rate", default=0.1, type=float)
6 | parser.add_argument("--lr", default=0.001, type=float)
7 | parser.add_argument("--device", default="0")
8 | parser.add_argument("--time", default=0, type=int)
9 | parser.add_argument("--block_size", default=96, type=int)
10 | parser.add_argument("--batch_size", default=32, type=int)
11 |
12 | parser.add_argument("--save", default=False)
13 | parser.add_argument("--manner", default="grey")
14 |
15 | parser.add_argument("--save_path", default=f"./trained_models")
16 | parser.add_argument("--folder")
17 | parser.add_argument("--my_state_dict")
18 | parser.add_argument("--my_log")
19 | parser.add_argument("--my_info")
20 | para = parser.parse_args()
21 | para.device = f"cuda:{para.device}"
22 | para.folder = f"{para.save_path}/{str(int(para.rate * 100))}/"
23 | para.my_state_dict = f"{para.folder}/state_dict.pth"
24 | para.my_log = f"{para.folder}/log.txt"
25 | para.my_info = f"{para.folder}/info.pth"
26 |
--------------------------------------------------------------------------------
/config/loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import time
4 | import copy
5 | import torch
6 | import random
7 | import numpy as np
8 | import torchvision
9 | from PIL import Image
10 | from tqdm import tqdm
11 | import torch.utils.data as data
12 |
13 | import config
14 |
15 |
16 | class TrainDataPackage:
17 | r"""
18 | Packaged images dataset from BSD500/train and BSD500/test (total 400 images) to *.pth,
19 | We also use VerticalFlip, HorizontalFlip and random generate function from one image (*.jpg) to 150 patches (tensor)
20 | to enhance our dataset.
21 | """
22 |
23 | def __init__(self, root="./dataset", transform=None, packaged=True):
24 | self.training_file = "train.pt"
25 | self.aug = DataAugment(debug=True)
26 | self.packaged = packaged
27 | self.root = root
28 | self.num = 50 # 50
29 | self.transform = transform or torchvision.transforms.Compose([
30 | torchvision.transforms.ToTensor(),
31 | torchvision.transforms.RandomVerticalFlip(),
32 | torchvision.transforms.RandomHorizontalFlip(),
33 | torchvision.transforms.Grayscale(num_output_channels=1),
34 | ])
35 |
36 | if not (os.path.exists(os.path.join(self.root, self.training_file))):
37 | print("No packaged dataset file (*.pt) in dataset/, now generating...")
38 | self.generate()
39 |
40 | if packaged:
41 | self.train_data = torch.load(os.path.join(self.root, self.training_file))
42 |
43 | def __len__(self):
44 | return len(self.train_data)
45 |
46 | def __getitem__(self, index):
47 | img = self.train_data[index]
48 | return img
49 |
50 | def generate(self):
51 | paths = [
52 | os.path.join(self.root, "BSD500/"),
53 | os.path.join(self.root, "VOCdevkit/VOC2012/JPEGImages/"),
54 | ]
55 | patches_list = []
56 |
57 | start = time.time()
58 | for path in paths:
59 | for roots, dirs, files in os.walk(path):
60 | if roots == path:
61 | print("Image number: {}".format(files.__len__()))
62 | for file in tqdm(files):
63 | if file.split('.')[1] == "jpg" or "png" or "tif" or "bmp":
64 | temp = os.path.join(path, file)
65 | tqdm.write("\r=> Processing " + temp)
66 | image = cv2.imread(temp)
67 | patches = self.random_patch(image, self.num)
68 | patches_list.extend(patches)
69 | print("Total patches: {}".format(patches_list.__len__()))
70 | print("Now Packaging...")
71 | with open(os.path.join(self.root, "train.pt"), 'wb') as f:
72 | torch.save(patches_list, f)
73 | end = time.time()
74 | print("Successfully packaged!, used time: {:.3f}".format(end - start))
75 |
76 | def random_patch(self, image, num):
77 | size = config.para.block_size
78 | image = np.array(image, dtype=np.float32) / 255.
79 | h, w = image.shape[0], image.shape[1]
80 | if h <= config.para.block_size or w <= config.para.block_size:
81 | return []
82 | patches = []
83 | for n in range(num):
84 | max_h = random.randint(0, h - size)
85 | max_w = random.randint(0, w - size)
86 | patch = image[max_h:max_h + size, max_w:max_w + size, :]
87 | patch = self.transform(patch)
88 | patches.append(patch)
89 | return patches
90 |
91 |
92 | def train_loader():
93 | train_dataset = TrainDataPackage()
94 | dst = torch.utils.data.DataLoader(train_dataset, batch_size=config.para.batch_size, drop_last=True, shuffle=True,
95 | pin_memory=True, num_workers=8, prefetch_factor=8)
96 | return dst
97 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .net import TCS_Net
2 | from .layers import OneUnit, Units
3 | from .modules import TransformerEncoder, TransformerDecoder
4 |
--------------------------------------------------------------------------------
/models/layers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | from einops.layers.torch import Rearrange
5 |
6 |
7 | class FeedForward(nn.Module):
8 | def __init__(self, dim, hidden_dim, dropout=0.):
9 | super().__init__()
10 | self.net = nn.Sequential(
11 | nn.Linear(dim, hidden_dim),
12 | nn.ReLU(),
13 | nn.Dropout(dropout),
14 | nn.Linear(hidden_dim, dim),
15 | )
16 |
17 | def forward(self, x):
18 | return self.net(x)
19 |
20 |
21 | class ScaledDotProduct(nn.Module):
22 | def __init__(self):
23 | super(ScaledDotProduct, self).__init__()
24 | self.softmax = nn.Softmax(dim=-1)
25 |
26 | def forward(self, q, k, v, scale=None):
27 | attention = torch.bmm(q, k.transpose(-2, -1))
28 | if scale:
29 | attention = attention * scale
30 | attention = self.softmax(attention)
31 | context = torch.bmm(attention, v)
32 | return context, attention
33 |
34 |
35 | class MultiHeadAttention(nn.Module):
36 | def __init__(self, model_dim=16 ** 2, num_heads=8, dropout=0., out_dim=None):
37 | super(MultiHeadAttention, self).__init__()
38 | self.dim_per_head = model_dim // num_heads
39 | self.num_heads = num_heads
40 | self.linear_k = nn.Linear(model_dim, model_dim)
41 | self.linear_v = nn.Linear(model_dim, model_dim)
42 | self.linear_q = nn.Linear(model_dim, model_dim)
43 |
44 | self.dot_product_attention = ScaledDotProduct()
45 | self.linear_out = nn.Linear(model_dim, out_dim if out_dim is not None else model_dim)
46 | self.dropout = nn.Dropout(dropout)
47 |
48 | def _reshape_to_heads(self, x):
49 | batch_size, seq_len, in_feature = x.size()
50 | sub_dim = in_feature // self.num_heads
51 | return x.reshape(batch_size, seq_len, self.num_heads, sub_dim) \
52 | .permute(0, 2, 1, 3) \
53 | .reshape(batch_size * self.num_heads, seq_len, sub_dim)
54 |
55 | def _reshape_from_heads(self, x):
56 | batch_size, seq_len, in_feature = x.size()
57 | batch_size //= self.num_heads
58 | out_dim = in_feature * self.num_heads
59 | return x.reshape(batch_size, self.num_heads, seq_len, in_feature) \
60 | .permute(0, 2, 1, 3) \
61 | .reshape(batch_size, seq_len, out_dim)
62 |
63 | def forward(self, x):
64 | key = self._reshape_to_heads(self.linear_k(x))
65 | value = self._reshape_to_heads(self.linear_v(x))
66 | query = self._reshape_to_heads(self.linear_q(x))
67 |
68 | scale = key.size(-1) ** -0.5
69 | context, attention = self.dot_product_attention(query, key, value, scale)
70 |
71 | context = self._reshape_from_heads(context)
72 |
73 | output = self.linear_out(context)
74 | output = self.dropout(output)
75 | return output, attention
76 |
77 |
78 | class OneUnit(nn.Module):
79 | def __init__(self, dim=16 ** 2, heads=8, dropout=0., out_dim=None):
80 | super().__init__()
81 | self.dim = dim
82 | self.flag = out_dim
83 | # self.attn = nn.MultiheadAttention(dim, heads, dropout=dropout)
84 | self.attn = MultiHeadAttention(model_dim=dim, num_heads=heads, dropout=dropout, out_dim=out_dim)
85 | self.ffn = FeedForward(dim=dim if out_dim is None else out_dim, hidden_dim=int(dim * 4), dropout=dropout)
86 | self.norm1 = nn.LayerNorm(dim if out_dim is None else out_dim)
87 | self.norm2 = nn.LayerNorm(dim if out_dim is None else out_dim)
88 |
89 | def forward(self, x):
90 | context, attn = self.attn(x)
91 | x = self.norm1(context + x) if self.flag is None else self.norm1(context)
92 | ffn = self.ffn(x)
93 | x = self.norm2(ffn + x)
94 | return x, attn
95 |
96 |
97 | class Units(nn.Module):
98 | def __init__(self, dim=16, depth=1, heads=8, dropout=0., out_dim=None):
99 | super().__init__()
100 | self.patch_size = dim
101 | self.dim = self.patch_size ** 2
102 | self.depth = depth
103 | self.layers = nn.ModuleList()
104 | for _ in range(depth):
105 | self.layers.append(OneUnit(dim=self.dim, heads=heads, dropout=dropout, out_dim=out_dim))
106 |
107 | def forward(self, x):
108 | x = Rearrange('b l h w -> b l (h w)')(x)
109 | x = x * math.sqrt(self.dim)
110 | attn_maps = []
111 | for i in range(self.depth):
112 | x, attn = self.layers[i](x)
113 | attn_maps.append(attn)
114 | x = Rearrange('b l (h w)-> b l h w', h=self.patch_size, w=self.patch_size)(x)
115 | return x, attn_maps
116 |
--------------------------------------------------------------------------------
/models/net.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 |
5 | import config
6 | import models
7 |
8 |
9 | class TCS_Net(nn.Module):
10 | def __init__(self):
11 | super(TCS_Net, self).__init__()
12 | self.block_size = config.para.block_size
13 |
14 | self.NUM_UNITS_PATCH = 6
15 | self.NUM_UNITS_PIXEL = 1
16 |
17 | self.DIM = 32
18 | self.PIXEL_EMBED = 32
19 | self.PATCH_DIM = 8 # todo
20 | self.idx = config.para.block_size // self.DIM
21 |
22 | # pixel-wise
23 | pixel_embedding = np.random.normal(0.0, (1 / self.PIXEL_EMBED) ** 0.5, size=(1, self.PIXEL_EMBED))
24 | self.pixel_embedding = nn.Parameter(torch.from_numpy(pixel_embedding).float(), requires_grad=True)
25 | self.transform_pixel_wise = nn.ModuleList()
26 | for i in range(self.NUM_UNITS_PIXEL):
27 | self.transform_pixel_wise.append(models.OneUnit(dim=self.PIXEL_EMBED, dropout=0.5, heads=1))
28 | pixel_detaching = np.random.normal(0.0, (1 / self.PIXEL_EMBED) ** 0.5, size=(self.PIXEL_EMBED, 1))
29 | self.pixel_detaching = nn.Parameter(torch.from_numpy(pixel_detaching).float(), requires_grad=True)
30 |
31 | # patch-wise
32 | self.transform_patch_wise = models.Units(dim=self.PATCH_DIM, depth=self.NUM_UNITS_PATCH, dropout=0.5, heads=8)
33 | # self.transform_patch_wise = PatchTransform(
34 | # dim=self.PATCH_DIM ** 2, num_layers=self.NUM_UNITS_PATCH, drop_out=0.5)
35 |
36 | # sampling and init recon
37 | points = self.DIM ** 2
38 | p_init = np.random.normal(0.0, (1 / points) ** 0.5, size=(points, int(config.para.rate * points)))
39 | self.P = nn.Parameter(torch.from_numpy(p_init).float(), requires_grad=True)
40 | self.R = nn.Parameter(torch.from_numpy(np.transpose(p_init)).float(), requires_grad=True)
41 |
42 | w_init = 1e-3 * np.ones(self.DIM)
43 | self.w = nn.Parameter(torch.from_numpy(w_init).float(), requires_grad=True)
44 |
45 | def forward(self, inputs):
46 | batch_size = inputs.size(0)
47 | y = self.sampling(inputs)
48 | recon, tmp = self.recon(y, batch_size)
49 | return recon, tmp
50 |
51 | def sampling(self, inputs):
52 | inputs = inputs.to(config.para.device)
53 | inputs = torch.cat(torch.split(inputs, split_size_or_sections=self.DIM, dim=3), dim=0)
54 | inputs = torch.cat(torch.split(inputs, split_size_or_sections=self.DIM, dim=2), dim=0)
55 | inputs = inputs.reshape(-1, self.DIM ** 2)
56 | y = torch.matmul(inputs, self.P.to(config.para.device))
57 | return y
58 |
59 | def recon(self, y, batch_size):
60 | # init
61 | init = torch.matmul(y, self.R.to(config.para.device)).reshape(-1, 1, self.DIM, self.DIM)
62 |
63 | recon = torch.cat(torch.split(init, split_size_or_sections=batch_size * self.idx, dim=0), dim=2)
64 | recon = torch.cat(torch.split(recon, split_size_or_sections=batch_size, dim=0), dim=3)
65 |
66 | # patch
67 | recon = self.block2patch(recon, self.PATCH_DIM)
68 | patch = self.transform_patch_wise(recon)
69 | recon = recon - patch[0]
70 | recon = self.patch2block(recon, self.block_size, self.PATCH_DIM)
71 |
72 | recon = torch.cat(torch.split(recon, split_size_or_sections=self.DIM, dim=3), dim=0)
73 | recon = torch.cat(torch.split(recon, split_size_or_sections=self.DIM, dim=2), dim=0).reshape(-1, self.DIM ** 2, 1)
74 |
75 | # pixel
76 | recon_pixel = torch.matmul(recon, self.pixel_embedding)
77 | for i in range(self.NUM_UNITS_PIXEL):
78 | recon_pixel, _ = self.transform_pixel_wise[i](recon_pixel)
79 | recon_pixel = torch.matmul(recon_pixel, self.pixel_detaching)
80 | recon = recon.reshape(-1, 1, self.DIM, self.DIM) - recon_pixel.reshape(-1, 1, self.DIM, self.DIM) * self.w
81 |
82 | recon = torch.cat(torch.split(recon, split_size_or_sections=batch_size * self.idx, dim=0), dim=2)
83 | recon = torch.cat(torch.split(recon, split_size_or_sections=batch_size, dim=0), dim=3)
84 |
85 | return recon, None
86 |
87 | @staticmethod
88 | def block2patch(inputs, size):
89 | inputs = torch.cat(torch.split(inputs, split_size_or_sections=size, dim=3), dim=1)
90 | inputs = torch.cat(torch.split(inputs, split_size_or_sections=size, dim=2), dim=1)
91 | return inputs
92 |
93 | @staticmethod
94 | def patch2block(inputs, block_size, patch_size, ori_channel=1):
95 | assert block_size % patch_size == 0, f"block size {block_size} should be divided by patch size {patch_size}."
96 | idx = int(block_size / patch_size)
97 | outputs = torch.cat(torch.split(inputs, split_size_or_sections=ori_channel * idx, dim=1), dim=2)
98 | outputs = torch.cat(torch.split(outputs, split_size_or_sections=ori_channel, dim=1), dim=3)
99 | return outputs
100 |
--------------------------------------------------------------------------------
/sample/rgb.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/sample/rgb.png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import cv2 as cv
4 | import numpy as np
5 | from skimage.metrics import structural_similarity as SSIM
6 | from skimage.metrics import peak_signal_noise_ratio as PSNR
7 |
8 | import config
9 | import models
10 |
11 |
12 | def save_log(recon_root, name_dataset, name_image, psnr, ssim, manner, consecutive=True):
13 | if not os.path.isfile(f"{recon_root}/Res_{name_dataset}_{manner}.txt"):
14 | log = open(f"{recon_root}/Res_{name_dataset}_{manner}.txt", 'w')
15 | log.write("=" * 120 + "\n")
16 | log.close()
17 | log = open(f"{recon_root}/Res_{name_dataset}_{manner}.txt", 'r+')
18 | if consecutive:
19 | old = log.read()
20 | log.seek(0)
21 | log.write(old)
22 | log.write(
23 | f"Res {name_image}: PSNR, {round(psnr, 2)}, SSIM, {round(ssim, 4)}\n")
24 | log.close()
25 |
26 |
27 | def testing(network, val, manner=config.para.manner, save_img=config.para.save):
28 | """
29 | The pre-processing before TCS-Net's forward propagation and the testing platform.
30 | """
31 | recon_root = "reconstructed_images"
32 | if not os.path.isdir(recon_root):
33 | os.mkdir(recon_root)
34 | datasets = ["Set11"] if val else ["McM18", "LIVE29", "General100", "OST300"] # Names of folders (testing datasets)
35 | with torch.no_grad():
36 | for one_dataset in datasets:
37 | if not os.path.isdir(f"{recon_root}/{one_dataset}"):
38 | os.mkdir(f"{recon_root}/{one_dataset}")
39 |
40 | test_dataset_path = f"dataset/test/{one_dataset}"
41 |
42 | # Grey manner
43 | if manner == "grey":
44 | recon_dataset_path_grey = f"{recon_root}/{one_dataset}/grey/"
45 | recon_dataset_path_grey_rate = f"{recon_root}/{one_dataset}/grey/{config.para.rate}"
46 | if not os.path.isdir(recon_dataset_path_grey):
47 | os.mkdir(recon_dataset_path_grey)
48 | if not os.path.isdir(recon_dataset_path_grey_rate):
49 | os.mkdir(recon_dataset_path_grey_rate)
50 | sum_psnr, sum_ssim = 0., 0.
51 | for _, _, images in os.walk(f"{test_dataset_path}/rgb/"):
52 | for one_image in images:
53 | name_image = one_image.split('.')[0]
54 | x = cv.imread(f"{test_dataset_path}/rgb/{one_image}", flags=cv.IMREAD_GRAYSCALE)
55 | x_ori = x
56 | x = torch.from_numpy(x / 255.).float()
57 | h, w = x.size()
58 |
59 | lack = config.para.block_size - h % config.para.block_size if h % config.para.block_size != 0 else 0
60 | padding_h = torch.zeros(lack, w)
61 | expand_h = h + lack
62 | inputs = torch.cat((x, padding_h), 0)
63 |
64 | lack = config.para.block_size - w % config.para.block_size if w % config.para.block_size != 0 else 0
65 | expand_w = w + lack
66 | padding_w = torch.zeros(expand_h, lack)
67 | inputs = torch.cat((inputs, padding_w), 1).unsqueeze(0).unsqueeze(0)
68 |
69 | inputs = torch.cat(torch.split(inputs, split_size_or_sections=config.para.block_size, dim=3), dim=0)
70 | inputs = torch.cat(torch.split(inputs, split_size_or_sections=config.para.block_size, dim=2), dim=0)
71 |
72 | reconstruction, _ = network(inputs)
73 |
74 | idx = expand_w // config.para.block_size
75 | reconstruction = torch.cat(torch.split(reconstruction, split_size_or_sections=1 * idx, dim=0), dim=2)
76 | reconstruction = torch.cat(torch.split(reconstruction, split_size_or_sections=1, dim=0), dim=3)
77 | reconstruction = reconstruction.squeeze()[:h, :w]
78 |
79 | x_hat = reconstruction.cpu().numpy() * 255.
80 | x_hat = np.rint(np.clip(x_hat, 0, 255))
81 |
82 | psnr = PSNR(x_ori, x_hat, data_range=255)
83 | ssim = SSIM(x_ori, x_hat, data_range=255, multichannel=False)
84 |
85 | sum_psnr += psnr
86 | sum_ssim += ssim
87 |
88 | if save_img:
89 | cv.imwrite(f"{recon_dataset_path_grey_rate}/{name_image}.png", x_hat)
90 | save_log(recon_root, one_dataset, name_image, psnr, ssim, f"_{config.para.rate}_{manner}")
91 | save_log(recon_root, one_dataset, None,
92 | sum_psnr / len(images), sum_ssim / len(images), f"_{config.para.rate}_{manner}_AVG", False)
93 | print(
94 | f"AVG RES: PSNR, {round(sum_psnr / len(images), 2)}, SSIM, {round(sum_ssim / len(images), 4)}")
95 | if val:
96 | return round(sum_psnr / len(images), 2), round(sum_ssim / len(images), 4)
97 |
98 | # RGB manner
99 | elif manner == "rgb":
100 | recon_dataset_path_rgb = f"{recon_root}/{one_dataset}/rgb/"
101 | recon_dataset_path_rgb_rate = f"{recon_root}/{one_dataset}/rgb/{config.para.rate}"
102 | if not os.path.isdir(recon_dataset_path_rgb):
103 | os.mkdir(recon_dataset_path_rgb)
104 | if not os.path.isdir(recon_dataset_path_rgb_rate):
105 | os.mkdir(recon_dataset_path_rgb_rate)
106 | sum_psnr, sum_ssim = 0., 0.
107 | for _, _, images in os.walk(f"{test_dataset_path}/rgb/"):
108 | for one_image in images:
109 | name_image = one_image.split('.')[0]
110 | x = cv.imread(f"{test_dataset_path}/rgb/{one_image}")
111 | x_ori = x
112 | r, g, b = cv.split(x)
113 | r = torch.from_numpy(np.asarray(r)).squeeze().float() / 255.
114 | g = torch.from_numpy(np.asarray(g)).squeeze().float() / 255.
115 | b = torch.from_numpy(np.asarray(b)).squeeze().float() / 255.
116 |
117 | x = torch.from_numpy(x).float()
118 | h, w = x.size()[0], x.size()[1]
119 |
120 | lack = config.para.block_size - h % config.para.block_size if h % config.para.block_size != 0 else 0
121 | padding_h = torch.zeros(lack, w)
122 | expand_h = h + lack
123 | inputs_r = torch.cat((r, padding_h), 0)
124 | inputs_g = torch.cat((g, padding_h), 0)
125 | inputs_b = torch.cat((b, padding_h), 0)
126 |
127 | lack = config.para.block_size - w % config.para.block_size if w % config.para.block_size != 0 else 0
128 | expand_w = w + lack
129 | padding_w = torch.zeros(expand_h, lack)
130 | inputs_r = torch.cat((inputs_r, padding_w), 1).unsqueeze(0).unsqueeze(0)
131 | inputs_g = torch.cat((inputs_g, padding_w), 1).unsqueeze(0).unsqueeze(0)
132 | inputs_b = torch.cat((inputs_b, padding_w), 1).unsqueeze(0).unsqueeze(0)
133 |
134 | inputs_r = torch.cat(torch.split(inputs_r, split_size_or_sections=config.para.block_size, dim=3),
135 | dim=0)
136 | inputs_r = torch.cat(torch.split(inputs_r, split_size_or_sections=config.para.block_size, dim=2),
137 | dim=0)
138 |
139 | inputs_g = torch.cat(torch.split(inputs_g, split_size_or_sections=config.para.block_size, dim=3),
140 | dim=0)
141 | inputs_g = torch.cat(torch.split(inputs_g, split_size_or_sections=config.para.block_size, dim=2),
142 | dim=0)
143 |
144 | inputs_b = torch.cat(torch.split(inputs_b, split_size_or_sections=config.para.block_size, dim=3),
145 | dim=0)
146 | inputs_b = torch.cat(torch.split(inputs_b, split_size_or_sections=config.para.block_size, dim=2),
147 | dim=0)
148 |
149 | r_hat, _ = network(inputs_r.to(config.para.device))
150 | g_hat, _ = network(inputs_g.to(config.para.device))
151 | b_hat, _ = network(inputs_b.to(config.para.device))
152 |
153 | idx = expand_w // config.para.block_size
154 | r_hat = torch.cat(torch.split(r_hat, split_size_or_sections=1 * idx, dim=0), dim=2)
155 | r_hat = torch.cat(torch.split(r_hat, split_size_or_sections=1, dim=0), dim=3)
156 | r_hat = r_hat.squeeze()[:h, :w].cpu().numpy() * 255.
157 |
158 | g_hat = torch.cat(torch.split(g_hat, split_size_or_sections=1 * idx, dim=0), dim=2)
159 | g_hat = torch.cat(torch.split(g_hat, split_size_or_sections=1, dim=0), dim=3)
160 | g_hat = g_hat.squeeze()[:h, :w].cpu().numpy() * 255.
161 |
162 | b_hat = torch.cat(torch.split(b_hat, split_size_or_sections=1 * idx, dim=0), dim=2)
163 | b_hat = torch.cat(torch.split(b_hat, split_size_or_sections=1, dim=0), dim=3)
164 | b_hat = b_hat.squeeze()[:h, :w].cpu().numpy() * 255.
165 |
166 | r_hat, g_hat, b_hat = np.rint(np.clip(r_hat, 0, 255)), \
167 | np.rint(np.clip(g_hat, 0, 255)), \
168 | np.rint(np.clip(b_hat, 0, 255))
169 | reconstruction = cv.merge([r_hat, g_hat, b_hat])
170 |
171 | psnr = PSNR(x_ori, reconstruction, data_range=255)
172 | ssim = SSIM(x_ori, reconstruction, data_range=255, multichannel=True)
173 |
174 | sum_psnr += psnr
175 | sum_ssim += ssim
176 |
177 | if save_img:
178 | cv.imwrite(f"{recon_dataset_path_rgb_rate}/{name_image}.png",
179 | (reconstruction))
180 | save_log(recon_root, one_dataset, name_image, psnr, ssim, f"_{config.para.rate}_{manner}")
181 | save_log(recon_root, one_dataset, None,
182 | sum_psnr / len(images), sum_ssim / len(images), f"_{config.para.rate}_{manner}_AVG", False)
183 | print(
184 | f"AVG RES: PSNR, {round(sum_psnr / len(images), 2)}, SSIM, {round(sum_ssim / len(images), 4)}")
185 | if val:
186 | return round(sum_psnr / len(images), 2), round(sum_ssim / len(images), 4)
187 | else:
188 | raise NotImplemented(f"Error manner: {manner}.")
189 |
190 |
191 | if __name__ == "__main__":
192 | my_state_dict = config.para.my_state_dict
193 | device = config.para.device
194 |
195 | net = models.TCS_Net().eval().to(device)
196 | if os.path.exists(my_state_dict):
197 | if torch.cuda.is_available():
198 | trained_model = torch.load(my_state_dict, map_location=device)
199 | else:
200 | raise Exception(f"No GPU.")
201 | net.load_state_dict(trained_model)
202 | else:
203 | raise FileNotFoundError(f"Missing trained model of rate {config.para.rate}.")
204 | testing(net, val=False, manner=config.para.manner, save_img=config.para.save)
205 |
--------------------------------------------------------------------------------
/test/McM18/1.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/1.tif
--------------------------------------------------------------------------------
/test/McM18/10.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/10.tif
--------------------------------------------------------------------------------
/test/McM18/11.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/11.tif
--------------------------------------------------------------------------------
/test/McM18/12.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/12.tif
--------------------------------------------------------------------------------
/test/McM18/13.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/13.tif
--------------------------------------------------------------------------------
/test/McM18/14.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/14.tif
--------------------------------------------------------------------------------
/test/McM18/15.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/15.tif
--------------------------------------------------------------------------------
/test/McM18/16.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/16.tif
--------------------------------------------------------------------------------
/test/McM18/17.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/17.tif
--------------------------------------------------------------------------------
/test/McM18/18.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/18.tif
--------------------------------------------------------------------------------
/test/McM18/2.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/2.tif
--------------------------------------------------------------------------------
/test/McM18/3.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/3.tif
--------------------------------------------------------------------------------
/test/McM18/4.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/4.tif
--------------------------------------------------------------------------------
/test/McM18/5.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/5.tif
--------------------------------------------------------------------------------
/test/McM18/6.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/6.tif
--------------------------------------------------------------------------------
/test/McM18/7.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/7.tif
--------------------------------------------------------------------------------
/test/McM18/8.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/8.tif
--------------------------------------------------------------------------------
/test/McM18/9.tif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/test/McM18/9.tif
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | from tqdm import tqdm
5 | import torch.utils.data
6 | import torch.optim as optim
7 | import torch.optim.lr_scheduler as LS
8 |
9 | import config
10 | import models
11 | from test import testing
12 |
13 |
14 | def set_seed(seed):
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed_all(seed)
17 |
18 |
19 | def check_path(path):
20 | if not os.path.isdir(path):
21 | os.mkdir(path)
22 | print(f"checking paths, mkdir: {path}")
23 |
24 |
25 | def main():
26 | check_path(config.para.save_path)
27 | check_path(config.para.folder)
28 | set_seed(996007)
29 |
30 | net = models.TCS_Net().train().to(config.para.device)
31 | optimizer = optim.Adam(filter(lambda x: x.requires_grad, net.parameters()), lr=1e-3)
32 | scheduler = LS.MultiStepLR(optimizer, milestones=[51, 101], gamma=0.1)
33 | if os.path.exists(config.para.my_state_dict):
34 | if torch.cuda.is_available():
35 | net.load_state_dict(torch.load(config.para.my_state_dict, map_location=config.para.device))
36 | info = torch.load(config.para.my_info, map_location=config.para.device)
37 | else:
38 | raise Exception(f"No GPU.")
39 |
40 | start_epoch = info["epoch"]
41 | current_best = info["res"]
42 | print(f"Loaded trained model of epoch {start_epoch}, res: {current_best}.")
43 | else:
44 | start_epoch = 1
45 | current_best = 0
46 | print("No saved model, start epoch = 1.")
47 |
48 | print("Data loading...")
49 | # train_set = utils.TrainDatasetFromFolder('./dataset/train/', block_size=utils.para.block_size)
50 | # dataset_train = torch.utils.data.DataLoader(
51 | # dataset=train_set, num_workers=8, batch_size=utils.para.batch_size, shuffle=True)
52 | dataset_train = config.train_loader()
53 |
54 | over_all_time = time.time()
55 | for epoch in range(start_epoch, int(200)):
56 | print("Lr: {}.".format(optimizer.param_groups[0]['lr']))
57 |
58 | epoch_loss = 0.
59 | dic = {"epoch": epoch, "device": config.para.device, "rate": config.para.rate, "lr": optimizer.param_groups[0]['lr']}
60 | for idx, xi in enumerate(tqdm(dataset_train, desc="Now training: ", postfix=dic)):
61 | xi = xi.to(config.para.device)
62 |
63 | optimizer.zero_grad()
64 | xo, _ = net(xi)
65 | batch_loss = torch.mean(torch.pow(xo - xi, 2)).to(config.para.device)
66 |
67 | if epoch != 1 and batch_loss > 2:
68 | print("\nWarning: your loss > 2 !")
69 |
70 | epoch_loss += batch_loss.item()
71 |
72 | batch_loss.backward()
73 | optimizer.step()
74 |
75 | if idx % 10 == 0:
76 | tqdm.write("\r[{:5}/{:5}], Loss: [{:8.6f}]".format(
77 | config.para.batch_size * (idx + 1),
78 | dataset_train.__len__() * config.para.batch_size,
79 | batch_loss.item()))
80 |
81 | avg_loss = epoch_loss / dataset_train.__len__()
82 | print("\n=> Epoch of {:2}, Epoch Loss: [{:8.6f}]".format(epoch, avg_loss))
83 |
84 | # Make a log note.
85 | if epoch == 1:
86 | if not os.path.isfile(config.para.my_log):
87 | output_file = open(config.para.my_log, 'w')
88 | output_file.write("=" * 120 + "\n")
89 | output_file.close()
90 | output_file = open(config.para.my_log, 'r+')
91 | old = output_file.read()
92 | output_file.seek(0)
93 | output_file.write("\nAbove is {} test. Note:{}.\n"
94 | .format("???", None) + "=" * 120 + "\n")
95 | output_file.write(old)
96 | output_file.close()
97 |
98 | with torch.no_grad():
99 | p, s = testing(net.eval(), val=True, manner="grey", save_img=False)
100 | print("{:5.3f}".format(p))
101 | if p > current_best:
102 | epoch_info = {"epoch": epoch, "res": p}
103 | torch.save(net.state_dict(), config.para.my_state_dict)
104 | torch.save(epoch_info, config.para.my_info)
105 | print("Check point saved\n")
106 | current_best = p
107 | output_file = open(config.para.my_log, 'r+')
108 | old = output_file.read()
109 | output_file.seek(0)
110 |
111 | output_file.write(f"Epoch {epoch}, Loss of train {round(avg_loss, 6)}, Res {round(current_best, 2)}, {round(s, 4)}\n")
112 | output_file.write(old)
113 | output_file.close()
114 |
115 | scheduler.step()
116 | print("Over all time: {:.3f}s".format(time.time() - over_all_time))
117 |
118 | print("Train end.")
119 |
120 |
121 | if __name__ == "__main__":
122 | main()
123 |
--------------------------------------------------------------------------------
/train.sh:
--------------------------------------------------------------------------------
1 | python train.py --device 0 --ratio 0.01 --batch_size 64
2 | python train.py --device 0 --ratio 0.04 --batch_size 64
3 | python train.py --device 0 --ratio 0.1 --batch_size 64
4 | python train.py --device 0 --ratio 0.25 --batch_size 64
5 |
--------------------------------------------------------------------------------
/trained_models/1/state_dict.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/trained_models/1/state_dict.pth
--------------------------------------------------------------------------------
/trained_models/10/state_dict.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/trained_models/10/state_dict.pth
--------------------------------------------------------------------------------
/trained_models/25/state_dict.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/trained_models/25/state_dict.pth
--------------------------------------------------------------------------------
/trained_models/4/state_dict.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ICSResearch/TCS-Net/1361d0c3da0488b2d25e66cbf34b2a1265347985/trained_models/4/state_dict.pth
--------------------------------------------------------------------------------