├── Patch_D.png ├── spatial-net.jpeg ├── MSRF_CLASSIFICATION.jpeg ├── README.md ├── dataloader.py ├── train.py ├── msrfc.py └── patchnet.py /Patch_D.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NoviceMAn-prog/SA-Net-MSRF-CNet-and-PatchNet-for-Writer-Identification/HEAD/Patch_D.png -------------------------------------------------------------------------------- /spatial-net.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NoviceMAn-prog/SA-Net-MSRF-CNet-and-PatchNet-for-Writer-Identification/HEAD/spatial-net.jpeg -------------------------------------------------------------------------------- /MSRF_CLASSIFICATION.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NoviceMAn-prog/SA-Net-MSRF-CNet-and-PatchNet-for-Writer-Identification/HEAD/MSRF_CLASSIFICATION.jpeg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Exploiting Multi-Scale Fusion, Spatial Attention and Patch Interaction Techniques for Text-Independent Writer Identification 2 | This repository provides the code for our paper titled "Exploiting Multi-Scale Fusion, Spatial Attention and Patch Interaction Techniques for Text-Independent Writer Identification" Accepted at Asian Conference on Pattern Recognition 2021([arxiv version](https://arxiv.org/abs/2111.10605)) 3 | ## 2.) Overview 4 | ### 2.1.)Introduction 5 | Text independent writer identification is a challenging problem that differentiates between different handwriting styles to decide the author of the handwritten text. 6 | Earlier writer identification relied on handcrafted features to reveal pieces of differences between writers. Recent work with the advent of convolutional neural network, 7 | deep learning-based methods have evolved. In this paper, three different deep learning techniques - spatial attention mechanism, multi-scale feature fusion and patch-based CNN 8 | were proposed to effectively capture the difference between each writer's handwriting. Our methods are based on the hypothesis that handwritten text images have specific spatial 9 | regions which are more unique to a writer's style, multi-scale features propagate characteristic features with respect to individual writers and patch-based features give more 10 | general and robust representations that helps to discriminate handwriting from different writers. The proposed methods outperforms various state-of-the-art methodologies on word 11 | and page-level writer identification methods on the CVL, Firemaker, CERUG-EN datasets and give comparable performance on the IAM dataset. 12 | 13 | ## 2.2.) Spatial Attention Unit in SA-Net 14 | ![](spatial-net.jpeg) 15 | 16 | ## 2.3.) MSRF-Classification Network Architecture 17 | ![](MSRF_CLASSIFICATION.jpeg) 18 | 19 | ## 2.4.) PatchNet Architecture 20 | ![](Patch_D.png) 21 | 22 | ## 3.) Training and Testing 23 | ## 3.1)Data Preparation 24 | The code for downloading and using it for training and testing is embedded in `python train.py` for CERUG-EN and Firemaker Dataset,the training and testing split for IAM dataset is provided in IAM-train.txt and IAM-test.txt 25 | 26 | ## 3.2)Training 27 | The architecture for MSRF-CNet, SA-Net is defined in msrfc.py and PatchNet architecture is in patchnet.py, for training change the dataset as required in the train.py, the testing code is also in train.py, enjoy! 28 | Run the script as: 29 | `python train.py` 30 | 31 | ## 4.) Citation 32 | Please cite our paper if you find the work useful: 33 | 34 | ``` 35 | @article{srivastava2021exploiting, 36 | title={Exploiting Multi-Scale Fusion, Spatial Attention and Patch Interaction Techniques for Text-Independent Writer Identification}, 37 | author={Srivastava, Abhishek and Chanda, Sukalpa and Pal, Umapada}, 38 | journal={arXiv preprint arXiv:2111.10605}, 39 | year={2021} 40 | } 41 | 42 | ``` 43 | ## 5.) FAQ 44 | Please feel free to contact me if you need any advice or guidance in using this work ([E-mail](abhisheksrivastava2397@gmail.com)) 45 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | from scipy import misc 5 | import torch.utils.data as data 6 | import torch 7 | from torchvision.transforms import Compose, ToTensor 8 | import random 9 | import imageio 10 | 11 | class DatasetFromFolder(data.Dataset): 12 | def __init__(self,dataset,foldername,labelfolder,imgtype='png',scale_size=(64,128), 13 | is_training=True): 14 | super(DatasetFromFolder,self).__init__() 15 | 16 | self.is_training = is_training 17 | 18 | self.imgtype = imgtype 19 | self.scale_size = scale_size 20 | self.folder = foldername 21 | self.dataset = dataset 22 | 23 | if self.dataset == 'CERUG-EN': 24 | self.cerug = True 25 | else: 26 | self.cerug = False 27 | 28 | self.labelidx_name = labelfolder + dataset + 'writer_index_table.pickle' 29 | print(self.labelidx_name) 30 | 31 | self.imglist = self._get_image_list(self.folder) 32 | 33 | self.idlist = self._get_all_identity() 34 | 35 | self.idx_tab = self._convert_identity2index(self.labelidx_name) 36 | 37 | self.num_writer = len(self.idx_tab) 38 | 39 | #------------ print info. 40 | print('-'*10) 41 | print('loading dataset %s with images: %d'%(dataset,len(self.imglist))) 42 | print('number of writer is: %d'%len(self.idx_tab)) 43 | print('-*'*10) 44 | 45 | #self.trans = True 46 | 47 | 48 | 49 | # convert to idx for neural network 50 | def _convert_identity2index(self,savename): 51 | if os.path.exists(savename): 52 | with open(savename,'rb') as fp: 53 | identity_idx = pickle.load(fp) 54 | else: 55 | #''' 56 | identity_idx = {} 57 | for idx,ids in enumerate(self.idlist): 58 | identity_idx[ids] = idx 59 | 60 | with open(savename,'wb') as fp: 61 | pickle.dump(identity_idx,fp) 62 | #''' 63 | 64 | return identity_idx 65 | 66 | # get all writer identity 67 | def _get_all_identity(self): 68 | writer_list = [] 69 | for img in self.imglist: 70 | writerId = self._get_identity(img) 71 | writer_list.append(writerId) 72 | writer_list=list(set(writer_list)) 73 | return writer_list 74 | #027-a02-046-05-04.png 75 | def _get_identity(self,fname): 76 | if self.cerug: 77 | return fname.split('_')[0] 78 | else: 79 | return fname.split('-')[0] 80 | 81 | # get all image list 82 | def _get_image_list(self,folder): 83 | flist = os.listdir(folder) 84 | imglist = [] 85 | for img in flist: 86 | if img.endswith(self.imgtype): 87 | imglist.append(img) 88 | return imglist 89 | 90 | def transform(self): 91 | return Compose([ToTensor(),]) 92 | 93 | def resize(self,image): 94 | h,w = image.shape[:2] 95 | ratio_h = float(self.scale_size[0])/float(h) 96 | ratio_w = float(self.scale_size[1])/float(w) 97 | 98 | if ratio_h < ratio_w: 99 | ratio = ratio_h 100 | hfirst = False 101 | else: 102 | ratio = ratio_w 103 | hfirst = True 104 | 105 | nh = int(ratio * h) 106 | nw = int(ratio * w) 107 | 108 | imre = misc.imresize(image,(nh,nw)) 109 | 110 | imre = 255 - imre 111 | ch,cw = imre.shape[:2] 112 | if self.is_training: 113 | new_img = np.zeros(self.scale_size) 114 | dy = int((self.scale_size[0]-ch)) 115 | dx = int((self.scale_size[1]-cw)) 116 | dy = random.randint(0,dy) 117 | dx = random.randint(0,dx) 118 | else: 119 | new_img = np.zeros(self.scale_size) 120 | dy = int((self.scale_size[0]-ch)/2.0) 121 | dx = int((self.scale_size[1]-cw)/2.0) 122 | 123 | #new_img = np.zeros(self.scale_size) 124 | #dy = int((self.scale_size[0]-ch)/2.0) 125 | #dx = int((self.scale_size[1]-cw)/2.0) 126 | 127 | imre = imre.astype('float') 128 | 129 | new_img[dy:dy+ch,dx:dx+cw] = imre 130 | #new_img /= 256.0 131 | #print(new_img.shape) 132 | 133 | return new_img,hfirst 134 | 135 | 136 | def __getitem__(self,index): 137 | 138 | imgfile = self.imglist[index] 139 | writer = self.idx_tab[self._get_identity(imgfile)] 140 | 141 | image = misc.imread(self.folder + imgfile,mode='L') 142 | image,hfirst = self.resize(image) 143 | image = image / 255.0 144 | 145 | image = self.transform()(image) 146 | writer = torch.from_numpy(np.array(writer)) 147 | 148 | return image,writer,imgfile 149 | 150 | def __len__(self): 151 | return len(self.imglist) 152 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.optim as optim 5 | from torch.utils.data import DataLoader 6 | from torch.optim import lr_scheduler 7 | 8 | import dataloader as dset 9 | import numpy as np 10 | import os 11 | from msrfc import * 12 | from msrfc import VGGnet_spatial # Importing SA-Net 13 | from patchnet import * # Importing PatchNet 14 | from pthflops import count_ops 15 | class LabelSomCE(nn.Module): 16 | def __init__(self): 17 | super().__init__() 18 | 19 | def forward(self,x,target,smoothing=0.1): 20 | confidence = 1.0 - smoothing 21 | logprobs = F.log_softmax(x,dim=-1) 22 | nll_loss = - logprobs.gather(dim=-1,index=target.unsqueeze(1)) 23 | nll_loss = nll_loss.squeeze(1) 24 | smooth_loss = -logprobs.mean(dim=-1) 25 | loss = confidence * nll_loss + smoothing * smooth_loss 26 | 27 | return loss.mean() 28 | 29 | def download(folder,thetarfile): 30 | import urllib.request 31 | import tarfile 32 | ftpstream = urllib.request.urlopen(thetarfile) 33 | thetarfile = tarfile.open(fileobj=ftpstream, mode="r|gz") 34 | thetarfile.extractall(folder) 35 | thetarfile.close() 36 | 37 | def download_cerug(folder): 38 | thetarfile = "https://www.ai.rug.nl/~sheng/writerset/CERUG-EN-train-images.tar.gz" 39 | download(folder,thetarfile) 40 | thetarfile = "https://www.ai.rug.nl/~sheng/writerset/CERUG-EN-test-images.tar.gz" 41 | download(folder,thetarfile) 42 | 43 | def download_firemaker(folder): 44 | thetarfile = "https://www.ai.rug.nl/~sheng/writerset/Firemaker-train-images.tar.gz" 45 | download(folder,thetarfile) 46 | thetarfile = "https://www.ai.rug.nl/~sheng/writerset/Firemaker-test-images.tar.gz" 47 | download(folder,thetarfile) 48 | 49 | class DeepWriter_Train: 50 | def __init__(self,dataset='Firemaker',imgtype='png',mode='vertical'): 51 | 52 | self.dataset = dataset 53 | self.folder = dataset 54 | #self.labelfolder = 'dataset/' 55 | if self.dataset == 'IAM': 56 | self.folder = 'data' 57 | if not os.path.exists(self.folder): 58 | if dataset == 'CERUG-EN': 59 | download_cerug(dataset) 60 | elif dataset == 'Firemaker': 61 | download_firemaker(dataset) 62 | else: 63 | print('****** Warning: the dataset %s does not existed!******'%dataset) 64 | print('Please go to the following website to check how to download the dataset:') 65 | print('https://www.ai.rug.nl/~sheng/writeridataset.html') 66 | print('*'*20) 67 | raise ValueError('Dataset: %s does not existed!'%dataset) 68 | 69 | self.labelfolder = self.folder 70 | self.train_folder = self.folder+'/train/' 71 | self.test_folder = self.folder+'/test/' 72 | 73 | self.imgtype=imgtype 74 | self.mode = mode 75 | self.device = 'cuda' #make sure to change this 76 | self.scale_size=(64,128) 77 | 78 | if self.device == 'cuda': 79 | torch.backends.cudnn.benchmark = True 80 | 81 | if self.dataset == 'CVL': 82 | self.imgtype = 'tif' 83 | 84 | self.model_dir = 'model' 85 | if not os.path.exists(self.model_dir): 86 | #raise ValueError('Model directory: %s does not existed'%self.model_dir) 87 | os.mkdir(self.model_dir)#raise ValueError('Model directory: %s does not existed'%self.model_dir) 88 | 89 | basedir = 'MSRF_firemaker_'+self.dataset+'_model_'+self.mode+'_aug_16' 90 | self.logfile= basedir + '.log' 91 | self.modelfile = basedir 92 | self.batch_size = 16 93 | 94 | train_set = dset.DatasetFromFolder(dataset=self.dataset, 95 | labelfolder = self.labelfolder, 96 | foldername=self.train_folder, 97 | imgtype=self.imgtype, 98 | scale_size=self.scale_size, 99 | is_training = True) 100 | 101 | self.training_data_loader = DataLoader(dataset=train_set, num_workers=0, 102 | batch_size=self.batch_size, shuffle=True) 103 | 104 | test_set = dset.DatasetFromFolder(dataset=self.dataset, 105 | labelfolder = self.labelfolder, 106 | foldername=self.test_folder,imgtype=self.imgtype, 107 | scale_size=self.scale_size, 108 | is_training = False) 109 | 110 | self.testing_data_loader = DataLoader(dataset=test_set, num_workers=0, 111 | batch_size=self.batch_size, shuffle=False) 112 | 113 | num_class = train_set.num_writer 114 | #self.model = dfrag(1,num_classes=train_set.num_writer).to(self.device) Use this for PatchNet 115 | #self.model = VGGnet_spatial(1,train_set.num_writer).to(self.device) Use this for SA-Net 116 | self.model = MSF(1,train_set.num_writer).to(self.device) 117 | pytorch_total_params = sum(p.numel() for p in self.model.parameters()) 118 | print('Number of parameters is ',pytorch_total_params) 119 | self.criterion = nn.CrossEntropyLoss() 120 | self.criterion = LabelSomCE() 121 | self.optimizer = optim.Adam(self.model.parameters(),lr=0.0001,weight_decay=1e-4) 122 | self.scheduler = lr_scheduler.StepLR(self.optimizer,step_size=10,gamma=0.5) 123 | self.page = np.zeros((train_set.num_writer,train_set.num_writer)) 124 | #print('CALCULATING FLOPS') 125 | #inp = torch.randn(1, 1, 64,128) 126 | #print(count_ops(self.model, inp)) 127 | #print('#'*10,'DONE','#'*10) 128 | #self.test(48,during_train=True) 129 | def train(self,epoch): 130 | self.model.train() 131 | losstotal = [] 132 | 133 | for iteration,batch in enumerate(self.training_data_loader,1): 134 | inputs = batch[0].to(self.device).float() 135 | target = batch[1].type(torch.long).to(self.device) 136 | 137 | #self.optimizer.zero_grad() Uncomment this piece of code when using the SA-Net 138 | #train_loss = 0 139 | #logs = self.model(inputs) 140 | #train_loss= train_loss + self.criterion(logs,target) 141 | 142 | self.optimizer.zero_grad() 143 | train_loss = 0 144 | logits_list,combined_logits = self.model(inputs) 145 | for logs in logits_list: 146 | train_loss= train_loss + self.criterion(logs,target) 147 | 148 | losstotal.append(train_loss.item()) 149 | train_loss.backward() 150 | self.optimizer.step() 151 | 152 | 153 | with open(self.logfile,'a') as fp: 154 | fp.write('Training epoch %d avg loss is: %.6f\n'%(epoch,np.mean(losstotal))) 155 | print('Traing epoch:',epoch,' avg loss is:',np.mean(losstotal)) 156 | def test(self,epoch,during_train=True): 157 | self.model.eval() 158 | 159 | if not during_train: 160 | self.load_model(epoch) 161 | 162 | top1 = 0 163 | top5 = 0 164 | ntotal=0 165 | dummy_writer = -1 166 | for iteration,batch in enumerate(self.testing_data_loader,1): 167 | inputs = batch[0].to(self.device).float() 168 | target = batch[1].to(self.device).long() 169 | logits_list,logits = self.model(inputs) 170 | #logits = self.model(inputs) Uncomment when Using SA-Net 171 | 172 | for n in range(logits.shape[0]): 173 | with torch.no_grad(): 174 | if dummy_writer == -1: 175 | dummy_writer = target[0] 176 | if target[n]!=dummy_writer: 177 | dummy_writer = target[n] 178 | dummy_pred = logits[n].cpu().numpy() 179 | self.page[dummy_writer] += dummy_pred 180 | 181 | 182 | 183 | logist = logits/3 184 | res = self.accuracy(logits,target,topk=(1,5)) 185 | top1 += res[0] 186 | top5 += res[1] 187 | 188 | ntotal += inputs.size(0) 189 | 190 | page_acc = 0 191 | for i in range(self.page.shape[0]): 192 | writer_page = np.argmax(self.page[i]) 193 | if writer_page==i: 194 | page_acc +=1 195 | 196 | top1 /= float(ntotal) 197 | top5 /= float(ntotal) 198 | page_acc /= float(self.page.shape[0]) 199 | print('Testing on epoch: %d has accuracy: top1: %.2f top5: %.2f'%(epoch,top1*100,top5*100)) 200 | print('Testing pages on epoch: %d has page accuracy: top1: %.2f'%(epoch,page_acc*100)) 201 | with open(self.logfile,'a') as fp: 202 | fp.write('Testing epoch %d accuracy is: top1: %.2f top5: %.2f\n'%(epoch,top1*100,top5*100)) 203 | fp.write(('Testing pages on epoch: %d has page accuracy: top1: %.2f'%(epoch,page_acc*100))) 204 | def check_exists(self,epoch): 205 | model_out_path = self.model_dir + '/' + self.modelfile + '-model_epoch_{}.pth'.format(epoch) 206 | return os.path.exists(model_out_path) 207 | 208 | def checkpoint(self,epoch): 209 | model_out_path = self.model_dir + '/' + self.modelfile + '-model_epoch_{}.pth'.format(epoch) 210 | torch.save(self.model.state_dict(),model_out_path) 211 | 212 | 213 | def load_model(self,epoch): 214 | model_out_path = self.model_dir + '/' + self.modelfile + '-model_epoch_{}.pth'.format(epoch) 215 | print(model_out_path) 216 | self.model.load_state_dict(torch.load(model_out_path)) 217 | print('Load model successful') 218 | 219 | def train_loops(self,start_epoch,num_epoch): 220 | #if self.check_exists(num_epoch): return 221 | if start_epoch > 0: 222 | self.load_model(start_epoch-1) 223 | 224 | for epoch in range(start_epoch,num_epoch): 225 | 226 | self.train(epoch) 227 | self.checkpoint(epoch) 228 | self.test(epoch) 229 | self.scheduler.step() 230 | 231 | def accuracy(self,output,target,topk=(1,)): 232 | with torch.no_grad(): 233 | maxk = max(topk) 234 | _,pred = output.topk(maxk,1,True,True) 235 | pred = pred.t() 236 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 237 | 238 | res = [] 239 | for k in topk: 240 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 241 | res.append(correct_k.data.cpu().numpy()) 242 | 243 | return res 244 | 245 | 246 | 247 | 248 | if __name__ == '__main__': 249 | 250 | modelist = ['vertical','horzontal'] 251 | mode = modelist[0] 252 | print('results on iam by msrf net') 253 | mod = DeepWriter_Train(dataset='IAM',mode=mode) # change the dataset to Firemaker or CERUG-EN 254 | mod.train_loops(0,50) 255 | -------------------------------------------------------------------------------- /msrfc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | import types 5 | import math 6 | import torch.nn.functional as F 7 | from math import sqrt 8 | from einops import rearrange, repeat 9 | from einops.layers.torch import Rearrange 10 | def cast_tuple(val, num): 11 | return val if isinstance(val, tuple) else (val,) * num 12 | def conv_output_size(image_size, kernel_size, stride, padding = 0): 13 | return int(((image_size[0] - kernel_size + (2 * padding)) / stride) + 1),int(((image_size[1] - kernel_size + (2 * padding)) / stride) + 1) 14 | 15 | 16 | 17 | class VGGnet(nn.Module): 18 | 19 | def __init__(self, input_channel): 20 | super().__init__() 21 | 22 | layers=[64,128,256,512] 23 | 24 | self.conv1 = self._conv(input_channel,layers[0]) 25 | self.maxp1 = nn.MaxPool2d(2,stride=2) 26 | self.conv2 = self._conv(layers[0],layers[1]) 27 | self.maxp2 = nn.MaxPool2d(2,stride=2) 28 | self.conv3 = self._conv(layers[1],layers[2]) 29 | self.maxp3 = nn.MaxPool2d(2,stride=2) 30 | self.conv4 = self._conv(layers[2],layers[3]) 31 | self.maxp4 = nn.MaxPool2d(2,stride=2) 32 | 33 | 34 | def _conv(self,inplance,outplance,nlayers=2): 35 | conv = [] 36 | for n in range(nlayers): 37 | conv.append(nn.Conv2d(inplance,outplance,kernel_size=3, 38 | stride=1,padding=1,bias=False)) 39 | conv.append(nn.BatchNorm2d(outplance)) 40 | conv.append(nn.ReLU(inplace=True)) 41 | inplance = outplance 42 | 43 | conv = nn.Sequential(*conv) 44 | 45 | return conv 46 | 47 | def forward(self, x): 48 | xlist=[x] 49 | x = self.conv1(x) 50 | xlist.append(x) 51 | x = self.maxp1(x) 52 | x = self.conv2(x) 53 | xlist.append(x) 54 | x = self.maxp2(x) 55 | x = self.conv3(x) 56 | xlist.append(x) 57 | x = self.maxp3(x) 58 | x = self.conv4(x) 59 | xlist.append(x) 60 | return xlist 61 | class VGGnet(nn.Module): 62 | 63 | def __init__(self, input_channel): 64 | super().__init__() 65 | 66 | layers=[64,128,256,512] 67 | 68 | self.conv1 = self._conv(input_channel,layers[0]) 69 | self.maxp1 = nn.MaxPool2d(2,stride=2) 70 | self.conv2 = self._conv(layers[0],layers[1]) 71 | self.maxp2 = nn.MaxPool2d(2,stride=2) 72 | self.conv3 = self._conv(layers[1],layers[2]) 73 | self.maxp3 = nn.MaxPool2d(2,stride=2) 74 | self.conv4 = self._conv(layers[2],layers[3]) 75 | self.maxp4 = nn.MaxPool2d(2,stride=2) 76 | 77 | 78 | def _conv(self,inplance,outplance,nlayers=2): 79 | conv = [] 80 | for n in range(nlayers): 81 | conv.append(nn.Conv2d(inplance,outplance,kernel_size=3, 82 | stride=1,padding=1,bias=False)) 83 | conv.append(nn.BatchNorm2d(outplance)) 84 | conv.append(nn.ReLU(inplace=True)) 85 | inplance = outplance 86 | 87 | conv = nn.Sequential(*conv) 88 | 89 | return conv 90 | 91 | def forward(self, x): 92 | xlist=[x] 93 | x = self.conv1(x) 94 | xlist.append(x) 95 | x = self.maxp1(x) 96 | x = self.conv2(x) 97 | xlist.append(x) 98 | x = self.maxp2(x) 99 | x = self.conv3(x) 100 | xlist.append(x) 101 | x = self.maxp3(x) 102 | x = self.conv4(x) 103 | xlist.append(x) 104 | return xlist 105 | class exchange(nn.Module): 106 | def __init__(self,scale1,scale2,k_1,k_2): 107 | super().__init__() 108 | self.layers1 = [scale1,scale1+k_1+k_1,scale1+2*k_1+k_1,scale1+3*k_1+k_1,scale1+4*k_1+k_1] 109 | self.layers2 = [scale2,scale2+k_2+k_1,scale2+2*k_2+k_1,scale2+3*k_2+k_1,scale2+4*k_2+k_1] 110 | self.x1 = nn.Sequential(nn.Conv2d(scale1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 111 | self.y1 = nn.Sequential(nn.Conv2d(scale2,k_2,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 112 | self.x1c = nn.Sequential(nn.Conv2d(scale1,k_1,kernel_size=4,stride=2,padding=1,bias=False),nn.ReLU(inplace=True)) 113 | self.y1t = nn.Sequential(nn.ConvTranspose2d(scale2,k_1,kernel_size=4,stride=2,padding=1,bias=False),nn.ReLU(inplace=True)) 114 | 115 | self.x2_input = nn.Sequential(nn.Conv2d(self.layers1[1],k_1,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 116 | self.y2_input = nn.Sequential(nn.Conv2d(self.layers2[1],k_2,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 117 | 118 | 119 | self.x2 = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 120 | self.y2 = nn.Sequential(nn.Conv2d(k_2,k_2,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 121 | self.x2c = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=4,stride=2,padding=1,bias=False),nn.ReLU(inplace=True)) 122 | self.y2t = nn.Sequential(nn.ConvTranspose2d(k_2,k_1,kernel_size=4,stride=2,padding=1,bias=False),nn.ReLU(inplace=True)) 123 | 124 | self.x3_input = nn.Sequential(nn.Conv2d(self.layers1[2],k_1,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 125 | self.y3_input = nn.Sequential(nn.Conv2d(self.layers2[2],k_2,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 126 | 127 | 128 | self.x3 = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 129 | self.y3 = nn.Sequential(nn.Conv2d(k_2,k_2,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 130 | self.x3c = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=4,stride=2,padding=1,bias=False),nn.ReLU(inplace=True)) 131 | self.y3t = nn.Sequential(nn.ConvTranspose2d(k_2,k_1,kernel_size=4,stride=2,padding=1,bias=False),nn.ReLU(inplace=True)) 132 | 133 | self.x4_input = nn.Sequential(nn.Conv2d(self.layers1[3],k_1,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 134 | self.y4_input = nn.Sequential(nn.Conv2d(self.layers2[3],k_2,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 135 | 136 | self.x4 = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 137 | self.y4 = nn.Sequential(nn.Conv2d(k_2,k_2,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 138 | self.x4c = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=4,stride=2,padding=1,bias=False),nn.ReLU(inplace=True)) 139 | self.y4t = nn.Sequential(nn.ConvTranspose2d(k_2,k_1,kernel_size=4,stride=2,padding=1,bias=False),nn.ReLU(inplace=True)) 140 | 141 | self.x5_input = nn.Sequential(nn.Conv2d(self.layers1[4],k_1,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 142 | self.y5_input = nn.Sequential(nn.Conv2d(self.layers2[4],k_2,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 143 | 144 | self.x5 = nn.Sequential(nn.Conv2d(self.layers1[4],scale1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 145 | self.y5 = nn.Sequential(nn.Conv2d(self.layers2[4],scale2,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 146 | 147 | 148 | 149 | 150 | def _forward(self,ten1,ten2): 151 | x1 = self.x1(ten1) 152 | y1 = self.y1(ten2) 153 | x1c = self.x1c(ten1) 154 | y1t = self.y1t(ten2) 155 | 156 | x2_input = torch.cat([ten1,x1,y1t],1) 157 | y2_input = torch.cat([ten2,y1,x1c],1) 158 | x2_input = self.x2_input(x2_input) 159 | y2_input = self.y2_input(y2_input) 160 | 161 | x2 = self.x2(x2_input) 162 | y2 = self.y2(y2_input) 163 | x2c = self.x2c(x2_input) 164 | y2t = self.y2t(y2_input) 165 | 166 | x3_input = torch.cat([ten1,x1,x2,y2t],1) 167 | y3_input = torch.cat([ten2,y1,y2,x2c],1) 168 | x3_input = self.x3_input(x3_input) 169 | y3_input = self.y3_input(y3_input) 170 | 171 | x3 = self.x3(x3_input) 172 | y3 = self.y3(y3_input) 173 | x3c = self.x3c(x3_input) 174 | y3t = self.y3t(y3_input) 175 | 176 | x4_input = torch.cat([ten1,x1,x2,x3,y3t],1) 177 | y4_input = torch.cat([ten2,y1,y2,y3,x3c],1) 178 | x4_input = self.x4_input(x4_input) 179 | y4_input = self.y4_input(y4_input) 180 | 181 | x4 = self.x4(x4_input) 182 | y4 = self.y4(y4_input) 183 | x4c = self.x4c(x4_input) 184 | y4t = self.y4t(y4_input) 185 | 186 | x5_input = torch.cat([ten1,x1,x2,x3,x4,y4t],1) 187 | y5_input = torch.cat([ten2,y1,y2,y3,y4,x4c],1) 188 | #x5_input = self.x5_input(x5_input) 189 | #y5_input = self.y5_input(y5_input) 190 | x5 = self.x5(x5_input) 191 | y5 = self.y5(y5_input) 192 | 193 | return 0.4*x5+ten1, 0.4*y5+ten2 194 | 195 | 196 | 197 | 198 | 199 | def forward(self,ten): 200 | return self._forward(ten[0],ten[1]) 201 | 202 | 203 | class MSF(nn.Module): 204 | def __init__(self,inplace,num_classes): 205 | super().__init__() 206 | self.net = VGGnet(inplace) 207 | self.avg = nn.AdaptiveAvgPool2d(1) 208 | self.classifier = nn.Linear(512,num_classes) 209 | #self.exchange12_one = exchange(64,3216,32) 210 | #self.exchange34_one = exchange(128,256,32,64) 211 | 212 | #self.exchange23_one = exchange(64,128,32,32) 213 | 214 | #self.exchange12_two = exchange(32,64,16,32) 215 | #self.exchange34_two = exchange(128,256,32,64) 216 | self.exchange12_one = exchange(64,128,16,32) 217 | self.att1_one = SpatialAttention(64,64) 218 | self.att2_one = SpatialAttention(128,128) 219 | self.att3_one = SpatialAttention(256,256) 220 | self.att4_one = SpatialAttention(512,512) 221 | self.exchange34_one = exchange(256,512,32,64) 222 | self.exchange23_one = exchange(128,256,32,32) 223 | 224 | self.exchange12_two = exchange(64,128,16,32) 225 | self.exchange34_two = exchange(256,512,32,128) 226 | 227 | def forward(self,x): 228 | xlist = self.net(x) 229 | n11,n12,n13,n14 = xlist[1],xlist[2],xlist[3],xlist[4] 230 | #print(n11.shape,n12.shape,n13.shape,n14.shape) 231 | n11 = self.att1_one(n11) 232 | n12 = self.att2_one(n12) 233 | n13 = self.att3_one(n13) 234 | n14 = self.att4_one(n14) 235 | n21,n22 = self.exchange12_one((n11,n12)) 236 | n23,n24 = self.exchange34_one((n13,n14)) 237 | n22,n23 = self.exchange23_one((n22,n23)) 238 | 239 | n31,n32 = self.exchange12_two((n21,n22)) 240 | n33,n34 = self.exchange34_two((n23,n24)) 241 | li = [n14,n24,n34] 242 | logits_list = [] 243 | for l in li: 244 | r = torch.flatten(self.avg(l),1) 245 | c = self.classifier(r) 246 | logits_list.append(c) 247 | combined_logits = 0 248 | for r in logits_list: 249 | combined_logits += r 250 | return logits_list,combined_logits 251 | 252 | class SE_Block(nn.Module): 253 | "credits: https://github.com/moskomule/senet.pytorch/blob/master/senet/se_module.py#L4" 254 | def __init__(self, c, r=16): 255 | super().__init__() 256 | self.squeeze = nn.AdaptiveAvgPool2d(1) 257 | self.excitation = nn.Sequential( 258 | nn.Linear(c, c // r, bias=False), 259 | nn.ReLU(inplace=True), 260 | nn.Linear(c // r, c, bias=False), 261 | nn.Sigmoid() 262 | ) 263 | 264 | def forward(self, x): 265 | bs, c, _, _ = x.shape 266 | y = self.squeeze(x).view(bs, c) 267 | y = self.excitation(y).view(bs, c, 1, 1) 268 | return x * y.expand_as(x) 269 | 270 | def _conv_spa(inplance,outplance,nlayers=2): 271 | conv = [] 272 | for n in range(nlayers): 273 | conv.append(nn.Conv2d(inplance,outplance,kernel_size=3, 274 | stride=1,padding=1,bias=False)) 275 | conv.append(nn.BatchNorm2d(outplance)) 276 | conv.append(nn.ReLU(inplace=True)) 277 | inplance = outplance 278 | 279 | conv = nn.Sequential(*conv) 280 | 281 | return conv 282 | 283 | class SpatialAttention(nn.Module): 284 | def __init__(self,in_channel,out_channel, kernel_size=3): 285 | super(SpatialAttention, self).__init__() 286 | 287 | self.conv1 = _conv_spa(in_channel,out_channel) 288 | self.conv2 = nn.Conv2d(out_channel, 1, kernel_size=1,bias=False) 289 | 290 | self.conv_main = _conv_spa(in_channel,out_channel) 291 | self.sigmoid = nn.Sigmoid() 292 | self.sande = SE_Block(out_channel,16) 293 | self.convcombine = nn.Conv2d(2*out_channel,out_channel,kernel_size=1,stride=1,padding=1,bias=False) 294 | self.bncombine = nn.BatchNorm2d(out_channel) 295 | self.relucombine = nn.ReLU(inplace=True) 296 | def forward(self, x): 297 | sp = self.conv1(x) 298 | sp = self.conv2(sp) 299 | act = self.sigmoid(sp) 300 | x_spatial = x*act 301 | #x_se = self.sande(x) 302 | #x = torch.cat([x_spatial, x_se], dim=1) 303 | #x = self.convcombine(x) 304 | x = self.bncombine(x_spatial) 305 | return self.relucombine(x) 306 | 307 | 308 | 309 | class VGGnet_spatial(nn.Module): 310 | 311 | def __init__(self, input_channel,num_classes): 312 | super().__init__() 313 | 314 | layers=[64,128,256,512] 315 | 316 | self.conv1 = self._conv(input_channel,layers[0]) 317 | self.att1 = SpatialAttention(layers[0],layers[0]) 318 | self.maxp1 = nn.MaxPool2d(2,stride=2) 319 | 320 | self.conv2 = self._conv(layers[0],layers[1]) 321 | self.att2 = SpatialAttention(layers[1],layers[1]) 322 | self.maxp2 = nn.MaxPool2d(2,stride=2) 323 | 324 | self.conv3 = self._conv(layers[1],layers[2]) 325 | self.att3 = SpatialAttention(layers[2],layers[2]) 326 | self.maxp3 = nn.MaxPool2d(2,stride=2) 327 | 328 | self.conv4 = self._conv(layers[2],layers[3]) 329 | self.att4 = SpatialAttention(layers[3],layers[3]) 330 | self.maxp4 = nn.MaxPool2d(2,stride=2) 331 | self.avg = nn.AdaptiveAvgPool2d(1) 332 | self.classifier = nn.Linear(512,num_classes) 333 | 334 | 335 | def _conv(self,inplance,outplance,nlayers=2): 336 | conv = [] 337 | for n in range(nlayers): 338 | conv.append(nn.Conv2d(inplance,outplance,kernel_size=3, 339 | stride=1,padding=1,bias=False)) 340 | conv.append(nn.BatchNorm2d(outplance)) 341 | conv.append(nn.ReLU(inplace=True)) 342 | inplance = outplance 343 | 344 | conv = nn.Sequential(*conv) 345 | 346 | return conv 347 | 348 | def forward(self, x): 349 | xlist=[x] 350 | x = self.conv1(x) 351 | x = self.att1(x) 352 | x = self.maxp1(x) 353 | 354 | x = self.conv2(x) 355 | x = self.att2(x) 356 | x = self.maxp2(x) 357 | 358 | x = self.conv3(x) 359 | x = self.att3(x) 360 | x = self.maxp3(x) 361 | 362 | x = self.conv4(x) 363 | x = self.att4(x) 364 | r = torch.flatten(self.avg(x),1) 365 | c = self.classifier(r) 366 | return c 367 | 368 | img = torch.randn(1, 1, 64,128) 369 | v = VGGnet_spatial(input_channel=1,num_classes=657) 370 | preds = v(img) # (1, 1000) 371 | print(preds.shape) 372 | -------------------------------------------------------------------------------- /patchnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import timm 4 | import types 5 | import math 6 | import torch.nn.functional as F 7 | from math import sqrt 8 | from einops import rearrange, repeat 9 | from einops.layers.torch import Rearrange 10 | def cast_tuple(val, num): 11 | return val if isinstance(val, tuple) else (val,) * num 12 | def conv_output_size(image_size, kernel_size, stride, padding = 0): 13 | return int(((image_size - kernel_size + (2 * padding)) / stride) + 1) 14 | 15 | 16 | class PreNorm(nn.Module): 17 | def __init__(self, dim, fn): 18 | super().__init__() 19 | self.norm = nn.LayerNorm(dim) 20 | self.fn = fn 21 | def forward(self, x, **kwargs): 22 | return self.fn(self.norm(x), **kwargs) 23 | class FeedForward(nn.Module): 24 | def __init__(self, dim, hidden_dim, dropout = 0.): 25 | super().__init__() 26 | self.net = nn.Sequential( 27 | nn.Linear(dim, hidden_dim), 28 | nn.GELU(), 29 | nn.Dropout(dropout), 30 | nn.Linear(hidden_dim, dim), 31 | nn.Dropout(dropout) 32 | ) 33 | def forward(self, x): 34 | return self.net(x) 35 | 36 | class Attention(nn.Module): 37 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 38 | super().__init__() 39 | inner_dim = dim_head * heads 40 | project_out = not (heads == 1 and dim_head == dim) 41 | 42 | self.heads = heads 43 | self.scale = dim_head ** -0.5 44 | 45 | self.attend = nn.Softmax(dim = -1) 46 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 47 | 48 | self.to_out = nn.Sequential( 49 | nn.Linear(inner_dim, dim), 50 | nn.Dropout(dropout) 51 | ) if project_out else nn.Identity() 52 | 53 | def forward(self, x): 54 | b, n, _, h = *x.shape, self.heads 55 | qkv = self.to_qkv(x).chunk(3, dim = -1) 56 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 57 | 58 | dots = torch.einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 59 | 60 | attn = self.attend(dots) 61 | 62 | out = torch.einsum('b h i j, b h j d -> b h i d', attn, v) 63 | out = rearrange(out, 'b h n d -> b n (h d)') 64 | return self.to_out(out) 65 | 66 | 67 | class Transformer(nn.Module): 68 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 69 | super().__init__() 70 | self.layers = nn.ModuleList([]) 71 | for _ in range(depth): 72 | self.layers.append(nn.ModuleList([ 73 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 74 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 75 | ])) 76 | def forward(self, x): 77 | for attn, ff in self.layers: 78 | x = attn(x) + x 79 | x = ff(x) + x 80 | return x 81 | class DepthWiseConv2d(nn.Module): 82 | def __init__(self, dim_in, dim_out, kernel_size, padding, stride, bias = True): 83 | super().__init__() 84 | self.net = nn.Sequential( 85 | nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias), 86 | nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias) 87 | ) 88 | def forward(self, x): 89 | return self.net(x) 90 | 91 | class Pool(nn.Module): 92 | def __init__(self, dim): 93 | super().__init__() 94 | self.downsample = DepthWiseConv2d(dim, dim * 2, kernel_size = 3, stride = 2, padding = 1) 95 | self.cls_ff = nn.Linear(dim, dim * 2) 96 | 97 | def forward(self, x): 98 | cls_token, tokens = x[:, :1], x[:, 1:] 99 | 100 | cls_token = self.cls_ff(cls_token) 101 | 102 | tokens = rearrange(tokens, 'b (h w) c -> b c h w', h = int(sqrt(tokens.shape[1]))) 103 | tokens = self.downsample(tokens) 104 | tokens = rearrange(tokens, 'b c h w -> b (h w) c') 105 | 106 | return torch.cat((cls_token, tokens), dim = 1) 107 | class PiT(nn.Module): 108 | def __init__( 109 | self, 110 | *, 111 | image_size, 112 | patch_size, 113 | num_classes, 114 | dim, 115 | depth, 116 | heads, 117 | mlp_dim, 118 | dim_head = 64, 119 | dropout = 0., 120 | emb_dropout = 0. 121 | ): 122 | super().__init__() 123 | assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.' 124 | assert isinstance(depth, tuple), 'depth must be a tuple of integers, specifying the number of blocks before each downsizing' 125 | heads = cast_tuple(heads, len(depth)) 126 | self.pool='mean' 127 | self.to_latent = nn.Identity() 128 | patch_dim = patch_size ** 2 129 | 130 | self.to_patch_embedding = nn.Sequential( 131 | nn.Unfold(kernel_size = patch_size, stride = patch_size // 2), 132 | Rearrange('b c n -> b n c'), 133 | nn.Linear(patch_dim, dim) 134 | ) 135 | 136 | output_size = conv_output_size(image_size, patch_size, patch_size // 2) 137 | num_patches = output_size ** 2 138 | 139 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 140 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 141 | self.dropout = nn.Dropout(emb_dropout) 142 | 143 | layers = [] 144 | for ind, (layer_depth, layer_heads) in enumerate(zip(depth, heads)): 145 | not_last = ind < (len(depth) - 1) 146 | 147 | layers.append(Transformer(dim, layer_depth, layer_heads, dim_head, mlp_dim, dropout)) 148 | 149 | if not_last: 150 | layers.append(Pool(dim)) 151 | dim *= 2 152 | 153 | self.layers = nn.Sequential( 154 | *layers, 155 | nn.LayerNorm(dim), 156 | nn.Linear(dim, num_classes)) 157 | 158 | def forward(self, img): 159 | #print(img.shape) 160 | x = self.to_patch_embedding(img) 161 | b, n, _ = x.shape 162 | 163 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 164 | x = torch.cat((cls_tokens, x), dim=1) 165 | x += self.pos_embedding 166 | x = self.dropout(x) 167 | for l in self.layers[0:len(self.layers)-1]: 168 | x = l(x) 169 | x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 170 | x = self.to_latent(x) 171 | x = self.layers[-1](x) 172 | return x 173 | v = PiT( 174 | image_size = 128, 175 | patch_size = 16, 176 | dim = 128, 177 | num_classes = 105, 178 | depth = (3, 3, 3), # list of depths, indicating the number of rounds of each stage before a downsample 179 | heads = 8, 180 | mlp_dim = 512, 181 | dropout = 0.1, 182 | emb_dropout = 0.1 183 | ) 184 | 185 | # forward pass now returns predictions and the attention maps 186 | 187 | img = torch.randn(1, 1, 128,128) 188 | 189 | preds = v(img) # (1, 1000) 190 | print(preds.shape) 191 | 192 | 193 | 194 | class exchange(nn.Module): 195 | def __init__(self,scale1,scale2,k_1,k_2): 196 | super().__init__() 197 | self.layers1 = [scale1,scale1+k_1+k_1,scale1+2*k_1+k_1,scale1+3*k_1+k_1,scale1+4*k_1+k_1] 198 | self.layers2 = [scale2,scale2+k_2+k_1,scale2+2*k_2+k_1,scale2+3*k_2+k_1,scale2+4*k_2+k_1] 199 | self.x1 = nn.Sequential(nn.Conv2d(scale1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 200 | self.y1 = nn.Sequential(nn.Conv2d(scale2,k_2,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 201 | self.x1c = nn.Sequential(nn.Conv2d(scale1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 202 | self.y1t = nn.Sequential(nn.ConvTranspose2d(scale2,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 203 | 204 | self.x2_input = nn.Sequential(nn.Conv2d(self.layers1[1],k_1,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 205 | self.y2_input = nn.Sequential(nn.Conv2d(self.layers2[1],k_2,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 206 | 207 | 208 | self.x2 = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 209 | self.y2 = nn.Sequential(nn.Conv2d(k_2,k_2,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 210 | self.x2c = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 211 | self.y2t = nn.Sequential(nn.ConvTranspose2d(k_2,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 212 | 213 | self.x3_input = nn.Sequential(nn.Conv2d(self.layers1[2],k_1,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 214 | self.y3_input = nn.Sequential(nn.Conv2d(self.layers2[2],k_2,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 215 | 216 | 217 | self.x3 = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 218 | self.y3 = nn.Sequential(nn.Conv2d(k_2,k_2,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 219 | self.x3c = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 220 | self.y3t = nn.Sequential(nn.ConvTranspose2d(k_2,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 221 | 222 | self.x4_input = nn.Sequential(nn.Conv2d(self.layers1[3],k_1,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 223 | self.y4_input = nn.Sequential(nn.Conv2d(self.layers2[3],k_2,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 224 | 225 | self.x4 = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 226 | self.y4 = nn.Sequential(nn.Conv2d(k_2,k_2,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 227 | self.x4c = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 228 | self.y4t = nn.Sequential(nn.ConvTranspose2d(k_2,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 229 | 230 | self.x5_input = nn.Sequential(nn.Conv2d(self.layers1[4],k_1,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 231 | self.y5_input = nn.Sequential(nn.Conv2d(self.layers2[4],k_2,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 232 | 233 | self.x5 = nn.Sequential(nn.Conv2d(self.layers1[4],scale1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 234 | self.y5 = nn.Sequential(nn.Conv2d(self.layers2[4],scale2,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 235 | 236 | 237 | 238 | 239 | def _forward(self,ten1,ten2): 240 | x1 = self.x1(ten1) 241 | y1 = self.y1(ten2) 242 | x1c = self.x1c(ten1) 243 | y1t = self.y1t(ten2) 244 | 245 | x2_input = torch.cat([ten1,x1,y1t],1) 246 | y2_input = torch.cat([ten2,y1,x1c],1) 247 | x2_input = self.x2_input(x2_input) 248 | y2_input = self.y2_input(y2_input) 249 | 250 | x2 = self.x2(x2_input) 251 | y2 = self.y2(y2_input) 252 | x2c = self.x2c(x2_input) 253 | y2t = self.y2t(y2_input) 254 | 255 | x3_input = torch.cat([ten1,x1,x2,y2t],1) 256 | y3_input = torch.cat([ten2,y1,y2,x2c],1) 257 | x3_input = self.x3_input(x3_input) 258 | y3_input = self.y3_input(y3_input) 259 | 260 | x3 = self.x3(x3_input) 261 | y3 = self.y3(y3_input) 262 | x3c = self.x3c(x3_input) 263 | y3t = self.y3t(y3_input) 264 | 265 | x4_input = torch.cat([ten1,x1,x2,x3,y3t],1) 266 | y4_input = torch.cat([ten2,y1,y2,y3,x3c],1) 267 | x4_input = self.x4_input(x4_input) 268 | y4_input = self.y4_input(y4_input) 269 | 270 | x4 = self.x4(x4_input) 271 | y4 = self.y4(y4_input) 272 | x4c = self.x4c(x4_input) 273 | y4t = self.y4t(y4_input) 274 | 275 | x5_input = torch.cat([ten1,x1,x2,x3,x4,y4t],1) 276 | y5_input = torch.cat([ten2,y1,y2,y3,y4,x4c],1) 277 | #x5_input = self.x5_input(x5_input) 278 | #y5_input = self.y5_input(y5_input) 279 | x5 = self.x5(x5_input) 280 | y5 = self.y5(y5_input) 281 | 282 | return 0.4*x5+ten1, 0.4*y5+ten2 283 | 284 | 285 | 286 | 287 | 288 | def forward(self,ten): 289 | return self._forward(ten[0],ten[1]) 290 | 291 | class VGGnet(nn.Module): 292 | 293 | def __init__(self, input_channel): 294 | super().__init__() 295 | 296 | layers=[64,128,256,512] 297 | 298 | self.conv1 = self._conv(input_channel,layers[0]) 299 | self.maxp1 = nn.MaxPool2d(2,stride=2) 300 | self.conv2 = self._conv(layers[0],layers[1]) 301 | self.maxp2 = nn.MaxPool2d(2,stride=2) 302 | self.conv3 = self._conv(layers[1],layers[2]) 303 | self.maxp3 = nn.MaxPool2d(2,stride=2) 304 | self.conv4 = self._conv(layers[2],layers[3]) 305 | self.maxp4 = nn.MaxPool2d(2,stride=2) 306 | 307 | 308 | def _conv(self,inplance,outplance,nlayers=2): 309 | conv = [] 310 | for n in range(nlayers): 311 | conv.append(nn.Conv2d(inplance,outplance,kernel_size=3, 312 | stride=1,padding=1,bias=False)) 313 | conv.append(nn.BatchNorm2d(outplance)) 314 | conv.append(nn.ReLU(inplace=True)) 315 | inplance = outplance 316 | 317 | conv = nn.Sequential(*conv) 318 | 319 | return conv 320 | 321 | def forward(self, x): 322 | xlist=[x] 323 | x = self.conv1(x) 324 | xlist.append(x) 325 | x = self.maxp1(x) 326 | x = self.conv2(x) 327 | xlist.append(x) 328 | x = self.maxp2(x) 329 | x = self.conv3(x) 330 | xlist.append(x) 331 | x = self.maxp3(x) 332 | x = self.conv4(x) 333 | xlist.append(x) 334 | return x 335 | 336 | class dfrag_old(nn.Module): 337 | def __init__(self,inplace,num_classes): 338 | super().__init__() 339 | self.net = VGGnet(inplace) 340 | self.avg = nn.AdaptiveAvgPool2d(1) 341 | 342 | self.conv1_stream1 = self._conv(1,64) 343 | self.conv1_stream2 = self._conv(1,64) 344 | self.conv1_stream3 = self._conv(1,64) 345 | self.conv1_stream4 = self._conv(1,64) 346 | self.conv1_stream5 = self._conv(1,64) 347 | self.maxp1 = nn.MaxPool2d(2,stride=(2,2)) 348 | 349 | self.conv2_stream1 = self._conv(64,128) 350 | self.conv2_stream2 = self._conv(64,128) 351 | self.conv2_stream3 = self._conv(64,128) 352 | self.conv2_stream4 = self._conv(64,128) 353 | self.conv2_stream5 = self._conv(64,128) 354 | self.maxp2 = nn.MaxPool2d(2,stride=2) 355 | 356 | self.classifier = nn.Linear(512,num_classes) 357 | self.exchange12_one = exchange(64,64,16,16) 358 | self.exchange23_one = exchange(64,64,16,16) 359 | self.exchange34_one = exchange(64,64,16,16) 360 | self.exchange45_one = exchange(64,64,16,16) 361 | 362 | self.exchange12_2 = exchange(128,128,32,32) 363 | self.exchange23_2 = exchange(128,128,32,32) 364 | self.exchange34_2 = exchange(128,128,32,32) 365 | self.exchange45_2 = exchange(128,128,32,32) 366 | self.classifier_global = nn.Linear(512,num_classes) 367 | self.classifier_patch = nn.Linear(128,num_classes) 368 | 369 | 370 | def _conv(self,inplance,outplance,nlayers=2): 371 | conv = [] 372 | for n in range(nlayers): 373 | conv.append(nn.Conv2d(inplance,outplance,kernel_size=3, 374 | stride=1,padding=1,bias=False)) 375 | conv.append(nn.BatchNorm2d(outplance)) 376 | conv.append(nn.ReLU(inplace=True)) 377 | inplance = outplance 378 | 379 | conv = nn.Sequential(*conv) 380 | 381 | return conv 382 | 383 | def forward(self,x): 384 | x_global = self.net(x) 385 | step = 16 386 | 387 | #print(xlist[0].shape) 388 | xpatch = [] 389 | # input image 390 | reslist = [] 391 | for n in range(0,65,step): 392 | xpatch.append(x[:,:,:,n:n+64]) 393 | 394 | x1 = self.conv1_stream1(xpatch[0]) 395 | x1 = self.maxp1(x1) 396 | x2 = self.conv1_stream2(xpatch[1]) 397 | x2 = self.maxp1(x2) 398 | x3 = self.conv1_stream3(xpatch[2]) 399 | x3= self.maxp1(x3) 400 | x4 = self.conv1_stream4(xpatch[3]) 401 | x4 = self.maxp1(x4) 402 | x5 = self.conv1_stream5(xpatch[4]) 403 | x5 = self.maxp1(x5) 404 | 405 | x1,x2 = self.exchange12_one((x1,x2)) 406 | x2,x3 = self.exchange23_one((x2,x3)) 407 | x3,x4 = self.exchange34_one((x3,x4)) 408 | x4,x5 = self.exchange45_one((x4,x5)) 409 | 410 | x1 = self.conv2_stream1(x1) 411 | x1 = self.maxp2(x1) 412 | x2 = self.conv2_stream2(x2) 413 | x2 = self.maxp2(x2) 414 | x3 = self.conv2_stream3(x3) 415 | x3= self.maxp2(x3) 416 | x4 = self.conv2_stream4(x4) 417 | x4 = self.maxp2(x4) 418 | x5 = self.conv2_stream5(x5) 419 | x5 = self.maxp2(x5) 420 | 421 | x1,x2 = self.exchange12_2((x1,x2)) 422 | x2,x3 = self.exchange23_2((x2,x3)) 423 | x3,x4 = self.exchange34_2((x3,x4)) 424 | x4,x5 = self.exchange45_2((x4,x5)) 425 | 426 | 427 | li = [x1,x2,x3,x4,x5] 428 | global_pred = torch.flatten(self.avg(x_global),1) 429 | global_logs = self.classifier_global(global_pred) 430 | 431 | logits_list = [global_logs] 432 | for l in li: 433 | r = torch.flatten(self.avg(l),1) 434 | c = self.classifier_patch(r) 435 | logits_list.append(c) 436 | combined_logits = 0 437 | for r in logits_list: 438 | combined_logits += r 439 | return logits_list,combined_logits 440 | 441 | 442 | 443 | class exchange(nn.Module): 444 | def __init__(self,scale1,scale2,k_1,k_2): 445 | super().__init__() 446 | self.layers1 = [scale1,scale1+k_1+k_1,scale1+2*k_1+k_1,scale1+3*k_1+k_1,scale1+4*k_1+k_1] 447 | self.layers2 = [scale2,scale2+k_2+k_1,scale2+2*k_2+k_1,scale2+3*k_2+k_1,scale2+4*k_2+k_1] 448 | self.x1 = nn.Sequential(nn.Conv2d(scale1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 449 | self.y1 = nn.Sequential(nn.Conv2d(scale2,k_2,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 450 | self.x1c = nn.Sequential(nn.Conv2d(scale1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 451 | self.y1t = nn.Sequential(nn.ConvTranspose2d(scale2,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 452 | 453 | self.x2_input = nn.Sequential(nn.Conv2d(self.layers1[1],k_1,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 454 | self.y2_input = nn.Sequential(nn.Conv2d(self.layers2[1],k_2,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 455 | 456 | 457 | self.x2 = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 458 | self.y2 = nn.Sequential(nn.Conv2d(k_2,k_2,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 459 | self.x2c = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 460 | self.y2t = nn.Sequential(nn.ConvTranspose2d(k_2,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 461 | 462 | self.x3_input = nn.Sequential(nn.Conv2d(self.layers1[2],k_1,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 463 | self.y3_input = nn.Sequential(nn.Conv2d(self.layers2[2],k_2,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 464 | 465 | 466 | self.x3 = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 467 | self.y3 = nn.Sequential(nn.Conv2d(k_2,k_2,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 468 | self.x3c = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 469 | self.y3t = nn.Sequential(nn.ConvTranspose2d(k_2,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 470 | 471 | self.x4_input = nn.Sequential(nn.Conv2d(self.layers1[3],k_1,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 472 | self.y4_input = nn.Sequential(nn.Conv2d(self.layers2[3],k_2,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 473 | 474 | self.x4 = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 475 | self.y4 = nn.Sequential(nn.Conv2d(k_2,k_2,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 476 | self.x4c = nn.Sequential(nn.Conv2d(k_1,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 477 | self.y4t = nn.Sequential(nn.ConvTranspose2d(k_2,k_1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 478 | 479 | self.x5_input = nn.Sequential(nn.Conv2d(self.layers1[4],k_1,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 480 | self.y5_input = nn.Sequential(nn.Conv2d(self.layers2[4],k_2,kernel_size=1,stride=1,padding=0,bias=False),nn.ReLU(inplace=True)) 481 | 482 | self.x5 = nn.Sequential(nn.Conv2d(self.layers1[4],scale1,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 483 | self.y5 = nn.Sequential(nn.Conv2d(self.layers2[4],scale2,kernel_size=3,stride=1,padding=1,bias=False),nn.ReLU(inplace=True)) 484 | 485 | 486 | 487 | 488 | def _forward(self,ten1,ten2): 489 | x1 = self.x1(ten1) 490 | y1 = self.y1(ten2) 491 | x1c = self.x1c(ten1) 492 | y1t = self.y1t(ten2) 493 | 494 | x2_input = torch.cat([ten1,x1,y1t],1) 495 | y2_input = torch.cat([ten2,y1,x1c],1) 496 | x2_input = self.x2_input(x2_input) 497 | y2_input = self.y2_input(y2_input) 498 | 499 | x2 = self.x2(x2_input) 500 | y2 = self.y2(y2_input) 501 | x2c = self.x2c(x2_input) 502 | y2t = self.y2t(y2_input) 503 | 504 | x3_input = torch.cat([ten1,x1,x2,y2t],1) 505 | y3_input = torch.cat([ten2,y1,y2,x2c],1) 506 | x3_input = self.x3_input(x3_input) 507 | y3_input = self.y3_input(y3_input) 508 | 509 | x3 = self.x3(x3_input) 510 | y3 = self.y3(y3_input) 511 | x3c = self.x3c(x3_input) 512 | y3t = self.y3t(y3_input) 513 | 514 | x4_input = torch.cat([ten1,x1,x2,x3,y3t],1) 515 | y4_input = torch.cat([ten2,y1,y2,y3,x3c],1) 516 | x4_input = self.x4_input(x4_input) 517 | y4_input = self.y4_input(y4_input) 518 | 519 | x4 = self.x4(x4_input) 520 | y4 = self.y4(y4_input) 521 | x4c = self.x4c(x4_input) 522 | y4t = self.y4t(y4_input) 523 | 524 | x5_input = torch.cat([ten1,x1,x2,x3,x4,y4t],1) 525 | y5_input = torch.cat([ten2,y1,y2,y3,y4,x4c],1) 526 | #x5_input = self.x5_input(x5_input) 527 | #y5_input = self.y5_input(y5_input) 528 | x5 = self.x5(x5_input) 529 | y5 = self.y5(y5_input) 530 | 531 | return 0.4*x5+ten1, 0.4*y5+ten2 532 | 533 | 534 | 535 | 536 | 537 | def forward(self,ten): 538 | return self._forward(ten[0],ten[1]) 539 | 540 | class VGGnet(nn.Module): 541 | 542 | def __init__(self, input_channel): 543 | super().__init__() 544 | 545 | layers=[64,128,256,512] 546 | 547 | self.conv1 = self._conv(input_channel,layers[0]) 548 | self.maxp1 = nn.MaxPool2d(2,stride=2) 549 | self.conv2 = self._conv(layers[0],layers[1]) 550 | self.maxp2 = nn.MaxPool2d(2,stride=2) 551 | self.conv3 = self._conv(layers[1],layers[2]) 552 | self.maxp3 = nn.MaxPool2d(2,stride=2) 553 | self.conv4 = self._conv(layers[2],layers[3]) 554 | self.maxp4 = nn.MaxPool2d(2,stride=2) 555 | 556 | 557 | def _conv(self,inplance,outplance,nlayers=2): 558 | conv = [] 559 | for n in range(nlayers): 560 | conv.append(nn.Conv2d(inplance,outplance,kernel_size=3, 561 | stride=1,padding=1,bias=False)) 562 | conv.append(nn.BatchNorm2d(outplance)) 563 | conv.append(nn.ReLU(inplace=True)) 564 | inplance = outplance 565 | 566 | conv = nn.Sequential(*conv) 567 | 568 | return conv 569 | 570 | def forward(self, x): 571 | xlist=[x] 572 | x = self.conv1(x) 573 | xlist.append(x) 574 | x = self.maxp1(x) 575 | x = self.conv2(x) 576 | xlist.append(x) 577 | x = self.maxp2(x) 578 | x = self.conv3(x) 579 | xlist.append(x) 580 | x = self.maxp3(x) 581 | x = self.conv4(x) 582 | xlist.append(x) 583 | return x 584 | 585 | class dfrag(nn.Module): 586 | def __init__(self,inplace,num_classes): 587 | super().__init__() 588 | self.net = VGGnet(inplace) 589 | self.avg = nn.AdaptiveAvgPool2d(1) 590 | 591 | self.conv1_stream1 = self._conv(1,64) 592 | self.conv1_stream2 = self._conv(1,64) 593 | self.conv1_stream3 = self._conv(1,64) 594 | self.conv1_stream4 = self._conv(1,64) 595 | self.conv1_stream5 = self._conv(1,64) 596 | self.maxp1 = nn.MaxPool2d(2,stride=(2,2)) 597 | 598 | self.conv2_stream1 = self._conv(64,128) 599 | self.conv2_stream2 = self._conv(64,128) 600 | self.conv2_stream3 = self._conv(64,128) 601 | self.conv2_stream4 = self._conv(64,128) 602 | self.conv2_stream5 = self._conv(64,128) 603 | 604 | self.conv3_stream1 = self._conv(128,256) 605 | self.conv3_stream2 = self._conv(128,256) 606 | self.conv3_stream3 = self._conv(128,256) 607 | self.conv3_stream4 = self._conv(128,256) 608 | self.conv3_stream5 = self._conv(128,256) 609 | self.conv4_stream1 = self._conv(256,256) 610 | self.conv4_stream2 = self._conv(256,256) 611 | self.conv4_stream3 = self._conv(256,256) 612 | self.conv4_stream4 = self._conv(256,256) 613 | self.conv4_stream5 = self._conv(256,256) 614 | self.maxp4 = nn.MaxPool2d(2,stride=2) 615 | self.maxp2 = nn.MaxPool2d(2,stride=2) 616 | self.maxp3 = nn.MaxPool2d(2,stride=2) 617 | 618 | self.classifier = nn.Linear(512,num_classes) 619 | self.exchange12_one = exchange(64,64,16,16) 620 | self.exchange23_one = exchange(64,64,16,16) 621 | self.exchange34_one = exchange(64,64,16,16) 622 | self.exchange45_one = exchange(64,64,16,16) 623 | 624 | self.exchange12_2 = exchange(128,128,32,32) 625 | self.exchange23_2 = exchange(128,128,32,32) 626 | self.exchange34_2 = exchange(128,128,32,32) 627 | self.exchange45_2 = exchange(128,128,32,32) 628 | 629 | self.exchange12_3 = exchange(256,256,32,32) 630 | self.exchange23_3 = exchange(256,256,32,32) 631 | self.exchange34_3 = exchange(256,256,32,32) 632 | self.exchange45_3 = exchange(256,256,32,32) 633 | 634 | 635 | self.exchange12_4 = exchange(256,256,16,16) 636 | self.exchange23_4 = exchange(256,256,16,16) 637 | self.exchange34_4 = exchange(256,256,16,16) 638 | self.exchange45_4 = exchange(256,256,16,16) 639 | 640 | 641 | 642 | 643 | self.classifier_global = nn.Linear(512,num_classes) 644 | self.classifier_patch = nn.Linear(256,num_classes) 645 | 646 | def _conv(self,inplance,outplance,nlayers=2): 647 | conv = [] 648 | for n in range(nlayers): 649 | conv.append(nn.Conv2d(inplance,outplance,kernel_size=3, 650 | stride=1,padding=1,bias=False)) 651 | conv.append(nn.BatchNorm2d(outplance)) 652 | conv.append(nn.ReLU(inplace=True)) 653 | inplance = outplance 654 | 655 | conv = nn.Sequential(*conv) 656 | 657 | return conv 658 | 659 | def forward(self,x): 660 | x_global = self.net(x) 661 | step = 16 662 | 663 | #print(xlist[0].shape) 664 | xpatch = [] 665 | # input image 666 | reslist = [] 667 | for n in range(0,65,step): 668 | #print(x[:,:,:,n:n+64].shape) 669 | xpatch.append(x[:,:,:,n:n+64]) 670 | 671 | x1 = self.conv1_stream1(xpatch[0]) 672 | x1 = self.maxp1(x1) 673 | x2 = self.conv1_stream2(xpatch[1]) 674 | x2 = self.maxp1(x2) 675 | x3 = self.conv1_stream3(xpatch[2]) 676 | x3= self.maxp1(x3) 677 | x4 = self.conv1_stream4(xpatch[3]) 678 | x4 = self.maxp1(x4) 679 | x5 = self.conv1_stream5(xpatch[4]) 680 | x5 = self.maxp1(x5) 681 | 682 | x1,x2 = self.exchange12_one((x1,x2)) 683 | x2,x3 = self.exchange23_one((x2,x3)) 684 | x3,x4 = self.exchange34_one((x3,x4)) 685 | x4,x5 = self.exchange45_one((x4,x5)) 686 | 687 | x1 = self.conv2_stream1(x1) 688 | x1 = self.maxp2(x1) 689 | x2 = self.conv2_stream2(x2) 690 | x2 = self.maxp2(x2) 691 | x3 = self.conv2_stream3(x3) 692 | x3= self.maxp2(x3) 693 | x4 = self.conv2_stream4(x4) 694 | x4 = self.maxp2(x4) 695 | x5 = self.conv2_stream5(x5) 696 | x5 = self.maxp2(x5) 697 | 698 | x1,x2 = self.exchange12_2((x1,x2)) 699 | x2,x3 = self.exchange23_2((x2,x3)) 700 | x3,x4 = self.exchange34_2((x3,x4)) 701 | x4,x5 = self.exchange45_2((x4,x5)) 702 | 703 | 704 | x1 = self.conv3_stream1(x1) 705 | x1 = self.maxp3(x1) 706 | x2 = self.conv3_stream2(x2) 707 | x2 = self.maxp3(x2) 708 | x3 = self.conv3_stream3(x3) 709 | x3= self.maxp3(x3) 710 | x4 = self.conv3_stream4(x4) 711 | x4 = self.maxp3(x4) 712 | x5 = self.conv3_stream5(x5) 713 | x5 = self.maxp3(x5) 714 | 715 | x1,x2 = self.exchange12_3((x1,x2)) 716 | x2,x3 = self.exchange23_3((x2,x3)) 717 | x3,x4 = self.exchange34_3((x3,x4)) 718 | x4,x5 = self.exchange45_3((x4,x5)) 719 | 720 | x1 = self.conv4_stream1(x1) 721 | x1 = self.maxp4(x1) 722 | x2 = self.conv4_stream2(x2) 723 | x2 = self.maxp4(x2) 724 | x3 = self.conv4_stream3(x3) 725 | x3= self.maxp4(x3) 726 | x4 = self.conv4_stream4(x4) 727 | x4 = self.maxp4(x4) 728 | x5 = self.conv4_stream5(x5) 729 | x5 = self.maxp4(x5) 730 | 731 | x1,x2 = self.exchange12_4((x1,x2)) 732 | x2,x3 = self.exchange23_4((x2,x3)) 733 | x3,x4 = self.exchange34_4((x3,x4)) 734 | x4,x5 = self.exchange45_4((x4,x5)) 735 | 736 | li = [x1,x2,x3,x4,x5] 737 | global_pred = torch.flatten(self.avg(x_global),1) 738 | global_logs = self.classifier_global(global_pred) 739 | 740 | logits_list = [global_logs] 741 | for l in li: 742 | r = torch.flatten(self.avg(l),1) 743 | c = self.classifier_patch(r) 744 | logits_list.append(c) 745 | combined_logits =0 746 | for r in logits_list: 747 | combined_logits += r 748 | return logits_list,combined_logits 749 | 750 | 751 | --------------------------------------------------------------------------------