├── image └── TCNet.png ├── utils.py ├── README.md ├── test.py ├── train.py ├── loss.py ├── token_transformer.py ├── solver.py ├── token_performer.py ├── dataset.py ├── Transformer.py ├── transformer_block.py ├── t2t_vit.py └── network.py /image/TCNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhangqiao970914/TCNet/HEAD/image/TCNet.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import time 4 | 5 | """ 6 | mkdir: 7 | Create a folder if "path" does not exist. 8 | """ 9 | def mkdir(path): 10 | if os.path.exists(path) == False: 11 | os.makedirs(path) 12 | 13 | """ 14 | write_doc: 15 | Write "content" into the file(".txt") in "path". 16 | """ 17 | def write_doc(path, content): 18 | with open(path, 'a') as file: 19 | file.write(content) 20 | 21 | """ 22 | get_time: 23 | Obtain the current time. 24 | """ 25 | def get_time(): 26 | torch.cuda.synchronize() 27 | return time.time() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TCNet 2 | * This project provides the code and results for 'TCNet:Co-salient Object Detection via Parallel Interaction of Transformers and CNNs', [IEEE TCSVT](https://ieeexplore.ieee.org/abstract/document/9968016) 3 | # Network Architecture 4 | ![image](https://github.com/zhangqiao970914/TCNet/blob/main/image/TCNet.png) 5 | # Result 6 | * [Result](https://pan.baidu.com/s/1L7s1Gi1RADzaKLwuSFITRg), 提取码:fn1p 7 | # model 8 | * [vgg16](https://pan.baidu.com/s/1jiTLv8oO8R7eVsdWPOf2ZQ), 提取码:aap0; 9 | * [t2t-vit-14](https://pan.baidu.com/s/1fejkFf_bRvTJkzJxfWQsYg), 提取码:yrlr 10 | # evaluation toolbox 11 | * [eval-co-sod](https://github.com/zzhanghub/eval-co-sod) 12 | # Others 13 | The code is based on [ICNet](https://github.com/blanclist/ICNet) and [VST](https://github.com/nnizhang/VST). 14 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from solver import Solver 3 | 4 | 5 | test_device = '0' 6 | test_batch_size = 10 7 | pred_root = './pred/' 8 | ckpt_path = '/user-data/T/ICNet/ckpt/Weights_16.pth' 9 | original_size = False 10 | test_num_thread = 4 11 | 12 | # An example to build "test_roots". 13 | test_roots = dict() 14 | datasets = ['CoSal2015','CoSOD3K','CoCA'] 15 | 16 | for dataset in datasets: 17 | roots = {'img': '/user-data/ICNet_Depth/Dataset/dataset_rgb/{}/images/'.format(dataset)} 18 | test_roots[dataset] = roots 19 | # ------------- end ------------- 20 | 21 | if __name__ == '__main__': 22 | os.environ['CUDA_VISIBLE_DEVICES'] = test_device 23 | solver = Solver() 24 | solver.test(roots=test_roots, 25 | ckpt_path=ckpt_path, 26 | pred_root=pred_root, 27 | num_thread=test_num_thread, 28 | batch_size=test_batch_size, 29 | original_size=original_size, 30 | pin=False) 31 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from solver import Solver 3 | 4 | vgg_path = '/user-data/T/ICNet/vgg16_feat.pth' 5 | ckpt_root = './ckpt/' 6 | train_init_epoch = 0 7 | train_end_epoch = 20 8 | train_device = '0' 9 | train_doc_path = './training.txt' 10 | learning_rate = 1e-5 11 | weight_decay = 1e-4 12 | train_batch_size = 10 13 | train_num_thread = 4 14 | 15 | # An example to build "train_roots". 16 | train_roots = {'img': '/user-data/COCO+Class/images/', 17 | 'gt': '/user-data/COCO+Class/gts/'} 18 | # ------------- end ------------- 19 | 20 | if __name__ == '__main__': 21 | os.environ['CUDA_VISIBLE_DEVICES'] = train_device 22 | solver = Solver() 23 | solver.train(roots=train_roots, 24 | vgg_path=vgg_path, 25 | init_epoch=train_init_epoch, 26 | end_epoch=train_end_epoch, 27 | learning_rate=learning_rate, 28 | batch_size=train_batch_size, 29 | weight_decay=weight_decay, 30 | ckpt_root=ckpt_root, 31 | doc_path=train_doc_path, 32 | num_thread=train_num_thread, 33 | pin=False) 34 | 35 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | """ 4 | IoU_loss: 5 | Compute IoU loss between predictions and ground-truths for training [Equation 3]. 6 | """ 7 | """ 8 | IoU_loss: 9 | Compute IoU loss between predictions and ground-truths for training [Equation 3]. 10 | """ 11 | def IoU_loss(preds_list, gt): 12 | preds = torch.cat(preds_list, dim=1) 13 | N, C, H, W = preds.shape 14 | min_tensor = torch.where(preds < gt, preds, gt) # shape=[N, C, H, W] 15 | max_tensor = torch.where(preds > gt, preds, gt) # shape=[N, C, H, W] 16 | min_sum = min_tensor.view(N, C, H * W).sum(dim=2) # shape=[N, C] 17 | max_sum = max_tensor.view(N, C, H * W).sum(dim=2) # shape=[N, C] 18 | loss = 1 - (min_sum / max_sum).mean() 19 | return loss 20 | 21 | 22 | def structure_loss(pred, mask): 23 | """ 24 | loss function (ref: F3Net-AAAI-2020) 25 | """ 26 | weit = 1 + 5 * torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15) - mask) 27 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduce='none') 28 | wbce = (weit * wbce).sum(dim=(2, 3)) / weit.sum(dim=(2, 3)) 29 | 30 | pred = torch.sigmoid(pred) 31 | inter = ((pred * mask) * weit).sum(dim=(2, 3)) 32 | union = ((pred + mask) * weit).sum(dim=(2, 3)) 33 | wiou = 1 - (inter + 1) / (union - inter + 1) 34 | return (wbce + wiou).mean() 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /token_transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd. 2 | # 3 | # This source code is licensed under the Clear BSD License 4 | # LICENSE file in the root directory of this file 5 | # All rights reserved. 6 | """ 7 | Take the standard Transformer as T2T Transformer 8 | """ 9 | import torch.nn as nn 10 | from timm.models.layers import DropPath 11 | from transformer_block import Mlp 12 | 13 | 14 | class Attention(nn.Module): 15 | def __init__(self, dim, num_heads=8, in_dim = None, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 16 | super().__init__() 17 | self.num_heads = num_heads 18 | self.in_dim = in_dim 19 | head_dim = dim // num_heads 20 | self.scale = qk_scale or head_dim ** -0.5 21 | 22 | self.qkv = nn.Linear(dim, in_dim * 3, bias=qkv_bias) 23 | self.attn_drop = nn.Dropout(attn_drop) 24 | self.proj = nn.Linear(in_dim, in_dim) 25 | self.proj_drop = nn.Dropout(proj_drop) 26 | 27 | def forward(self, x): 28 | B, N, C = x.shape 29 | 30 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.in_dim).permute(2, 0, 3, 1, 4) 31 | q, k, v = qkv[0], qkv[1], qkv[2] 32 | 33 | attn = (q @ k.transpose(-2, -1)) * self.scale 34 | attn = attn.softmax(dim=-1) 35 | attn = self.attn_drop(attn) 36 | 37 | x = (attn @ v).transpose(1, 2).reshape(B, N, self.in_dim) 38 | x = self.proj(x) 39 | x = self.proj_drop(x) 40 | 41 | # skip connection 42 | x = v.squeeze(1) + x # because the original x has different size with current x, use v to do skip connection 43 | 44 | return x 45 | 46 | 47 | class Token_transformer(nn.Module): 48 | 49 | def __init__(self, dim, in_dim, num_heads, mlp_ratio=1., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 50 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 51 | super().__init__() 52 | self.norm1 = norm_layer(dim) 53 | self.attn = Attention( 54 | dim, in_dim=in_dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 55 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 56 | self.norm2 = norm_layer(in_dim) 57 | self.mlp = Mlp(in_features=in_dim, hidden_features=int(in_dim*mlp_ratio), out_features=in_dim, act_layer=act_layer, drop=drop) 58 | 59 | def forward(self, x): 60 | x = self.attn(self.norm1(x)) 61 | x = x + self.drop_path(self.mlp(self.norm2(x))) 62 | return x 63 | -------------------------------------------------------------------------------- /solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from torch.optim import Adam 4 | import network 5 | from loss import structure_loss 6 | import numpy as np 7 | import cv2 8 | from numpy import mean 9 | from torch import nn 10 | from dataset import get_loader 11 | from os.path import join 12 | import random 13 | from utils import mkdir, write_doc, get_time 14 | 15 | 16 | class Solver(object): 17 | def __init__(self): 18 | self.ICNet = network.ICNet().cuda() 19 | 20 | def train(self, roots, init_epoch, end_epoch, learning_rate, batch_size, weight_decay, ckpt_root, doc_path, num_thread, pin, vgg_path=None): 21 | # Define Adam optimizer. 22 | optimizer = Adam(self.ICNet.parameters(), 23 | lr=learning_rate, 24 | weight_decay=weight_decay) 25 | 26 | # Load ".pth" to initialize model. 27 | if init_epoch == 0: 28 | # From pre-trained VGG16. 29 | self.ICNet.apply(network.weights_init) 30 | self.ICNet.vgg.vgg.load_state_dict(torch.load(vgg_path)) 31 | #self.ICNet.rgb.load_state_dict(torch.load('/media/lab509-2/data/ZQ/T/ICNet/80.7_T2T_ViT_t_14.pth.tar')) 32 | else: 33 | # From the existed checkpoint file. 34 | ckpt = torch.load(join(ckpt_root, 'Weights_{}.pth'.format(init_epoch))) 35 | self.ICNet.load_state_dict(ckpt['state_dict']) 36 | optimizer.load_state_dict(ckpt['optimizer']) 37 | 38 | # Define training dataloader. 39 | train_dataloader = get_loader(roots=roots, 40 | request=('img', 'gt'), 41 | shuffle=True, 42 | batch_size=batch_size, 43 | data_aug=True, 44 | num_thread=num_thread, 45 | pin=pin) 46 | 47 | # Train. 48 | self.ICNet.train() 49 | for epoch in range(init_epoch + 1, end_epoch): 50 | start_time = get_time() 51 | loss_sum = 0.0 52 | 53 | for data_batch in train_dataloader: 54 | self.ICNet.zero_grad() 55 | 56 | # Obtain a batch of data. 57 | img, gt = data_batch['img'], data_batch['gt'] 58 | img, gt = img.cuda(), gt.cuda() 59 | 60 | if len(img) == 1: 61 | # Skip this iteration when training batchsize is 1 due to Batch Normalization. 62 | continue 63 | 64 | # Forward. 65 | S_3_pred, S_4_pred, S_5_pred, S_g_pred = self.ICNet(image_group=img, is_training=True) 66 | 67 | 68 | # Compute IoU loss. 69 | loss1 = structure_loss(S_3_pred, gt) 70 | loss2 = structure_loss(S_4_pred, gt) 71 | loss3 = structure_loss(S_5_pred, gt) 72 | loss4 = structure_loss(S_g_pred, gt) 73 | loss = loss1 + loss3 + loss4 + loss4 74 | # Backward. 75 | loss.backward() 76 | optimizer.step() 77 | loss_sum = loss_sum + loss.detach().item() 78 | 79 | # Save the checkpoint file (".pth") after each epoch. 80 | mkdir(ckpt_root) 81 | torch.save({'optimizer': optimizer.state_dict(), 82 | 'state_dict': self.ICNet.state_dict()}, join(ckpt_root, 'Weights_{}.pth'.format(epoch))) 83 | 84 | # Compute average loss over the training dataset approximately. 85 | loss_mean = loss_sum / len(train_dataloader) 86 | end_time = get_time() 87 | 88 | # Record training information (".txt"). 89 | content = 'CkptIndex={}: TrainLoss={} LR={} Time={}\n'.format(epoch, loss_mean, learning_rate, end_time - start_time) 90 | write_doc(doc_path, content) 91 | 92 | def test(self, roots, ckpt_path, pred_root, num_thread, batch_size, original_size, pin): 93 | with torch.no_grad(): 94 | # Load the specified checkpoint file(".pth"). 95 | state_dict = torch.load(ckpt_path)['state_dict'] 96 | self.ICNet.load_state_dict(state_dict) 97 | self.ICNet.eval() 98 | 99 | # Get names of the test datasets. 100 | datasets = roots.keys() 101 | 102 | # Test ICNet on each dataset. 103 | for dataset in datasets: 104 | # Define test dataloader for the current test dataset. 105 | test_dataloader = get_loader(roots=roots[dataset], 106 | request=('img', 'file_name', 'group_name', 'size'), 107 | shuffle=False, 108 | data_aug=False, 109 | num_thread=num_thread, 110 | batch_size=batch_size, 111 | pin=pin) 112 | 113 | # Create a folder for the current test dataset for saving predictions. 114 | mkdir(pred_root) 115 | cur_dataset_pred_root = join(pred_root, dataset) 116 | mkdir(cur_dataset_pred_root) 117 | 118 | for data_batch in test_dataloader: 119 | # Obtain a batch of data. 120 | img = data_batch['img'].cuda() 121 | 122 | 123 | 124 | 125 | time_list = [] 126 | start_each = time.time() 127 | # Forward. 128 | preds = self.ICNet(image_group=img, 129 | 130 | is_training=False) 131 | time_each = time.time() - start_each 132 | 133 | print(time_each) 134 | #print("{}'s average Time Is : {:.1f} fps".format(1 / mean(time_list))) 135 | 136 | # Create a folder for the current batch according to its "group_name" for saving predictions. 137 | group_name = data_batch['group_name'][0] 138 | cur_group_pred_root = join(cur_dataset_pred_root, group_name) 139 | mkdir(cur_group_pred_root) 140 | 141 | # preds.shape: [N, 1, H, W]->[N, H, W, 1] 142 | preds = preds.permute(0, 2, 3, 1).cpu().numpy() 143 | 144 | # Make paths where predictions will be saved. 145 | pred_paths = list(map(lambda file_name: join(cur_group_pred_root, file_name + '.png'), data_batch['file_name'])) 146 | 147 | # For each prediction: 148 | for i, pred_path in enumerate(pred_paths): 149 | # Resize the prediction to the original size when "original_size == True". 150 | H, W = data_batch['size'][0][i], data_batch['size'][1][i] 151 | pred = cv2.resize(preds[i], (W, H)) if original_size else preds[i] 152 | 153 | # Save the prediction. 154 | cv2.imwrite(pred_path, np.array(pred * 255)) 155 | -------------------------------------------------------------------------------- /token_performer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Take Performer as T2T Transformer 3 | """ 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | 8 | 9 | class crosstask_performer(nn.Module): 10 | def __init__(self, dim, in_dim, head_cnt=1, kernel_ratio=0.5, dp1=0.1, dp2 = 0.1): 11 | super().__init__() 12 | self.emb = in_dim * head_cnt # we use 1, so it is no need here 13 | # self.kqv = nn.Linear(dim, 3 * self.emb) 14 | self.q_s = nn.Linear(dim, self.emb) 15 | self.k_s = nn.Linear(dim, self.emb) 16 | self.v_s = nn.Linear(dim, self.emb) 17 | 18 | self.q_c = nn.Linear(dim, self.emb) 19 | self.k_c = nn.Linear(dim, self.emb) 20 | self.v_c = nn.Linear(dim, self.emb) 21 | self.dp = nn.Dropout(dp1) 22 | self.proj_s = nn.Linear(self.emb, self.emb) 23 | self.proj_c = nn.Linear(self.emb, self.emb) 24 | self.head_cnt = head_cnt 25 | self.norm1_s = nn.LayerNorm(dim) 26 | self.norm1_c = nn.LayerNorm(dim) 27 | self.norm2_s = nn.LayerNorm(self.emb) 28 | self.norm2_c = nn.LayerNorm(self.emb) 29 | self.epsilon = 1e-8 # for stable in division 30 | 31 | self.mlp_s = nn.Sequential( 32 | nn.Linear(self.emb, 1 * self.emb), 33 | nn.GELU(), 34 | nn.Linear(1 * self.emb, self.emb), 35 | nn.Dropout(dp2), 36 | ) 37 | self.mlp_c = nn.Sequential( 38 | nn.Linear(self.emb, 1 * self.emb), 39 | nn.GELU(), 40 | nn.Linear(1 * self.emb, self.emb), 41 | nn.Dropout(dp2), 42 | ) 43 | 44 | self.m_s = int(self.emb * kernel_ratio) 45 | self.w_s = torch.randn(self.m_s, self.emb) 46 | self.w_s = nn.Parameter(nn.init.orthogonal_(self.w_s) * math.sqrt(self.m_s), requires_grad=False) 47 | 48 | self.m_c = int(self.emb * kernel_ratio) 49 | self.w_c = torch.randn(self.m_c, self.emb) 50 | self.w_c = nn.Parameter(nn.init.orthogonal_(self.w_c) * math.sqrt(self.m_c), requires_grad=False) 51 | 52 | def prm_exp_s(self, x): 53 | # part of the function is borrow from https://github.com/lucidrains/performer-pytorch 54 | # and Simo Ryu (https://github.com/cloneofsimo) 55 | # ==== positive random features for gaussian kernels ==== 56 | # x = (B, T, hs) 57 | # w = (m, hs) 58 | # return : x : B, T, m 59 | # SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)] 60 | # therefore return exp(w^Tx - |x|/2)/sqrt(m) 61 | xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, self.m_s) / 2 62 | wtx = torch.einsum('bti,mi->btm', x.float(), self.w_s) 63 | 64 | return torch.exp(wtx - xd) / math.sqrt(self.m_s) 65 | 66 | def prm_exp_c(self, x): 67 | # part of the function is borrow from https://github.com/lucidrains/performer-pytorch 68 | # and Simo Ryu (https://github.com/cloneofsimo) 69 | # ==== positive random features for gaussian kernels ==== 70 | # x = (B, T, hs) 71 | # w = (m, hs) 72 | # return : x : B, T, m 73 | # SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)] 74 | # therefore return exp(w^Tx - |x|/2)/sqrt(m) 75 | xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, self.m_c) / 2 76 | wtx = torch.einsum('bti,mi->btm', x.float(), self.w_c) 77 | 78 | return torch.exp(wtx - xd) / math.sqrt(self.m_c) 79 | 80 | def cross_attn(self, saliency_fea, contour_fea): 81 | k_s, q_s, v_s = self.k_s(saliency_fea), self.q_s(saliency_fea), self.v_s(saliency_fea) 82 | k_c, q_c, v_c = self.k_c(contour_fea), self.q_c(contour_fea), self.v_c(contour_fea) 83 | 84 | kp_s, qp_s = self.prm_exp_s(k_c), self.prm_exp_s(q_s) # (B, T, m), (B, T, m) 85 | D_s = torch.einsum('bti,bi->bt', qp_s, kp_s.sum(dim=1)).unsqueeze(dim=2) # (B, T, m) * (B, m) -> (B, T, 1) 86 | kptv_s = torch.einsum('bin,bim->bnm', v_c.float(), kp_s) # (B, emb, m) 87 | y_s = torch.einsum('bti,bni->btn', qp_s, kptv_s) / (D_s.repeat(1, 1, self.emb) + self.epsilon) # (B, T, emb)/Diag 88 | # skip connection 89 | # y_s = saliency_fea + self.dp(self.proj_s(y_s)) # same as token_transformer in T2T layer, use v as skip connection 90 | y_s = self.dp(self.proj_s(y_s)) # same as token_transformer in T2T layer, use v as skip connection 91 | 92 | kp_c, qp_c = self.prm_exp_c(k_s), self.prm_exp_c(q_c) # (B, T, m), (B, T, m) 93 | D_c = torch.einsum('bti,bi->bt', qp_c, kp_c.sum(dim=1)).unsqueeze(dim=2) # (B, T, m) * (B, m) -> (B, T, 1) 94 | kptv_c = torch.einsum('bin,bim->bnm', v_s.float(), kp_c) # (B, emb, m) 95 | y_c = torch.einsum('bti,bni->btn', qp_c, kptv_c) / (D_c.repeat(1, 1, self.emb) + self.epsilon) # (B, T, emb)/Diag 96 | # skip connection 97 | # y_c = contour_fea + self.dp(self.proj_c(y_c)) # same as token_transformer in T2T layer, use v as skip connection 98 | y_c = self.dp(self.proj_c(y_c)) # same as token_transformer in T2T layer, use v as skip connection 99 | 100 | return y_s, y_c 101 | 102 | def forward(self, saliency_fea, contour_fea): 103 | # cross task attention 104 | saliency_fea_fuse, contour_fea_fuse = self.cross_attn(self.norm1_s(saliency_fea), self.norm1_c(contour_fea)) 105 | 106 | saliency_fea = saliency_fea + saliency_fea_fuse 107 | contour_fea = contour_fea + contour_fea_fuse 108 | 109 | saliency_fea = saliency_fea + self.mlp_s(self.norm2_s(saliency_fea)) 110 | contour_fea = contour_fea + self.mlp_c(self.norm2_c(contour_fea)) 111 | 112 | return saliency_fea, contour_fea 113 | 114 | 115 | class Token_performer(nn.Module): 116 | def __init__(self, dim, in_dim, head_cnt=1, kernel_ratio=0.5, dp1=0.1, dp2 = 0.1): 117 | super().__init__() 118 | self.emb = in_dim * head_cnt # we use 1, so it is no need here 119 | self.kqv = nn.Linear(dim, 3 * self.emb) 120 | self.dp = nn.Dropout(dp1) 121 | self.proj = nn.Linear(self.emb, self.emb) 122 | self.head_cnt = head_cnt 123 | self.norm1 = nn.LayerNorm(dim) 124 | self.norm2 = nn.LayerNorm(self.emb) 125 | self.epsilon = 1e-8 # for stable in division 126 | 127 | self.mlp = nn.Sequential( 128 | nn.Linear(self.emb, 1 * self.emb), 129 | nn.GELU(), 130 | nn.Linear(1 * self.emb, self.emb), 131 | nn.Dropout(dp2), 132 | ) 133 | 134 | self.m = int(self.emb * kernel_ratio) 135 | self.w = torch.randn(self.m, self.emb) 136 | self.w = nn.Parameter(nn.init.orthogonal_(self.w) * math.sqrt(self.m), requires_grad=False) 137 | 138 | def prm_exp(self, x): 139 | # part of the function is borrow from https://github.com/lucidrains/performer-pytorch 140 | # and Simo Ryu (https://github.com/cloneofsimo) 141 | # ==== positive random features for gaussian kernels ==== 142 | # x = (B, T, hs) 143 | # w = (m, hs) 144 | # return : x : B, T, m 145 | # SM(x, y) = E_w[exp(w^T x - |x|/2) exp(w^T y - |y|/2)] 146 | # therefore return exp(w^Tx - |x|/2)/sqrt(m) 147 | xd = ((x * x).sum(dim=-1, keepdim=True)).repeat(1, 1, self.m) / 2 148 | wtx = torch.einsum('bti,mi->btm', x.float(), self.w) 149 | 150 | return torch.exp(wtx - xd) / math.sqrt(self.m) 151 | 152 | def single_attn(self, x): 153 | k, q, v = torch.split(self.kqv(x), self.emb, dim=-1) 154 | kp, qp = self.prm_exp(k), self.prm_exp(q) # (B, T, m), (B, T, m) 155 | D = torch.einsum('bti,bi->bt', qp, kp.sum(dim=1)).unsqueeze(dim=2) # (B, T, m) * (B, m) -> (B, T, 1) 156 | kptv = torch.einsum('bin,bim->bnm', v.float(), kp) # (B, emb, m) 157 | y = torch.einsum('bti,bni->btn', qp, kptv) / (D.repeat(1, 1, self.emb) + self.epsilon) # (B, T, emb)/Diag 158 | # skip connection 159 | # y = v + self.dp(self.proj(y)) # same as token_transformer in T2T layer, use v as skip connection 160 | y = self.dp(self.proj(y)) 161 | return y 162 | 163 | def forward(self, x): 164 | x = x + self.single_attn(self.norm1(x)) 165 | x = x + self.mlp(self.norm2(x)) 166 | return x 167 | 168 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image, ImageFile 3 | ImageFile.LOAD_TRUNCATED_IMAGES = True 4 | import torch 5 | import random 6 | import numpy as np 7 | from torch.utils import data 8 | import PIL.ImageOps 9 | import cv2 10 | import torchvision.transforms as transforms 11 | import matplotlib.pyplot as plt 12 | from os.path import join 13 | from os import listdir 14 | 15 | """ 16 | build_file_paths: 17 | When "file_names == None and group_names == None", 18 | traverse file folder to build "file_paths", "group_names", "file_names" and "indices". 19 | Otherwise, build "file_paths" based on given "file_names" and "group_names". 20 | """ 21 | def build_file_paths(base, group_names=None, file_names=None, suffix='.png'): 22 | if file_names == None and group_names == None: 23 | file_paths = [] 24 | group_names = [] 25 | file_names = [] 26 | indices = [] 27 | cur_group_end_index = 0 28 | for group_name in listdir(base): 29 | group_path = join(base, group_name) 30 | group_file_names = listdir(group_path) 31 | cur_group_end_index += len(group_file_names) 32 | 33 | # Save the ending index of current group into "indices", which is prepared for "Cosal_Sampler". 34 | indices.append(cur_group_end_index) 35 | 36 | for file_name in group_file_names: 37 | file_path = join(group_path, file_name) 38 | file_paths.append(file_path) 39 | group_names.append(group_name) 40 | file_names.append(file_name[:str(file_name).rfind('.')]) 41 | return file_paths, group_names, file_names, indices 42 | else: 43 | file_paths = list(map(lambda i: join(base, group_names[i], file_names[i] + suffix), range(len(file_names)))) 44 | return file_paths 45 | 46 | """ 47 | random_flip: 48 | Flip inputs horizontally with a possibility of 0.5. 49 | """ 50 | def random_flip(img, gt): 51 | datas = (img, gt) 52 | if random.random() > 0.5: 53 | datas = tuple(map(lambda data: transforms.functional.hflip(data) if data is not None else None, datas)) 54 | return datas 55 | 56 | 57 | class ImageData(data.Dataset): 58 | def __init__(self, roots, request, aug_transform=None, rgb_transform=None, gray_transform=None): 59 | if 'img' in request == False: 60 | raise Exception('\'img\' must be contained in \'request\'.') 61 | 62 | self.need_gt = True if 'gt' in request else False 63 | self.need_file_name = True if 'file_name' in request else False 64 | self.need_group_name = True if 'group_name' in request else False 65 | #self.need_sism = True if 'sism' in request else False 66 | self.need_size = True if 'size' in request else False 67 | 68 | img_paths, group_names, file_names, indices = build_file_paths(roots['img']) 69 | gt_paths = build_file_paths(roots['gt'], group_names, file_names) if self.need_gt else None 70 | #sism_paths = build_file_paths(roots['sism'], group_names, file_names) if self.need_sism else None 71 | 72 | self.img_paths = img_paths 73 | self.gt_paths = gt_paths 74 | #self.sism_paths = sism_paths 75 | self.file_names = file_names 76 | self.group_names = group_names 77 | self.indices = indices 78 | self.aug_transform = aug_transform 79 | self.rgb_transform = rgb_transform 80 | self.gray_transform = gray_transform 81 | 82 | def __getitem__(self, item): 83 | img = Image.open(self.img_paths[item]).convert('RGB') 84 | W, H = img.size 85 | gt = Image.open(self.gt_paths[item]).convert('L') if self.need_gt else None 86 | #sism = Image.open(self.sism_paths[item]).convert('L') if self.need_sism else None 87 | group_name = self.group_names[item] if self.need_group_name else None 88 | file_name = self.file_names[item] if self.need_file_name else None 89 | 90 | if self.aug_transform is not None: 91 | img, gt = self.aug_transform(img, gt) 92 | 93 | if self.rgb_transform is not None: 94 | img = self.rgb_transform(img) 95 | if self.gray_transform is not None and self.need_gt: 96 | gt = self.gray_transform(gt) 97 | #if self.gray_transform is not None and self.need_sism: 98 | # sism = self.gray_transform(sism) 99 | 100 | data_item = {} 101 | data_item['img'] = img 102 | if self.need_gt: data_item['gt'] = gt 103 | #if self.need_sism: data_item['sism'] = sism 104 | if self.need_file_name: data_item['file_name'] = file_name 105 | if self.need_group_name: data_item['group_name'] = group_name 106 | if self.need_size: data_item['size'] = (H, W) 107 | return data_item 108 | 109 | def __len__(self): 110 | return len(self.img_paths) 111 | 112 | 113 | """ 114 | Cosal_Sampler: 115 | Provide indices of each batch, ensuring that each batch data is extracted from the same image group (with the same category). 116 | """ 117 | class Cosal_Sampler(data.Sampler): 118 | def __init__(self, indices, shuffle, batch_size): 119 | self.indices = indices 120 | self.shuffle = shuffle 121 | self.batch_size = batch_size 122 | self.len = None 123 | self.batches_indices = None 124 | self.reset_batches_indices() 125 | 126 | def reset_batches_indices(self): 127 | batches_indices = [] 128 | start_idx = 0 129 | # For each image group (with same category): 130 | for end_idx in self.indices: 131 | # Initalize "group_indices". 132 | group_indices = list(range(start_idx, end_idx)) 133 | 134 | # Shuffle "group_indices" if needed. 135 | if self.shuffle: 136 | np.random.shuffle(group_indices) 137 | 138 | # Get the size of current image group. 139 | num = end_idx - start_idx 140 | 141 | # Split "group_indices" to multiple batches according to "self.batch_size", 142 | # then append the splited indices ("batch_indices") to "batches_indices". 143 | # Note that, when "self.batch_size == None", each image group is regarded as a batch ("batch_size = num"). 144 | idx = 0 145 | while idx < num: 146 | batch_size = num if self.batch_size == None else self.batch_size 147 | batch_indices = group_indices[idx:idx + batch_size] 148 | batches_indices.append(batch_indices) 149 | idx += batch_size 150 | start_idx = end_idx 151 | 152 | # Each entry of "batches_indices" is a list indicating indices of a specific batch, 153 | # but neighbouring entries basically belongs to the same image group (with same category). 154 | # Thus, shuffle "batches_indices" if needed. 155 | if self.shuffle: 156 | np.random.shuffle(batches_indices) 157 | 158 | self.len = len(batches_indices) 159 | self.batches_indices = batches_indices 160 | 161 | def __iter__(self): 162 | if self.shuffle: 163 | self.reset_batches_indices() 164 | return iter(self.batches_indices) 165 | 166 | def __len__(self): 167 | return self.len 168 | 169 | 170 | def get_loader(roots, request, batch_size, data_aug, shuffle, num_thread=4, pin=True): 171 | aug_transform = random_flip if data_aug else None 172 | rgb_transform = transforms.Compose([ 173 | transforms.Resize([224, 224]), 174 | transforms.ToTensor(), 175 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 176 | ]) 177 | gray_transform = transforms.Compose([ 178 | transforms.Resize([224, 224]), 179 | transforms.ToTensor() 180 | ]) 181 | dataset = ImageData(roots, request, aug_transform=aug_transform, rgb_transform=rgb_transform, gray_transform=gray_transform) 182 | cosal_sampler = Cosal_Sampler(indices=dataset.indices, shuffle=shuffle, batch_size=batch_size) 183 | data_loader = data.DataLoader(dataset=dataset, batch_sampler=cosal_sampler, num_workers=num_thread, pin_memory=pin) 184 | return data_loader 185 | -------------------------------------------------------------------------------- /Transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from .transformer_block import Block 4 | from timm.models.layers import trunc_normal_ 5 | 6 | 7 | class TransformerEncoder(nn.Module): 8 | def __init__(self, depth, num_heads, embed_dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 9 | drop_path_rate=0., norm_layer=nn.LayerNorm): 10 | super(TransformerEncoder, self).__init__() 11 | 12 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 13 | self.blocks = nn.ModuleList([ 14 | Block( 15 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 16 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 17 | for i in range(depth)]) 18 | 19 | self.rgb_norm = norm_layer(embed_dim) 20 | 21 | self.apply(self._init_weights) 22 | 23 | def _init_weights(self, m): 24 | if isinstance(m, nn.Linear): 25 | trunc_normal_(m.weight, std=.02) 26 | if isinstance(m, nn.Linear) and m.bias is not None: 27 | nn.init.constant_(m.bias, 0) 28 | elif isinstance(m, nn.LayerNorm): 29 | nn.init.constant_(m.bias, 0) 30 | nn.init.constant_(m.weight, 1.0) 31 | 32 | def forward(self, rgb_fea): 33 | 34 | for block in self.blocks: 35 | rgb_fea = block(rgb_fea) 36 | 37 | rgb_fea = self.rgb_norm(rgb_fea) 38 | 39 | return rgb_fea 40 | 41 | 42 | class token_TransformerEncoder(nn.Module): 43 | def __init__(self, depth, num_heads, embed_dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 44 | drop_path_rate=0., norm_layer=nn.LayerNorm): 45 | super(token_TransformerEncoder, self).__init__() 46 | 47 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 48 | self.blocks = nn.ModuleList([ 49 | Block( 50 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 51 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 52 | for i in range(depth)]) 53 | 54 | self.norm = norm_layer(embed_dim) 55 | self.apply(self._init_weights) 56 | 57 | def _init_weights(self, m): 58 | if isinstance(m, nn.Linear): 59 | trunc_normal_(m.weight, std=.02) 60 | if isinstance(m, nn.Linear) and m.bias is not None: 61 | nn.init.constant_(m.bias, 0) 62 | elif isinstance(m, nn.LayerNorm): 63 | nn.init.constant_(m.bias, 0) 64 | nn.init.constant_(m.weight, 1.0) 65 | 66 | def forward(self, fea): 67 | 68 | for block in self.blocks: 69 | fea = block(fea) 70 | 71 | fea = self.norm(fea) 72 | 73 | return fea 74 | 75 | class Transformer(nn.Module): 76 | def __init__(self, embed_dim=384, depth=14, num_heads=6, mlp_ratio=3.): 77 | super(Transformer, self).__init__() 78 | 79 | self.encoderlayer = TransformerEncoder(embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio) 80 | 81 | def forward(self, rgb_fea): 82 | 83 | rgb_memory = self.encoderlayer(rgb_fea) 84 | 85 | return rgb_memory 86 | 87 | 88 | class saliency_token_inference(nn.Module): 89 | def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 90 | super().__init__() 91 | 92 | self.norm = nn.LayerNorm(dim) 93 | self.num_heads = num_heads 94 | head_dim = dim // num_heads 95 | 96 | self.scale = qk_scale or head_dim ** -0.5 97 | 98 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 99 | self.k = nn.Linear(dim, dim, bias=qkv_bias) 100 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 101 | self.attn_drop = nn.Dropout(attn_drop) 102 | self.proj = nn.Linear(dim, dim) 103 | self.proj_drop = nn.Dropout(proj_drop) 104 | 105 | self.sigmoid = nn.Sigmoid() 106 | 107 | def forward(self, fea): 108 | B, N, C = fea.shape 109 | x = self.norm(fea) 110 | T_s, F_s = x[:, 0, :].unsqueeze(1), x[:, 1:-1, :] 111 | # T_s [B, 1, 384] F_s [B, 14*14, 384] 112 | 113 | q = self.q(F_s).reshape(B, N-2, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 114 | k = self.k(T_s).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 115 | v = self.v(T_s).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 116 | 117 | attn = (q @ k.transpose(-2, -1)) * self.scale 118 | 119 | attn = self.sigmoid(attn) 120 | attn = self.attn_drop(attn) 121 | 122 | infer_fea = (attn @ v).transpose(1, 2).reshape(B, N-2, C) 123 | infer_fea = self.proj(infer_fea) 124 | infer_fea = self.proj_drop(infer_fea) 125 | 126 | infer_fea = infer_fea + fea[:, 1:-1, :] 127 | return infer_fea 128 | 129 | 130 | class contour_token_inference(nn.Module): 131 | def __init__(self, dim, num_heads=1, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 132 | super().__init__() 133 | 134 | self.norm = nn.LayerNorm(dim) 135 | self.num_heads = num_heads 136 | head_dim = dim // num_heads 137 | 138 | self.scale = qk_scale or head_dim ** -0.5 139 | 140 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 141 | self.k = nn.Linear(dim, dim, bias=qkv_bias) 142 | self.v = nn.Linear(dim, dim, bias=qkv_bias) 143 | self.attn_drop = nn.Dropout(attn_drop) 144 | self.proj = nn.Linear(dim, dim) 145 | self.proj_drop = nn.Dropout(proj_drop) 146 | 147 | self.sigmoid = nn.Sigmoid() 148 | 149 | def forward(self, fea): 150 | B, N, C = fea.shape 151 | x = self.norm(fea) 152 | T_s, F_s = x[:, -1, :].unsqueeze(1), x[:, 1:-1, :] 153 | # T_s [B, 1, 384] F_s [B, 14*14, 384] 154 | 155 | q = self.q(F_s).reshape(B, N-2, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 156 | k = self.k(T_s).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 157 | v = self.v(T_s).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 158 | 159 | attn = (q @ k.transpose(-2, -1)) * self.scale 160 | 161 | # attn = attn.softmax(dim=-1) 162 | attn = self.sigmoid(attn) 163 | attn = self.attn_drop(attn) 164 | 165 | infer_fea = (attn @ v).transpose(1, 2).reshape(B, N-2, C) 166 | infer_fea = self.proj(infer_fea) 167 | infer_fea = self.proj_drop(infer_fea) 168 | 169 | infer_fea = infer_fea + fea[:, 1:-1, :] 170 | return infer_fea 171 | 172 | 173 | class token_Transformer(nn.Module): 174 | def __init__(self, embed_dim=384, depth=14, num_heads=6, mlp_ratio=3.): 175 | super(token_Transformer, self).__init__() 176 | 177 | self.norm = nn.LayerNorm(embed_dim) 178 | self.mlp_s = nn.Sequential( 179 | nn.Linear(embed_dim, embed_dim), 180 | nn.GELU(), 181 | nn.Linear(embed_dim, embed_dim), 182 | ) 183 | self.saliency_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 184 | self.contour_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 185 | self.encoderlayer = token_TransformerEncoder(embed_dim=embed_dim, depth=depth, num_heads=num_heads, mlp_ratio=mlp_ratio) 186 | self.saliency_token_pre = saliency_token_inference(dim=embed_dim, num_heads=1) 187 | self.contour_token_pre = contour_token_inference(dim=embed_dim, num_heads=1) 188 | 189 | def forward(self, rgb_fea): 190 | B, _, _ = rgb_fea.shape 191 | fea_1_16 = self.mlp_s(self.norm(rgb_fea)) # [B, 14*14, 384] 192 | 193 | saliency_tokens = self.saliency_token.expand(B, -1, -1) 194 | fea_1_16 = torch.cat((saliency_tokens, fea_1_16), dim=1) 195 | 196 | contour_tokens = self.contour_token.expand(B, -1, -1) 197 | fea_1_16 = torch.cat((fea_1_16, contour_tokens), dim=1) 198 | # fea_1_16 [B, 1 + 14*14 + 1, 384] 199 | 200 | fea_1_16 = self.encoderlayer(fea_1_16) 201 | # fea_1_16 [B, 1 + 14*14 + 1, 384] 202 | saliency_tokens = fea_1_16[:, 0, :].unsqueeze(1) 203 | contour_tokens = fea_1_16[:, -1, :].unsqueeze(1) 204 | 205 | saliency_fea_1_16 = self.saliency_token_pre(fea_1_16) 206 | contour_fea_1_16 = self.contour_token_pre(fea_1_16) 207 | return saliency_fea_1_16, fea_1_16, saliency_tokens, contour_fea_1_16, contour_tokens 208 | 209 | -------------------------------------------------------------------------------- /transformer_block.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd. 2 | # 3 | # This source code is licensed under the Clear BSD License 4 | # LICENSE file in the root directory of this file 5 | # All rights reserved. 6 | """ 7 | Borrow from timm(https://github.com/rwightman/pytorch-image-models) 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | import numpy as np 12 | from timm.models.layers import DropPath 13 | 14 | 15 | class Mlp(nn.Module): 16 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 17 | super().__init__() 18 | out_features = out_features or in_features 19 | hidden_features = hidden_features or in_features 20 | self.fc1 = nn.Linear(in_features, hidden_features) 21 | self.act = act_layer() 22 | self.fc2 = nn.Linear(hidden_features, out_features) 23 | self.drop = nn.Dropout(drop) 24 | 25 | def forward(self, x): 26 | x = self.fc1(x) 27 | x = self.act(x) 28 | x = self.drop(x) 29 | x = self.fc2(x) 30 | x = self.drop(x) 31 | return x 32 | 33 | 34 | class Attention(nn.Module): 35 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 36 | super().__init__() 37 | self.num_heads = num_heads 38 | head_dim = dim // num_heads 39 | 40 | self.scale = qk_scale or head_dim ** -0.5 41 | 42 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 43 | self.attn_drop = nn.Dropout(attn_drop) 44 | self.proj = nn.Linear(dim, dim) 45 | self.proj_drop = nn.Dropout(proj_drop) 46 | 47 | def forward(self, x): 48 | B, N, C = x.shape 49 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 50 | q, k, v = qkv[0], qkv[1], qkv[2] 51 | 52 | attn = (q @ k.transpose(-2, -1)) * self.scale 53 | attn = attn.softmax(dim=-1) 54 | attn = self.attn_drop(attn) 55 | 56 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 57 | x = self.proj(x) 58 | x = self.proj_drop(x) 59 | return x 60 | 61 | 62 | class MutualAttention(nn.Module): 63 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 64 | super().__init__() 65 | self.num_heads = num_heads 66 | head_dim = dim // num_heads 67 | 68 | self.scale = qk_scale or head_dim ** -0.5 69 | 70 | self.rgb_q = nn.Linear(dim, dim, bias=qkv_bias) 71 | self.rgb_k = nn.Linear(dim, dim, bias=qkv_bias) 72 | self.rgb_v = nn.Linear(dim, dim, bias=qkv_bias) 73 | self.rgb_proj = nn.Linear(dim, dim) 74 | 75 | self.depth_q = nn.Linear(dim, dim, bias=qkv_bias) 76 | self.depth_k = nn.Linear(dim, dim, bias=qkv_bias) 77 | self.depth_v = nn.Linear(dim, dim, bias=qkv_bias) 78 | self.depth_proj = nn.Linear(dim, dim) 79 | 80 | self.attn_drop = nn.Dropout(attn_drop) 81 | self.proj_drop = nn.Dropout(proj_drop) 82 | 83 | def forward(self, rgb_fea, depth_fea): 84 | B, N, C = rgb_fea.shape 85 | 86 | rgb_q = self.rgb_q(rgb_fea).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 87 | rgb_k = self.rgb_k(rgb_fea).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 88 | rgb_v = self.rgb_v(rgb_fea).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 89 | # q [B, nhead, N, C//nhead] 90 | 91 | depth_q = self.depth_q(depth_fea).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 92 | depth_k = self.depth_k(depth_fea).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 93 | depth_v = self.depth_v(depth_fea).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) 94 | 95 | # rgb branch 96 | rgb_attn = (rgb_q @ depth_k.transpose(-2, -1)) * self.scale 97 | rgb_attn = rgb_attn.softmax(dim=-1) 98 | rgb_attn = self.attn_drop(rgb_attn) 99 | 100 | rgb_fea = (rgb_attn @ depth_v).transpose(1, 2).reshape(B, N, C) 101 | rgb_fea = self.rgb_proj(rgb_fea) 102 | rgb_fea = self.proj_drop(rgb_fea) 103 | 104 | # depth branch 105 | depth_attn = (depth_q @ rgb_k.transpose(-2, -1)) * self.scale 106 | depth_attn = depth_attn.softmax(dim=-1) 107 | depth_attn = self.attn_drop(depth_attn) 108 | 109 | depth_fea = (depth_attn @ rgb_v).transpose(1, 2).reshape(B, N, C) 110 | depth_fea = self.depth_proj(depth_fea) 111 | depth_fea = self.proj_drop(depth_fea) 112 | 113 | return rgb_fea, depth_fea 114 | 115 | 116 | class Block(nn.Module): 117 | 118 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 119 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 120 | super().__init__() 121 | self.norm1 = norm_layer(dim) 122 | self.attn = Attention( 123 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 124 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 125 | self.norm2 = norm_layer(dim) 126 | mlp_hidden_dim = int(dim * mlp_ratio) 127 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 128 | 129 | def forward(self, x): 130 | x = x + self.drop_path(self.attn(self.norm1(x))) 131 | x = x + self.drop_path(self.mlp(self.norm2(x))) 132 | return x 133 | 134 | 135 | class MutualSelfBlock(nn.Module): 136 | 137 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 138 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 139 | super().__init__() 140 | 141 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 142 | mlp_hidden_dim = int(dim * mlp_ratio) 143 | 144 | # mutual attention 145 | self.norm1_rgb_ma = norm_layer(dim) 146 | self.norm2_depth_ma = norm_layer(dim) 147 | self.mutualAttn = MutualAttention( 148 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 149 | self.norm3_rgb_ma = norm_layer(dim) 150 | self.norm4_depth_ma = norm_layer(dim) 151 | self.mlp_rgb_ma = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 152 | self.mlp_depth_ma = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 153 | 154 | # rgb self attention 155 | self.norm1_rgb_sa = norm_layer(dim) 156 | self.selfAttn_rgb = Attention( 157 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 158 | self.norm2_rgb_sa = norm_layer(dim) 159 | self.mlp_rgb_sa = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 160 | 161 | # depth self attention 162 | self.norm1_depth_sa = norm_layer(dim) 163 | self.selfAttn_depth = Attention( 164 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 165 | self.norm2_depth_sa = norm_layer(dim) 166 | self.mlp_depth_sa = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 167 | 168 | def forward(self, rgb_fea, depth_fea): 169 | 170 | # mutual attention 171 | rgb_fea_fuse, depth_fea_fuse = self.drop_path(self.mutualAttn(self.norm1_rgb_ma(rgb_fea), self.norm2_depth_ma(depth_fea))) 172 | 173 | rgb_fea = rgb_fea + rgb_fea_fuse 174 | depth_fea = depth_fea + depth_fea_fuse 175 | 176 | rgb_fea = rgb_fea + self.drop_path(self.mlp_rgb_ma(self.norm3_rgb_ma(rgb_fea))) 177 | depth_fea = depth_fea + self.drop_path(self.mlp_depth_ma(self.norm4_depth_ma(depth_fea))) 178 | 179 | # rgb self attention 180 | rgb_fea = rgb_fea + self.drop_path(self.selfAttn_rgb(self.norm1_rgb_sa(rgb_fea))) 181 | rgb_fea = rgb_fea + self.drop_path(self.mlp_rgb_sa(self.norm2_rgb_sa(rgb_fea))) 182 | 183 | # depth self attention 184 | depth_fea = depth_fea + self.drop_path(self.selfAttn_depth(self.norm1_depth_sa(depth_fea))) 185 | depth_fea = depth_fea + self.drop_path(self.mlp_depth_sa(self.norm2_depth_sa(depth_fea))) 186 | 187 | return rgb_fea, depth_fea 188 | 189 | 190 | def get_sinusoid_encoding(n_position, d_hid): 191 | ''' Sinusoid position encoding table ''' 192 | 193 | def get_position_angle_vec(position): 194 | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] 195 | 196 | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) 197 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 198 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 199 | 200 | return torch.FloatTensor(sinusoid_table).unsqueeze(0) 201 | -------------------------------------------------------------------------------- /t2t_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) [2012]-[2021] Shanghai Yitu Technology Co., Ltd. 2 | # 3 | # This source code is licensed under the Clear BSD License 4 | # LICENSE file in the root directory of this file 5 | # All rights reserved. 6 | """ 7 | T2T-ViT 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | 12 | from timm.models.helpers import load_pretrained 13 | from timm.models.registry import register_model 14 | from timm.models.layers import trunc_normal_ 15 | import numpy as np 16 | from token_transformer import Token_transformer 17 | from token_performer import Token_performer 18 | from transformer_block import Block, get_sinusoid_encoding 19 | from timm.models import load_checkpoint 20 | 21 | 22 | def _cfg(url='', **kwargs): 23 | return { 24 | 'url': url, 25 | 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 26 | 'crop_pct': .9, 'interpolation': 'bicubic', 27 | 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 28 | 'classifier': 'head', 29 | **kwargs 30 | } 31 | 32 | default_cfgs = { 33 | 'T2t_vit_t_14': _cfg(), 34 | 'T2t_vit_t_19': _cfg(), 35 | 'T2t_vit_t_24': _cfg(), 36 | 'T2t_vit_14': _cfg(), 37 | 'T2t_vit_19': _cfg(), 38 | 'T2t_vit_24': _cfg(), 39 | 'T2t_vit_7': _cfg(), 40 | 'T2t_vit_10': _cfg(), 41 | 'T2t_vit_12': _cfg(), 42 | 'T2t_vit_14_resnext': _cfg(), 43 | 'T2t_vit_14_wide': _cfg(), 44 | } 45 | 46 | 47 | class T2T_module(nn.Module): 48 | """ 49 | Tokens-to-Token encoding module 50 | """ 51 | def __init__(self, img_size=224, tokens_type='performer', in_chans=3, embed_dim=768, token_dim=64): 52 | super().__init__() 53 | 54 | if tokens_type == 'transformer': 55 | print('adopt transformer encoder for tokens-to-token') 56 | self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) 57 | self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 58 | self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 59 | 60 | self.attention1 = Token_transformer(dim=in_chans * 7 * 7, in_dim=token_dim, num_heads=1, mlp_ratio=1.0) 61 | self.attention2 = Token_transformer(dim=token_dim * 3 * 3, in_dim=token_dim, num_heads=1, mlp_ratio=1.0) 62 | self.project = nn.Linear(token_dim * 3 * 3, embed_dim) 63 | 64 | elif tokens_type == 'performer': 65 | print('adopt performer encoder for tokens-to-token') 66 | self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) 67 | self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 68 | self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) 69 | 70 | # self.attention1 = Token_performer(dim=token_dim, in_dim=in_chans*7*7, kernel_ratio=0.5) 71 | # self.attention2 = Token_performer(dim=token_dim, in_dim=token_dim*3*3, kernel_ratio=0.5) 72 | self.attention1 = Token_performer(dim=in_chans*7*7, in_dim=token_dim, kernel_ratio=0.5) 73 | self.attention2 = Token_performer(dim=token_dim*3*3, in_dim=token_dim, kernel_ratio=0.5) 74 | self.project = nn.Linear(token_dim * 3 * 3, embed_dim) 75 | 76 | elif tokens_type == 'convolution': # just for comparison with conolution, not our model 77 | # for this tokens type, you need change forward as three convolution operation 78 | print('adopt convolution layers for tokens-to-token') 79 | self.soft_split0 = nn.Conv2d(3, token_dim, kernel_size=(7, 7), stride=(4, 4), padding=(2, 2)) # the 1st convolution 80 | self.soft_split1 = nn.Conv2d(token_dim, token_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 2nd convolution 81 | self.project = nn.Conv2d(token_dim, embed_dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)) # the 3rd convolution 82 | 83 | self.num_patches = (img_size // (4 * 2 * 2)) * (img_size // (4 * 2 * 2)) # there are 3 sfot split, stride are 4,2,2 seperately 84 | 85 | def forward(self, x): 86 | # step0: soft split 87 | x = self.soft_split0(x).transpose(1, 2) 88 | 89 | # x [B, 56*56, 147=7*7*3] 90 | # iteration1: restricturization/reconstruction 91 | x_1_4 = self.attention1(x) 92 | B, new_HW, C = x_1_4.shape 93 | x = x_1_4.transpose(1,2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) 94 | # iteration1: soft split 95 | x = self.soft_split1(x).transpose(1, 2) 96 | 97 | # iteration2: restricturization/reconstruction 98 | x_1_8 = self.attention2(x) 99 | B, new_HW, C = x_1_8.shape 100 | x = x_1_8.transpose(1, 2).reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW))) 101 | # iteration2: soft split 102 | x = self.soft_split2(x).transpose(1, 2) 103 | 104 | # final tokens 105 | x = self.project(x) 106 | 107 | return x, x_1_8, x_1_4 108 | 109 | 110 | class T2T_ViT(nn.Module): 111 | def __init__(self, img_size=224, tokens_type='performer', in_chans=3, num_classes=1000, embed_dim=768, depth=12, 112 | num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., 113 | drop_path_rate=0., norm_layer=nn.LayerNorm): 114 | super().__init__() 115 | self.num_classes = num_classes 116 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 117 | 118 | self.tokens_to_token = T2T_module( 119 | img_size=img_size, tokens_type=tokens_type, in_chans=in_chans, embed_dim=embed_dim) 120 | num_patches = self.tokens_to_token.num_patches 121 | 122 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 123 | self.pos_embed = nn.Parameter(data=get_sinusoid_encoding(n_position=num_patches + 1, d_hid=embed_dim), requires_grad=False) 124 | self.pos_drop = nn.Dropout(p=drop_rate) 125 | 126 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 127 | self.blocks = nn.ModuleList([ 128 | Block( 129 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 130 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 131 | for i in range(depth)]) 132 | self.norm = norm_layer(embed_dim) 133 | 134 | # Classifier head 135 | self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() 136 | 137 | trunc_normal_(self.cls_token, std=.02) 138 | self.apply(self._init_weights) 139 | 140 | def _init_weights(self, m): 141 | if isinstance(m, nn.Linear): 142 | trunc_normal_(m.weight, std=.02) 143 | if isinstance(m, nn.Linear) and m.bias is not None: 144 | nn.init.constant_(m.bias, 0) 145 | elif isinstance(m, nn.LayerNorm): 146 | nn.init.constant_(m.bias, 0) 147 | nn.init.constant_(m.weight, 1.0) 148 | 149 | @torch.jit.ignore 150 | def no_weight_decay(self): 151 | return {'cls_token'} 152 | 153 | def get_classifier(self): 154 | return self.head 155 | 156 | def reset_classifier(self, num_classes, global_pool=''): 157 | self.num_classes = num_classes 158 | self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() 159 | 160 | def forward_features(self, x): 161 | B = x.shape[0] 162 | x, x_1_8, x_1_4 = self.tokens_to_token(x) 163 | 164 | cls_tokens = self.cls_token.expand(B, -1, -1) 165 | x = torch.cat((cls_tokens, x), dim=1) 166 | x = x + self.pos_embed 167 | x = self.pos_drop(x) 168 | 169 | # T2T-ViT backbone 170 | for blk in self.blocks: 171 | x = blk(x) 172 | 173 | x = self.norm(x) 174 | # return x[:, 0] 175 | return x[:, 1:, :], x_1_8, x_1_4 176 | 177 | def forward(self, x): 178 | x, x_1_8, x_1_4 = self.forward_features(x) 179 | #x_pred = self.head(x) 180 | return x, x_1_8, x_1_4 181 | 182 | 183 | @register_model 184 | def T2t_vit_t_14(pretrained=True, **kwargs): # adopt transformers for tokens to token 185 | # if pretrained: 186 | # kwargs.setdefault('qk_scale', 384 ** -0.5) 187 | 188 | # model = T2T_ViT(tokens_type='transformer', embed_dim=384, depth=14, num_heads=6, mlp_ratio=3., **kwargs) 189 | model = T2T_ViT(tokens_type='transformer', embed_dim=384, depth=14, num_heads=6, mlp_ratio=3.) 190 | model.default_cfg = default_cfgs['T2t_vit_t_14'] 191 | args = kwargs['args'] 192 | if pretrained: 193 | load_checkpoint(model, args.pretrained_model, use_ema=True) 194 | print('Model loaded from {}'.format(args.pretrained_model)) 195 | # load_pretrained( 196 | # model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 197 | return model 198 | 199 | @register_model 200 | def T2t_vit_t_19(pretrained=False, **kwargs): # adopt transformers for tokens to token 201 | if pretrained: 202 | kwargs.setdefault('qk_scale', 448 ** -0.5) 203 | model = T2T_ViT(tokens_type='transformer', embed_dim=448, depth=19, num_heads=7, mlp_ratio=3., **kwargs) 204 | model.default_cfg = default_cfgs['T2t_vit_t_19'] 205 | if pretrained: 206 | load_pretrained( 207 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 208 | return model 209 | 210 | @register_model 211 | def T2t_vit_t_24(pretrained=False, **kwargs): # adopt transformers for tokens to token 212 | if pretrained: 213 | kwargs.setdefault('qk_scale', 512 ** -0.5) 214 | model = T2T_ViT(tokens_type='transformer', embed_dim=512, depth=24, num_heads=8, mlp_ratio=3., **kwargs) 215 | model.default_cfg = default_cfgs['T2t_vit_t_24'] 216 | if pretrained: 217 | load_pretrained( 218 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 219 | return model 220 | 221 | 222 | @register_model 223 | def T2t_vit_7(pretrained=False, **kwargs): # adopt performer for tokens to token 224 | if pretrained: 225 | kwargs.setdefault('qk_scale', 256 ** -0.5) 226 | model = T2T_ViT(tokens_type='performer', embed_dim=256, depth=7, num_heads=4, mlp_ratio=2., **kwargs) 227 | model.default_cfg = default_cfgs['T2t_vit_7'] 228 | if pretrained: 229 | load_pretrained( 230 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 231 | return model 232 | 233 | @register_model 234 | def T2t_vit_10(pretrained=False, **kwargs): # adopt performer for tokens to token 235 | if pretrained: 236 | kwargs.setdefault('qk_scale', 256 ** -0.5) 237 | model = T2T_ViT(tokens_type='performer', embed_dim=256, depth=10, num_heads=4, mlp_ratio=2., **kwargs) 238 | model.default_cfg = default_cfgs['T2t_vit_10'] 239 | if pretrained: 240 | load_pretrained( 241 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 242 | return model 243 | 244 | @register_model 245 | def T2t_vit_12(pretrained=False, **kwargs): # adopt performer for tokens to token 246 | if pretrained: 247 | kwargs.setdefault('qk_scale', 256 ** -0.5) 248 | model = T2T_ViT(tokens_type='performer', embed_dim=256, depth=12, num_heads=4, mlp_ratio=2., **kwargs) 249 | model.default_cfg = default_cfgs['T2t_vit_12'] 250 | if pretrained: 251 | load_pretrained( 252 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 253 | return model 254 | 255 | 256 | @register_model 257 | def T2t_vit_t_14(pretrained=True, **kwargs): # adopt transformers for tokens to token 258 | # if pretrained: 259 | # kwargs.setdefault('qk_scale', 384 ** -0.5) 260 | 261 | # model = T2T_ViT(tokens_type='transformer', embed_dim=384, depth=14, num_heads=6, mlp_ratio=3., **kwargs) 262 | model = T2T_ViT(tokens_type='transformer', embed_dim=384, depth=14, num_heads=6, mlp_ratio=3.) 263 | model.default_cfg = default_cfgs['T2t_vit_t_14'] 264 | args = kwargs['args'] 265 | if pretrained: 266 | load_checkpoint(model, args.pretrained_model, use_ema=True) 267 | print('Model loaded from {}'.format(args.pretrained_model)) 268 | # load_pretrained( 269 | # model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 270 | return model 271 | @register_model 272 | def T2t_vit_19(pretrained=False, **kwargs): # adopt performer for tokens to token 273 | if pretrained: 274 | kwargs.setdefault('qk_scale', 448 ** -0.5) 275 | model = T2T_ViT(tokens_type='performer', embed_dim=448, depth=19, num_heads=7, mlp_ratio=3., **kwargs) 276 | model.default_cfg = default_cfgs['T2t_vit_19'] 277 | if pretrained: 278 | load_pretrained( 279 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 280 | return model 281 | 282 | @register_model 283 | def T2t_vit_24(pretrained=False, **kwargs): # adopt performer for tokens to token 284 | if pretrained: 285 | kwargs.setdefault('qk_scale', 512 ** -0.5) 286 | model = T2T_ViT(tokens_type='performer', embed_dim=512, depth=24, num_heads=8, mlp_ratio=3., **kwargs) 287 | model.default_cfg = default_cfgs['T2t_vit_24'] 288 | if pretrained: 289 | load_pretrained( 290 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 291 | return model 292 | 293 | 294 | # rexnext and wide structure 295 | @register_model 296 | def T2t_vit_14_resnext(pretrained=False, **kwargs): 297 | if pretrained: 298 | kwargs.setdefault('qk_scale', 384 ** -0.5) 299 | model = T2T_ViT(tokens_type='performer', embed_dim=384, depth=14, num_heads=32, mlp_ratio=3., **kwargs) 300 | model.default_cfg = default_cfgs['T2t_vit_14_resnext'] 301 | if pretrained: 302 | load_pretrained( 303 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 304 | return model 305 | 306 | @register_model 307 | def T2t_vit_14_wide(pretrained=False, **kwargs): 308 | if pretrained: 309 | kwargs.setdefault('qk_scale', 512 ** -0.5) 310 | model = T2T_ViT(tokens_type='performer', embed_dim=768, depth=4, num_heads=12, mlp_ratio=3., **kwargs) 311 | model.default_cfg = default_cfgs['T2t_vit_14_wide'] 312 | if pretrained: 313 | load_pretrained( 314 | model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) 315 | return model -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | import math 4 | import numpy as np 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from torch.nn import init 8 | from os.path import join 9 | from t2t_vit import T2t_vit_t_14 10 | from transformer_block import Block, get_sinusoid_encoding 11 | from timm.models.layers import trunc_normal_ 12 | from token_performer import Token_performer 13 | 14 | np.set_printoptions(suppress=True, threshold=1e5) 15 | import argparse 16 | 17 | 18 | def resize(input, target_size=(224, 224)): 19 | return F.interpolate(input, (target_size[0], target_size[1]), mode='bilinear', align_corners=True) 20 | 21 | 22 | """ 23 | weights_init: 24 | Weights initialization. 25 | """ 26 | 27 | 28 | def weights_init(module): 29 | if isinstance(module, nn.Conv2d): 30 | init.normal_(module.weight, 0, 0.01) 31 | if module.bias is not None: 32 | init.constant_(module.bias, 0) 33 | elif isinstance(module, nn.BatchNorm2d): 34 | init.constant_(module.weight, 1) 35 | init.constant_(module.bias, 0) 36 | 37 | 38 | """" 39 | VGG16: 40 | VGG16 backbone. 41 | """ 42 | 43 | 44 | class VGG16(nn.Module): 45 | def __init__(self): 46 | super(VGG16, self).__init__() 47 | layers = [] 48 | in_channel = 3 49 | vgg_out_channels = (64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M') 50 | for out_channel in vgg_out_channels: 51 | if out_channel == 'M': 52 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 53 | else: 54 | conv2d = nn.Conv2d(in_channel, out_channel, 3, 1, 1) 55 | layers += [conv2d, nn.ReLU(inplace=True)] 56 | in_channel = out_channel 57 | self.vgg = nn.ModuleList(layers) 58 | self.table = {'conv1_1': 0, 'conv1_2': 2, 'conv1_2_mp': 4, 59 | 'conv2_1': 5, 'conv2_2': 7, 'conv2_2_mp': 9, 60 | 'conv3_1': 10, 'conv3_2': 12, 'conv3_3': 14, 'conv3_3_mp': 16, 61 | 'conv4_1': 17, 'conv4_2': 19, 'conv4_3': 21, 'conv4_3_mp': 23, 62 | 'conv5_1': 24, 'conv5_2': 26, 'conv5_3': 28, 'conv5_3_mp': 30, 'final': 31} 63 | 64 | def forward(self, feats, start_layer_name, end_layer_name): 65 | start_idx = self.table[start_layer_name] 66 | end_idx = self.table[end_layer_name] 67 | for idx in range(start_idx, end_idx): 68 | feats = self.vgg[idx](feats) 69 | return feats 70 | 71 | 72 | class token_TransformerEncoder(nn.Module): 73 | def __init__(self, depth, num_heads, embed_dim, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., 74 | attn_drop_rate=0., 75 | drop_path_rate=0., norm_layer=nn.LayerNorm): 76 | super(token_TransformerEncoder, self).__init__() 77 | 78 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 79 | self.blocks = nn.ModuleList([ 80 | Block( 81 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 82 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 83 | for i in range(depth)]) 84 | 85 | self.norm = norm_layer(embed_dim) 86 | self.apply(self._init_weights) 87 | 88 | def _init_weights(self, m): 89 | if isinstance(m, nn.Linear): 90 | trunc_normal_(m.weight, std=.02) 91 | if isinstance(m, nn.Linear) and m.bias is not None: 92 | nn.init.constant_(m.bias, 0) 93 | elif isinstance(m, nn.LayerNorm): 94 | nn.init.constant_(m.bias, 0) 95 | nn.init.constant_(m.weight, 1.0) 96 | 97 | def forward(self, fea): 98 | 99 | for block in self.blocks: 100 | fea = block(fea) 101 | 102 | fea = self.norm(fea) 103 | 104 | return fea 105 | 106 | 107 | class token_Transformer(nn.Module): 108 | def __init__(self, embed_dim=384, depth=14, num_heads=6, mlp_ratio=3.): 109 | super(token_Transformer, self).__init__() 110 | 111 | self.norm = nn.LayerNorm(embed_dim) 112 | self.mlp_s = nn.Sequential( 113 | nn.Linear(embed_dim, embed_dim), 114 | nn.GELU(), 115 | nn.Linear(embed_dim, embed_dim), 116 | ) 117 | 118 | self.encoderlayer = token_TransformerEncoder(embed_dim=embed_dim, depth=depth, num_heads=num_heads, 119 | mlp_ratio=mlp_ratio) 120 | 121 | def forward(self, rgb_fea): 122 | B, _, _ = rgb_fea.shape 123 | fea_1_16 = self.mlp_s(self.norm(rgb_fea)) # [B, 14*14, 384] 124 | fea_1_16 = self.encoderlayer(fea_1_16) 125 | return fea_1_16 126 | 127 | 128 | class decoder_module(nn.Module): 129 | def __init__(self, dim=384, token_dim=64, img_size=224, ratio=8, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), 130 | fuse=True): 131 | super(decoder_module, self).__init__() 132 | 133 | self.project = nn.Linear(token_dim, token_dim * kernel_size[0] * kernel_size[1]) 134 | self.upsample = nn.Fold(output_size=(img_size // ratio, img_size // ratio), kernel_size=kernel_size, 135 | stride=stride, padding=padding) 136 | self.fuse = fuse 137 | if self.fuse: 138 | self.concatFuse = nn.Sequential( 139 | nn.Linear(token_dim * 2, token_dim), 140 | nn.GELU(), 141 | nn.Linear(token_dim, token_dim), 142 | ) 143 | self.att = Token_performer(dim=token_dim, in_dim=token_dim, kernel_ratio=0.5) 144 | 145 | # project input feature to 64 dim 146 | self.norm = nn.LayerNorm(dim) 147 | self.mlp = nn.Sequential( 148 | nn.Linear(dim, token_dim), 149 | nn.GELU(), 150 | nn.Linear(token_dim, token_dim), 151 | ) 152 | 153 | def forward(self, dec_fea, enc_fea=None): 154 | 155 | if self.fuse: 156 | # from 384 to 64 157 | dec_fea = self.mlp(self.norm(dec_fea)) 158 | 159 | # [1] token upsampling by the proposed reverse T2T module 160 | dec_fea = self.project(dec_fea) 161 | # [B, H*W, token_dim*kernel_size*kernel_size] 162 | dec_fea = self.upsample(dec_fea.transpose(1, 2)) 163 | B, C, _, _ = dec_fea.shape 164 | dec_fea = dec_fea.view(B, C, -1).transpose(1, 2) 165 | # [B, HW, C] 166 | 167 | if self.fuse: 168 | # [2] fuse encoder fea and decoder fea 169 | dec_fea = self.concatFuse(torch.cat([dec_fea, enc_fea], dim=2)) 170 | dec_fea = self.att(dec_fea) 171 | 172 | return dec_fea 173 | 174 | 175 | 176 | class MCM(nn.Module): 177 | def __init__(self, in_dim): 178 | super(MCM, self).__init__() 179 | # Co-attention 180 | self.query_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1, stride=1, padding=0) 181 | self.key_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1, stride=1, padding=0) 182 | self.scale = 1.0 / (in_dim ** 0.5) 183 | self.value_conv = nn.Conv2d(in_dim, in_dim, kernel_size=1, stride=1, padding=0) 184 | self.conv6 = nn.Conv2d(in_dim, in_dim, kernel_size=1, stride=1, padding=0) 185 | 186 | self.query = nn.Linear(512, 512) 187 | self.key = nn.Linear(512, 512) 188 | self.value = nn.Linear(512, 512) 189 | 190 | self.conv512_64 = nn.Conv2d(512, 64, 1) 191 | self.conv1 = nn.Sequential( 192 | nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1, bias=True), 193 | nn.ReLU(inplace=True), 194 | ) 195 | self.conv2 = nn.Sequential( 196 | nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1, bias=True), 197 | nn.ReLU(inplace=True), 198 | ) 199 | self.conv3 = nn.Sequential( 200 | nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1, bias=True), 201 | nn.ReLU(inplace=True), 202 | ) 203 | 204 | def forward(self, xc, xt): 205 | # xc[10,512,16,16] xt[10,256,384] 206 | B, C, H, W = xc.size() 207 | c_value = self.value_conv(xc) 208 | c_query = self.query_conv(xc).view(B, -1, W * H).permute(0, 2, 1) # [B,HW,C] 209 | c_key = self.key_conv(xc).view(B, -1, W * H).permute(1, 0, 2) # [B,C,HW] 210 | c_query = c_query.contiguous().view(-1, C) # [BHW,C] 211 | c_key = c_key.contiguous().view(C, -1) 212 | c_xw = torch.matmul(c_query, c_key) # [BHW,BHW] 213 | c_xw = c_xw.view(B * H * W, B, H * W) # [BHW, B, HW] 214 | c_max = torch.max(c_xw, dim=-1)[0] # [BHW, B] 215 | c_avg = torch.mean(c_xw, dim=-1) # [BHW, B] 216 | c_co = c_max + c_avg # [BHW, B] 217 | c_co = c_co.sum(-1) # [BWH] 218 | c_co = c_co.view(B, -1) * self.scale 219 | c_co = F.softmax(c_co, dim=-1) # [B,HW] 220 | c_co = c_co.view(B, H, W).unsqueeze(1) # [B,1,16,16] 221 | # 局部自注意力 222 | Bt, HW, Ct = xt.size() #[10,196,512] 223 | t_value = xt.transpose(1, 2).reshape(Bt, Ct, int(np.sqrt(HW)), int(np.sqrt(HW))) 224 | t_query = self.query(xt).contiguous().view(-1, Ct) 225 | t_key = self.key(xt).contiguous().view(-1, Ct) 226 | t_xw = torch.matmul(t_query, t_key.permute(1, 0)) 227 | t_xw = t_xw.view(Bt * HW, Bt, HW) 228 | t_max = torch.max(t_xw, dim=-1)[0] # [BHW, B] 229 | t_avg = torch.mean(t_xw, dim=-1) # [BHW, B] 230 | t_co = t_max + t_avg # [BHW, B] 231 | t_co = t_co.sum(-1) # [BWH] 232 | t_co = t_co.view(Bt, -1) * self.scale 233 | t_co = F.softmax(t_co, dim=-1) 234 | t_co = t_co.view(Bt, int(np.sqrt(HW)), int(np.sqrt(HW))).unsqueeze(1) # [ B,1,16,16] 235 | # 全局与局部 236 | ct_xw = torch.matmul(c_query, t_key.permute(1, 0)) 237 | ct_xw = ct_xw.view(B * H * W, B, H * W) 238 | 239 | ct_max = torch.max(ct_xw, dim=-1)[0] # [BHW, B] 240 | ct_avg = torch.mean(ct_xw, dim=-1) # [BHW, B] 241 | ct_co = ct_max + ct_avg # [BHW, B] 242 | ct_co = ct_co.sum(-1) # [BWH] 243 | ct_co = ct_co.view(B, -1) * self.scale 244 | ct_co = F.softmax(ct_co, dim=-1) # [B,HW] 245 | ct_co = ct_co.view(B, H, W).unsqueeze(1) # [B,1,16,16] 246 | # 局部与全局 247 | tc_xw = torch.matmul(t_query, c_key) 248 | tc_xw = tc_xw.view(Bt * HW, Bt, HW) 249 | tc_max = torch.max(tc_xw, dim=-1)[0] # [BHW, B] 250 | tc_avg = torch.mean(tc_xw, dim=-1) # [BHW, B] 251 | tc_co = tc_max + tc_avg # [BHW, B] 252 | tc_co = tc_co.sum(-1) # [BWH] 253 | tc_co = tc_co.view(Bt, -1) * self.scale 254 | tc_co = F.softmax(tc_co, dim=-1) 255 | tc_co = tc_co.view(Bt, int(np.sqrt(HW)), int(np.sqrt(HW))).unsqueeze(1) # [ B,1,16,16] 256 | # 相乘部分 257 | c_co = self.conv512_64(c_co * c_value) # [64,16,16] 258 | t_co = self.conv512_64(t_co * t_value) # [512,16,16] 259 | 260 | ct_co = self.conv512_64(ct_co * c_value) # [64,16,16] 261 | 262 | tc_co = self.conv512_64(tc_co * t_value) # [512,16,16] 263 | 264 | c_ct_cat = self.conv1(torch.cat([c_co, ct_co], 1)) 265 | t_ct_cat = self.conv2(torch.cat([t_co, tc_co], 1)) 266 | 267 | out_final = self.conv3(torch.cat([c_ct_cat, t_ct_cat], 1)) 268 | 269 | return out_final 270 | 271 | 272 | 273 | 274 | class SpatialAttention(nn.Module): 275 | def __init__(self, kernel_size=7): 276 | super(SpatialAttention, self).__init__() 277 | 278 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 279 | padding = 3 if kernel_size == 7 else 1 280 | 281 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 282 | self.sigmoid = nn.Sigmoid() 283 | 284 | def forward(self, x): 285 | avg_out = torch.mean(x, dim=1, keepdim=True) 286 | max_out, _ = torch.max(x, dim=1, keepdim=True) 287 | x = torch.cat([avg_out, max_out], dim=1) 288 | x = self.conv1(x) 289 | return self.sigmoid(x) 290 | 291 | 292 | class ChannelAttention(nn.Module): 293 | def __init__(self, in_planes, ratio=16): 294 | super(ChannelAttention, self).__init__() 295 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 296 | self.max_pool = nn.AdaptiveMaxPool2d(1) 297 | 298 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 299 | self.relu1 = nn.ReLU() 300 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 301 | 302 | self.sigmoid = nn.Sigmoid() 303 | 304 | def forward(self, x): 305 | avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) 306 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 307 | out = avg_out + max_out 308 | return self.sigmoid(out) 309 | 310 | 311 | class CCM(nn.Module): 312 | def __init__(self): 313 | super(CCM, self).__init__() 314 | 315 | self.attention_feature0 = nn.Sequential(nn.Conv2d(64 * 2, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), 316 | nn.PReLU(), 317 | nn.Conv2d(64, 2, kernel_size=3, padding=1)) 318 | self.conv2 = nn.Sequential(nn.Conv2d(64 * 2, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), 319 | nn.PReLU(), 320 | nn.Conv2d(64, 64, kernel_size=3, padding=1)) 321 | self.ca = ChannelAttention(64) 322 | self.sa = SpatialAttention() 323 | self.sigmoid = nn.Sigmoid() 324 | 325 | def forward(self, x, y, z): 326 | # x:CNN y:Transformer z:co-attention 327 | z = self.sigmoid(z) 328 | 329 | Gx = x * z 330 | Gy = y * z 331 | 332 | G0 = self.attention_feature0(torch.cat((Gx, Gy), dim=1)) 333 | G0 = F.adaptive_avg_pool2d(torch.sigmoid(G0), 1) 334 | c0_Gx = G0[:, 0, :, :].unsqueeze(1).repeat(1, 64, 1, 1) * Gx 335 | c0_Gy = G0[:, 1, :, :].unsqueeze(1).repeat(1, 64, 1, 1) * Gy 336 | 337 | temp_y = c0_Gy.mul(self.ca(c0_Gy)) 338 | temp_x = c0_Gx.mul(self.sa(c0_Gx)) 339 | final = self.conv2(torch.cat([temp_y, temp_x], 1)) 340 | return final 341 | 342 | 343 | class GCPD(nn.Module): 344 | def __init__(self, embed_dim=384, token_dim=64, img_size=224): 345 | 346 | super(GCPD, self).__init__() 347 | self.img_size = img_size 348 | 349 | self.decoder0 = decoder_module(dim=embed_dim, token_dim=token_dim, img_size=img_size, ratio=8, 350 | kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), fuse=False) 351 | self.decoder1 = decoder_module(dim=embed_dim, token_dim=token_dim, img_size=img_size, ratio=8, 352 | kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), fuse=True) 353 | self.decoder2 = decoder_module(dim=embed_dim, token_dim=token_dim, img_size=img_size, ratio=4, 354 | kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), fuse=True) 355 | self.token_trans6 = token_Transformer(embed_dim=384, depth=8, num_heads=8, mlp_ratio=3.) 356 | self.token_trans5 = token_Transformer(embed_dim=384, depth=6, num_heads=6, mlp_ratio=3.) 357 | self.token_trans4 = token_Transformer(embed_dim=384, depth=4, num_heads=4, mlp_ratio=3.) 358 | self.token_trans3 = token_Transformer(embed_dim=384, depth=2, num_heads=2, mlp_ratio=3.) 359 | 360 | self.fc = nn.Linear(64, 384) 361 | self.fc_192_384 = nn.Linear(192, 384) 362 | self.fc_448_384 = nn.Linear(448, 384) 363 | self.upsample16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=False) 364 | self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=False) 365 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False) 366 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 367 | self.downsample2 = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False) 368 | self.downsample4 = nn.Upsample(scale_factor=0.25, mode='bilinear', align_corners=False) 369 | 370 | self.pre_1_16 = nn.Linear(384, 1) 371 | self.pre_1_8 = nn.Linear(384, 1) 372 | self.pre_1_4 = nn.Linear(384, 1) 373 | 374 | for m in self.modules(): 375 | classname = m.__class__.__name__ 376 | if classname.find('Conv') != -1: 377 | nn.init.xavier_uniform_(m.weight), 378 | if m.bias is not None: 379 | nn.init.constant_(m.bias, 0) 380 | elif classname.find('Linear') != -1: 381 | nn.init.xavier_uniform_(m.weight), 382 | if m.bias is not None: 383 | nn.init.constant_(m.bias, 0) 384 | elif classname.find('BatchNorm') != -1: 385 | nn.init.constant_(m.weight, 1) 386 | nn.init.constant_(m.bias, 0) 387 | 388 | def forward(self, x, y, z): 389 | # x[B, 64, 14, 14];y[B, 64, 28, 28];z[B, 64, 56, 56] sa[B, 64, 14, 14] 390 | x5 = x 391 | y4 = self.downsample2(y) 392 | z3 = self.downsample4(z) 393 | feat_t = torch.cat([x5, y4, z3], 1) # [B, 192, 14, 14] 394 | B, Ct, Ht, Wt = feat_t.shape 395 | feat_t = feat_t.view(B, Ct, -1).transpose(1, 2) 396 | feat_t = self.fc_192_384(feat_t) # [B, 14*14, 384] 397 | Tt = self.token_trans6(feat_t) 398 | Tt = Tt + feat_t 399 | mask_g = self.pre_1_16(Tt) 400 | mask_g = mask_g.transpose(1, 2).reshape(B, 1, Ht, Wt) 401 | 402 | #####################第五层######################## 403 | B, Cx, Hx, Wx = x.shape 404 | x_fea = x.view(B, Cx, -1).transpose(1, 2) # [B, 14*14, Cx] 405 | tx = torch.cat([Tt, x_fea], dim=2) 406 | tx = self.fc_448_384(tx) 407 | # x_fea = self.fc(x_fea) # [B, 14*14, 384] 408 | tx = self.token_trans5(tx) 409 | tx = tx + self.fc(x_fea) 410 | mask_x = self.pre_1_16(tx) 411 | mask_x = mask_x.transpose(1, 2).reshape(B, 1, Hx, Wx) 412 | 413 | #####################第四层###################### 414 | B, Cy, Hy, Wy = y.shape 415 | y_fea = y.view(B, Cy, -1).transpose(1, 2) # [B, 28*28, Cy] 416 | xy = self.decoder1(tx, y_fea) # [B, 28*28, 64] 417 | ty = self.fc(xy) # [B, 28*28, 384] 418 | ty = self.token_trans4(ty) 419 | ty = ty + self.fc(y_fea) 420 | mask_y = self.pre_1_8(ty) 421 | mask_y = mask_y.transpose(1, 2).reshape(B, 1, Hy, Wy) 422 | 423 | ####################第三层###################### 424 | B, Cz, Hz, Wz = z.shape 425 | z_fea = z.view(B, Cz, -1).transpose(1, 2) # [B, 56*56, Cz] 426 | yz = self.decoder2(ty, z_fea) # [B, 56*56, Cz] 427 | tz = self.fc(yz) 428 | tz = self.token_trans3(tz) 429 | tz = tz + self.fc(z_fea) 430 | mask_z = self.pre_1_4(tz) 431 | mask_z = mask_z.transpose(1, 2).reshape(B, 1, Hz, Wz) 432 | return self.upsample4(mask_z), self.upsample8(mask_y), self.upsample16(mask_x), self.upsample16(mask_g) 433 | 434 | 435 | class ICNet(nn.Module): 436 | def __init__(self, channel=64): 437 | super(ICNet, self).__init__() 438 | # Backbone 439 | self.vgg = VGG16() 440 | parser = argparse.ArgumentParser() 441 | parser.add_argument('--pretrained_model', default='/hy-tmp/TCNet-main/80.7_T2T_ViT_t_14.pth.tar', 442 | type=str, help='load Pretrained model') 443 | args = parser.parse_args() 444 | self.rgb_backbone = T2t_vit_t_14(pretrained=True, args=args) 445 | 446 | 447 | # 共同注意力部分 448 | self.fc = nn.Linear(384, 512) 449 | self.conv512_256 = nn.Conv2d(512, 256, 1) 450 | self.conv512_128 = nn.Conv2d(512, 128, 1) 451 | self.conv512_64 = nn.Conv2d(512, 64, 1) 452 | self.conv256_64 = nn.Conv2d(256, 64, 1) 453 | self.conv128_64 = nn.Conv2d(128, 64, 1) 454 | self.conv64_1 = nn.Conv2d(64, 1, 1) 455 | self.sa = MCM(512) 456 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False) 457 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) 458 | self.sig = nn.Sigmoid() 459 | 460 | #CCM 461 | self.ccm5 = CCM() 462 | self.ccm4 = CCM() 463 | self.ccm3 = CCM() 464 | # GCPD 465 | self.decoder = GCPD(embed_dim=384, token_dim=64, img_size=224) 466 | 467 | def forward(self, image_group, is_training): 468 | rgb_fea_1_16, rgb_fea_1_8, rgb_fea_1_4 = self.rgb_backbone(image_group) 469 | rgb_fea_1_16 = self.fc(rgb_fea_1_16) 470 | Bt, HW, Ct = rgb_fea_1_16.size() 471 | t_value = rgb_fea_1_16.transpose(1, 2).reshape(Bt, Ct, int(np.sqrt(HW)), int(np.sqrt(HW))) 472 | 473 | # Extract features from the VGG16 backbone. 474 | conv1_2 = self.vgg(image_group, 'conv1_1', 'conv1_2_mp') # shape=[N, 64, 224, 224] 475 | conv2_2 = self.vgg(conv1_2, 'conv1_2_mp', 'conv2_2_mp') # shape=[N, 128, 112, 112] 476 | conv3_3 = self.vgg(conv2_2, 'conv2_2_mp', 'conv3_3_mp') # shape=[N, 256, 56, 56] 477 | conv4_3 = self.vgg(conv3_3, 'conv3_3_mp', 'conv4_3_mp') # shape=[N, 512, 28, 28] 478 | conv5_3 = self.vgg(conv4_3, 'conv4_3_mp', 'conv5_3_mp') # shape=[N, 512, 14, 14] 479 | sa = self.sa(conv5_3, rgb_fea_1_16) # [64,14,14] 480 | 481 | x5 = self.ccm5(self.conv512_64(conv5_3), self.conv512_64(t_value), sa) 482 | x4 = self.ccm4(self.conv512_64(conv4_3), self.conv512_64(self.upsample2(t_value)), self.upsample2(sa)) 483 | x3 = self.ccm3(self.conv256_64(conv3_3), self.conv512_64(self.upsample4(t_value)), self.upsample4(sa)) 484 | 485 | S_3_pred, S_4_pred, S_5_pred, S_g_pred = self.decoder(x5, x4, x3) 486 | 487 | # Return predicted co-saliency maps. 488 | if is_training: 489 | return S_3_pred, S_4_pred, S_5_pred, S_g_pred 490 | else: 491 | preds = S_3_pred 492 | return preds 493 | --------------------------------------------------------------------------------