├── README.md ├── lib ├── __init__.py ├── dataloader.py ├── dataset.py ├── init.py ├── test.py └── train.py ├── models ├── __init__.py ├── capsule_layer.py ├── capsule_net.py ├── nn_.py ├── shufflenet.py └── unet.py └── my_train.py /README.md: -------------------------------------------------------------------------------- 1 | # Capsules for Object Segmentation 2 | 3 | A barebones CUDA-enabled PyTorch implementation of the segcaps architecture in the paper "Capsules for Object Segmentation" by [Rodney LaLonde and Ulas Bagci](https://github.com/lalonderodney/SegCaps). 4 | 5 | 6 | ## Condensed Abstract 7 | 8 | > Convolutional neural networks (CNNs) have shown remarkable results over the last several years for a wide range of computer vision tasks. A new architecture recently introduced by Sabour et al., referred to as a capsule networks with dynamic routing, has shown great initial results for digit recognition and small image classification. Our work expands the use of capsule networks to the task of object segmentation for the first time in the literature. We extend the idea of convolutional capsules with locally-connected routing and propose the concept of deconvolutional capsules. Further, we extend the masked reconstruction to reconstruct the positive input class. The proposed convolutional-deconvolutional capsule network, called SegCaps, shows strong results for the task of object segmentation with substantial decrease in parameter space. As an example application, we applied the proposed SegCaps to segment pathological lungs from low dose CT scans and compared its accuracy and efficiency with other U-Net-based architectures. SegCaps is able to handle large image sizes (512 x 512) as opposed to baseline capsules (typically less than 32 x 32). The proposed SegCaps reduced the number of parameters of U-Net architecture by 95.4% while still providing a better segmentation accuracy. 9 | 10 | Paper written by by Rodney LaLonde and Ulas Bagci. For more information, please check out the paper [here](https://arxiv.org/abs/1804.04241.). 11 | 12 | ## Requirements 13 | 14 | * Python 3 15 | * PyTorch 16 | * TorchVision 17 | * TorchNet 18 | * Visdom 19 | 20 | ## Usage 21 | 22 | **Step 1** Adjust the number of training epochs, batch sizes, etc. inside `my_train.py`. 23 | 24 | 25 | 26 | **Step 2** Start training. you can choose your own dataset. 27 | 28 | ```console 29 | $ python my_train.py 30 | ``` 31 | 32 | 33 | ## TODO 34 | 35 | - [ ] test the acc in the dataset 36 | 37 | ## Credits 38 | 39 | Primarily referenced these two TensorFlow and Keras implementations: 40 | 1. [Official Keras implementation by Rodney LaLonde and Ulas Bagci](https://github.com/lalonderodney/SegCaps) 41 | 2. [TensorFlow implementation by @iwyoo](https://github.com/iwyoo/tf-SegCaps) 42 | 43 | 44 | ## Contact/Support 45 | 46 | email:13935771565@163.com -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /lib/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import lib.dataset as dataset 3 | from torchvision import transforms 4 | import torch 5 | import visdom 6 | import random 7 | import numpy as np 8 | import torch.nn.functional as F 9 | 10 | 11 | class ToTensor(object): 12 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 13 | 14 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 15 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 16 | """ 17 | 18 | def __call__(self, pic): 19 | """ 20 | Args: 21 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 22 | 23 | Returns: 24 | Tensor: Converted image. 25 | """ 26 | return torch.from_numpy(pic) 27 | 28 | def __repr__(self): 29 | return self.__class__.__name__ + '()' 30 | 31 | 32 | class getcolor(): 33 | def __init__(self): 34 | color = [i for i in range(5, 256)] 35 | self.rgb_r = random.sample(color, 251) 36 | self.rgb_g = random.sample(color, 251) 37 | self.rgb_b = random.sample(color, 251) 38 | self.i = 0 39 | 40 | def get(self): 41 | i = self.i 42 | color = [self.rgb_r[i % 251], self.rgb_g[i % 251], self.rgb_b[i % 251]] 43 | idx = random.sample([k for k in range(0, 3)], 3) 44 | color = [color[idx[0]], color[idx[1]], color[idx[2]]] 45 | self.i = (i + 1) % 251 46 | return color 47 | 48 | 49 | def get_data(batch_size, data_name, data_root='./my_ai/'): 50 | data_loader = data.DataLoader( 51 | dataset.Dataset( 52 | path=data_root, 53 | transform_data=transforms.Compose([ 54 | # transforms.RandomHorizontalFlip(), 55 | # channel_change(), 56 | 57 | color_change(), 58 | ToTensor(), 59 | tensor_pad(28) 60 | ]), 61 | transform_labels=transforms.Compose([ 62 | # transforms.RandomHorizontalFlip(), 63 | 64 | ToTensor(), 65 | tensor_pad(28) 66 | ]), 67 | data_name=data_name 68 | ), 69 | batch_size=batch_size, 70 | shuffle=True, 71 | num_workers=1 72 | ) 73 | return data_loader 74 | 75 | 76 | def test(): 77 | vis = visdom.Visdom() 78 | vis.close(env='test') 79 | data_loader = get_data(10) 80 | a = input('in') 81 | for batch_index, (data, target) in enumerate(data_loader): 82 | print(batch_index) 83 | print(data.shape) 84 | print(target.shape) 85 | vis.image( 86 | data[0, 0, :, :], 87 | env='test', 88 | win='image', 89 | opts=dict(title='target') 90 | ) 91 | 92 | 93 | class rgb_channel(object): 94 | def __call__(self, x): 95 | x = x.reshape(1, -1) 96 | rgb = torch.zeros(256, x.shape[0]).scatter_(0, x, 1) 97 | return rgb 98 | 99 | 100 | class channel_change(object): 101 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 102 | 103 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 104 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 105 | """ 106 | 107 | def __call__(self, img): 108 | """ 109 | Args: 110 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 111 | 112 | Returns: 113 | Tensor: Converted image. 114 | """ 115 | color = getcolor() 116 | flag_leg = img == (92, 143, 0) 117 | flag_body = img == (200, 128, 0) 118 | flag_neck = img == (130, 162, 0) 119 | flag_arm = img == (138, 73, 0) 120 | flag_head = img == (0, 236, 163) 121 | flag_back = img == (0, 0, 0) 122 | flag_leg = flag_leg[:, :, 0] * flag_leg[:, :, 1] * flag_leg[:, :, 2] 123 | flag_body = flag_body[:, :, 0] * flag_body[:, :, 1] * flag_body[:, :, 2] 124 | flag_neck = flag_neck[:, :, 0] * flag_neck[:, :, 1] * flag_neck[:, :, 2] 125 | flag_arm = flag_arm[:, :, 0] * flag_arm[:, :, 1] * flag_arm[:, :, 2] 126 | flag_head = flag_head[:, :, 0] * flag_head[:, :, 1] * flag_head[:, :, 2] 127 | flag_back = flag_back[:, :, 0] * flag_back[:, :, 1] * flag_back[:, :, 2] 128 | img[:, :, :] = color.get() 129 | img[flag_head] = color.get() 130 | img[flag_neck] = color.get() 131 | img[flag_body] = color.get() 132 | img[flag_leg] = color.get() 133 | img[flag_arm] = color.get() 134 | img[flag_back] = [0, 0, 0] 135 | rgb_r = [] 136 | rgb_g = [] 137 | rgb_b = [] 138 | for i in range(256): 139 | rgb_r.append(img[:, :, 0, ] == i) 140 | rgb_g.append(img[:, :, 1] == i) 141 | rgb_b.append(img[:, :, 2] == i) 142 | rgb_r, rgb_g, rgb_b = np.array(rgb_r), np.array(rgb_g), np.array(rgb_b) 143 | # print(rgb_r.shape) 144 | rgb = np.concatenate((rgb_r, rgb_g, rgb_b), 0) 145 | # print(rgb.shape) 146 | data = rgb.astype(np.float32) 147 | return data * 0.5 148 | 149 | def __repr__(self): 150 | return self.__class__.__name__ + '()' 151 | 152 | 153 | class color_change(object): 154 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 155 | 156 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 157 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 158 | """ 159 | 160 | def __call__(self, img): 161 | """ 162 | Args: 163 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 164 | 165 | Returns: 166 | Tensor: Converted image. 167 | """ 168 | 169 | color = getcolor() 170 | flag_leg = img == (92, 143, 0) 171 | flag_body = img == (200, 128, 0) 172 | flag_neck = img == (130, 162, 0) 173 | flag_arm = img == (138, 73, 0) 174 | flag_head = img == (0, 236, 163) 175 | flag_back = img == (0, 0, 0) 176 | flag_leg = flag_leg[:, :, 0] * flag_leg[:, :, 1] * flag_leg[:, :, 2] 177 | flag_body = flag_body[:, :, 0] * flag_body[:, :, 1] * flag_body[:, :, 2] 178 | flag_neck = flag_neck[:, :, 0] * flag_neck[:, :, 1] * flag_neck[:, :, 2] 179 | flag_arm = flag_arm[:, :, 0] * flag_arm[:, :, 1] * flag_arm[:, :, 2] 180 | flag_head = flag_head[:, :, 0] * flag_head[:, :, 1] * flag_head[:, :, 2] 181 | flag_back = flag_back[:, :, 0] * flag_back[:, :, 1] * flag_back[:, :, 2] 182 | img[:, :, :] = color.get() 183 | img[flag_head] = color.get() 184 | img[flag_neck] = color.get() 185 | img[flag_body] = color.get() 186 | img[flag_leg] = color.get() 187 | img[flag_arm] = color.get() 188 | img[flag_back] = [0, 0, 0] 189 | img = img.transpose(2, 0, 1) 190 | return img.astype(np.float32) / 255 191 | 192 | def __repr__(self): 193 | return self.__class__.__name__ + '()' 194 | 195 | 196 | class tensor_pad(object): 197 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 198 | 199 | Converts a PIL Image or numpy.ndarray (H x W x C) in the range 200 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 201 | """ 202 | def __init__(self,padding): 203 | self.padding=padding 204 | def __call__(self, img): 205 | """ 206 | Args: 207 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 208 | 209 | Returns: 210 | Tensor: Converted image. 211 | """ 212 | img=F.pad(img,(self.padding,self.padding,self.padding,self.padding)) 213 | return img 214 | 215 | def __repr__(self): 216 | return self.__class__.__name__ + '()' 217 | 218 | # test() 219 | -------------------------------------------------------------------------------- /lib/dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os 3 | from skimage import io 4 | import numpy as np 5 | import visdom 6 | 7 | 8 | class Dataset(data.Dataset): 9 | def __init__(self, path,transform_data=None,transform_labels=None,data_name='my_ai_000'): 10 | self.transform_data,self.transform_labels=transform_data,transform_labels 11 | imgs = [] 12 | files = os.listdir(path) 13 | for file in files: 14 | if data_name in file: 15 | file_path = os.path.join(path, file) 16 | img = io.imread(file_path) 17 | imgs.append(img) 18 | imgs = np.array(imgs) 19 | self.labels = self.init_label(imgs) 20 | self.data=imgs 21 | def init_label(self, imgs): 22 | 23 | flag_leg = imgs == (92, 143, 0) 24 | flag_body = imgs == (200, 128, 0) 25 | flag_neck = imgs == (130, 162, 0) 26 | flag_arm = imgs == (138, 73, 0) 27 | flag_head = imgs == (0, 236, 163) 28 | flag_back = imgs == (0, 0, 0) 29 | flag_leg = flag_leg[:, :, :, 0] * flag_leg[:, :, :, 1] * flag_leg[:, :, :, 2] 30 | flag_body = flag_body[:, :, :, 0] * flag_body[:, :, :, 1] * flag_body[:, :, :, 2] 31 | flag_neck = flag_neck[:, :, :, 0] * flag_neck[:, :, :, 1] * flag_neck[:, :, :, 2] 32 | flag_arm = flag_arm[:, :, :, 0] * flag_arm[:, :, :, 1] * flag_arm[:, :, :, 2] 33 | flag_head = flag_head[:, :, :, 0] * flag_head[:, :, :, 1] * flag_head[:, :, :, 2] 34 | flag_back = flag_back[:, :, :, 0] * flag_back[:, :, :, 1] * flag_back[:, :, :, 2] 35 | flag = np.zeros((imgs.shape[0], imgs.shape[1], imgs.shape[2]), np.long) 36 | flag[flag_head] = 1 37 | flag[flag_neck] = 2 38 | flag[flag_arm] = 3 39 | flag[flag_body] = 4 40 | flag[flag_leg] = 5 41 | flag = np.array(flag) 42 | return flag 43 | 44 | def __getitem__(self, index): 45 | img, target = self.data[index], self.labels[index] 46 | if self.transform_data: 47 | img=self.transform_data(img) 48 | if self.transform_labels: 49 | target=self.transform_labels(target) 50 | return img, target 51 | 52 | def __len__(self): 53 | return len(self.data) 54 | 55 | 56 | def test(path): 57 | a = Dataset(path) 58 | b, c = a[0] 59 | vis = visdom.Visdom() 60 | vis.close(env='test') 61 | vis.image( 62 | c.reshape(1, 200, 200).astype(np.float32) / 5, 63 | env='test', 64 | win='get' 65 | ) 66 | for i in range(768): 67 | vis.image( 68 | b[i], 69 | env='test', 70 | win='tes', 71 | opts=dict(title='predict') 72 | ) 73 | print(b.shape) 74 | print(c.shape) 75 | print(len(a)) 76 | 77 | # test('./my_ai/') 78 | -------------------------------------------------------------------------------- /lib/init.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from models import * 4 | from lib.dataloader import get_data 5 | def init(args): 6 | 7 | print("=================FLAGS==================") 8 | for k, v in args.__dict__.items(): 9 | print('{}: {}'.format(k, v)) 10 | print("========================================") 11 | os.environ['CUDA_VISIBLE_DEVICES'] =args.gpu 12 | print('{}:{}'.format('cuda',torch.cuda.is_available())) 13 | args.cuda = torch.cuda.is_available() #查看cuda是否正常 14 | torch.manual_seed(args.seed)#没看懂是在干什么,应该是随机生成数 15 | if args.cuda: 16 | torch.cuda.manual_seed(args.seed) 17 | train_loader = get_data(args.batch_size_train,data_name=args.data_name,data_root='../data_test') 18 | #if args.model=='xception': 19 | if 'xception' in args.model : 20 | model=Xception() 21 | elif 'densenet' in args.model: 22 | model=DenseNet101() 23 | elif 'shufflenet' in args.model: 24 | model=ShuffleNet(10, g = 8, scale_factor = 1) 25 | elif 'deformnet' in args.model: 26 | model=DeformConvNet() 27 | elif 'unet' in args.model: 28 | model=UNet(6) 29 | if args.pretrain: 30 | model.load_state_dict(torch.load(args.load_params_name)) 31 | elif 'segcaps' in args.model: 32 | model=SegCaps() 33 | model.cuda() 34 | 35 | decreasing_lr = list(map(int,args.dlr.split(','))) 36 | return train_loader,model,decreasing_lr -------------------------------------------------------------------------------- /lib/test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # N is batch size; D_in is input dimension; 4 | # H is hidden dimension; D_out is output dimension. 5 | N, D_in, H, D_out = 64, 1000, 100, 10 6 | 7 | # Create random Tensors to hold inputs and outputs 8 | x = torch.randn(2, 2) 9 | y = torch.randn(N, D_out) 10 | 11 | # Use the nn package to define our model and loss function. 12 | model = torch.nn.Sequential( 13 | torch.nn.Linear(D_in, H), 14 | torch.nn.ReLU(), 15 | torch.nn.Linear(H, D_out), 16 | ) 17 | class cheng_layer(nn.Module): 18 | def __init__(self): 19 | super(cheng_layer, self).__init__() 20 | self.a=nn.Parameter(torch.ones(2,2)*2) 21 | def forward(self, x): 22 | y=x*self.a 23 | return y 24 | model=cheng_layer() 25 | loss_fn = torch.nn.MSELoss(reduction='sum') 26 | 27 | # Use the optim package to define an Optimizer that will update the weights of 28 | # the model for us. Here we will use Adam; the optim package contains many other 29 | # optimization algoriths. The first argument to the Adam constructor tells the 30 | # optimizer which Tensors it should update. 31 | learning_rate = 1 32 | optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 33 | for k,v in model.named_parameters(): 34 | print(k) 35 | print(type(model.parameters())) 36 | for t in range(500): 37 | # Forward pass: compute predicted y by passing x to the model. 38 | y_pred = model(x) 39 | 40 | # Compute and print loss. 41 | print(y_pred) 42 | print(y) 43 | loss = y_pred.sum() 44 | print(t, loss.item()) 45 | a=input('') 46 | # Before the backward pass, use the optimizer object to zero all of the 47 | # gradients for the variables it will update (which are the learnable 48 | # weights of the model). This is because by default, gradients are 49 | # accumulated in buffers( i.e, not overwritten) whenever .backward() 50 | # is called. Checkout docs of torch.autograd.backward for more details. 51 | #optimizer.zero_grad() 52 | 53 | # Backward pass: compute gradient of the loss with respect to model 54 | # parameters 55 | loss.backward() 56 | 57 | # Calling the step function on an Optimizer makes an update to its 58 | # parameters 59 | optimizer.step() -------------------------------------------------------------------------------- /lib/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import sys 4 | 5 | import torch.optim as optim 6 | from torchnet.logger import VisdomLogger, VisdomPlotLogger 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch 10 | import visdom 11 | import numpy as np 12 | 13 | class AverageMeter(object): 14 | """ 15 | Computes and stores the average and current value 16 | Copied from: https://github.com/pytorch/examples/blob/master/imagenet/main.py 17 | """ 18 | 19 | def __init__(self): 20 | self.reset() 21 | 22 | def reset(self): 23 | self.val = 0 24 | self.avg = 0 25 | self.sum = 0 26 | self.count = 0 27 | 28 | def update(self, val, n=1): 29 | self.val = val 30 | self.sum += val * n 31 | self.count += n 32 | self.avg = self.sum / self.count 33 | def compute_loss(output, target): 34 | class_loss = (target * F.relu(0.9 - output)+ 0.5 * (1 - target) * F.relu(output - 0.1)).mean() 35 | return class_loss 36 | 37 | def compute_acc(predict,target,vis): 38 | predict[predict>=0.7]=1 39 | predict[predict<=0.3]=0 40 | vis.image(predict[0].cpu().float().numpy(), env='show',win='predict', opts=dict(title='predict')) 41 | predict = predict != target 42 | acc = torch.sum(predict).float() / torch.numel(target.data) 43 | return acc 44 | def train_epoch(model, loader,optimizer, epoch, n_epochs, ): 45 | batch_time = AverageMeter() 46 | losses = AverageMeter() 47 | accs = AverageMeter() 48 | model.train() 49 | end = time.time() 50 | vis=visdom.Visdom() 51 | for batch_index,(data,target) in enumerate(loader): 52 | 53 | 54 | data = data.cuda() 55 | target[target >= 1] = 1 56 | target = target.float() 57 | target = target.cuda() 58 | output = model(data) 59 | 60 | loss = compute_loss(output, target) 61 | 62 | batch_size = target.size(0) 63 | losses.update(loss.data, batch_size) 64 | 65 | optimizer.zero_grad() 66 | loss.backward() 67 | 68 | optimizer.step() 69 | vis.image(data[0].cpu().numpy(), env='show', win='img', opts=dict(title='img')) 70 | vis.image(target[0].cpu().float().numpy(), env='show', win='target', opts=dict(title='target')) 71 | acc=compute_acc(output.detach(),target,vis) 72 | accs.update(acc) 73 | batch_time.update(time.time() - end) 74 | end = time.time() 75 | res = '\t'.join([ 76 | 'Epoch: [%d/%d]' % (epoch + 1, n_epochs), 77 | 'Batch: [%d/%d]' % (batch_index, len(loader)), 78 | 'Time %.3f (%.3f)' % (batch_time.val, batch_time.avg), 79 | 'Loss %.4f (%.4f)' % (losses.val, losses.avg), 80 | 'Error %.4f (%.4f)' % (accs.val, accs.avg), 81 | ]) 82 | print(res) 83 | return batch_time.avg, losses.avg , accs.avg 84 | 85 | 86 | 87 | def test_epoch(model,loader,epoch,n_epochs): 88 | batch_time = AverageMeter() 89 | losses = AverageMeter() 90 | accs = AverageMeter() 91 | 92 | # Model on eval mode 93 | model.eval() 94 | vis = visdom.Visdom() 95 | with torch.no_grad(): 96 | end = time.time() 97 | for batch_index,(data,target) in enumerate(loader): 98 | data = data.cuda() 99 | target[target >= 1] = 1 100 | target = target.float() 101 | target = target.cuda() 102 | output = model(data) 103 | loss = compute_loss(output, target) 104 | 105 | batch_size = target.size(0) 106 | losses.update(loss.data, batch_size) 107 | vis.image(data[0].cpu().numpy(), env='show', win='img', opts=dict(title='img')) 108 | vis.image(target[0].cpu().float().numpy(), env='show', win='target', opts=dict(title='target')) 109 | acc = compute_acc(output, target,vis) 110 | accs.update(acc) 111 | batch_time.update(time.time() - end) 112 | end = time.time() 113 | res = '\t'.join([ 114 | 'Test', 115 | 'Epoch: [%d/%d]' % (epoch + 1, n_epochs), 116 | 'Batch: [%d/%d]' % (batch_index, len(loader)), 117 | 'Time %.3f (%.3f)' % (batch_time.val, batch_time.avg), 118 | 'Loss %.4f (%.4f)' % (losses.val, losses.avg), 119 | 'Error %.4f (%.4f)' % (accs.val, accs.avg), 120 | ]) 121 | print(res) 122 | return batch_time.avg, losses.avg, accs.avg 123 | 124 | 125 | 126 | def train(args, model,train_loader, decreasing_lr, wd=0.0001, momentum=0.9, ): 127 | if args.seed is not None: 128 | torch.manual_seed(args.seed) 129 | 130 | vis = visdom.Visdom() 131 | vis.close(env=args.model) 132 | test_acc_logger = VisdomPlotLogger('line', env=args.model, opts={'title': 'Test Accuracy'}) 133 | test_loss_logger = VisdomPlotLogger('line', env=args.model, opts={'title': 'Test Loss'}) 134 | train_acc_logger = VisdomPlotLogger('line', env=args.model, opts={'title': 'Train Accuracy'}) 135 | train_loss_logger = VisdomPlotLogger('line', env=args.model, opts={'title': 'Train Loss'}) 136 | lr_logger = VisdomPlotLogger('line', env=args.model, opts={'title': 'Learning Rate'}) 137 | 138 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 139 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=decreasing_lr, 140 | gamma=0.1) 141 | best_train_loss = 10 142 | for epoch in range(args.nepoch): 143 | scheduler.step() 144 | 145 | _, train_loss,train_acc = train_epoch( 146 | model=model, 147 | loader=train_loader, 148 | optimizer=optimizer, 149 | epoch=epoch, 150 | n_epochs=args.nepoch, 151 | ) 152 | _, test_loss,test_acc = test_epoch( 153 | loader=train_loader, 154 | model=model, 155 | epoch=epoch, 156 | n_epochs=args.nepoch, 157 | ) 158 | if best_train_loss>train_loss: 159 | best_train_loss=train_loss 160 | print('best_loss'+str(best_train_loss)) 161 | torch.save(model.state_dict(),args.params_name) 162 | print(train_loss) 163 | train_loss_logger.log(epoch, 1-float(train_loss)) 164 | train_acc_logger.log(epoch,1-float(train_acc)) 165 | test_acc_logger.log(epoch,1-float(test_acc)) 166 | test_loss_logger.log(epoch,float(test_loss)) 167 | lr_logger.log(epoch, optimizer.param_groups[0]['lr']) 168 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.unet import UNet 2 | from .capsule_net import SegCaps -------------------------------------------------------------------------------- /models/capsule_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import models.nn_ as nn_ 5 | import torch.optim as optim 6 | 7 | class CapsuleLayer(nn.Module): 8 | def __init__(self, t_0,z_0, op, k, s, t_1, z_1, routing): 9 | super().__init__() 10 | self.t_1 = t_1 11 | self.z_1 = z_1 12 | self.op = op 13 | self.k = k 14 | self.s = s 15 | self.routing = routing 16 | self.convs = nn.ModuleList() 17 | self.t_0=t_0 18 | for _ in range(t_0): 19 | if self.op=='conv': 20 | self.convs.append(nn.Conv2d(z_0, self.t_1*self.z_1, self.k, self.s,padding=2,bias=False)) 21 | else: 22 | self.convs.append(nn.ConvTranspose2d(z_0, self.t_1 * self.z_1, self.k, self.s,padding=2,output_padding=1)) 23 | 24 | def forward(self, u): # input [N,CAPS,C,H,W] 25 | if u.shape[1]!=self.t_0: 26 | raise ValueError("Wrong type of operation for capsule") 27 | op = self.op 28 | k = self.k 29 | s = self.s 30 | t_1 = self.t_1 31 | z_1 = self.z_1 32 | routing = self.routing 33 | N = u.shape[0] 34 | H_1=u.shape[3] 35 | W_1=u.shape[4] 36 | t_0 = self.t_0 37 | 38 | u_t_list = [u_t.squeeze(1) for u_t in u.split(1, 1)] # 将cap分别取出来 39 | 40 | u_hat_t_list = [] 41 | 42 | for i, u_t in zip(range(self.t_0), u_t_list): # u_t: [N,C,H,W] 43 | if op == "conv": 44 | u_hat_t = self.convs[i](u_t) # 卷积方式 45 | elif op == "deconv": 46 | u_hat_t = self.convs[i](u_t) #u_hat_t: [N,t_1*z_1,H,W] 47 | else: 48 | raise ValueError("Wrong type of operation for capsule") 49 | H_1 = u_hat_t.shape[2] 50 | W_1 = u_hat_t.shape[3] 51 | u_hat_t = u_hat_t.reshape(N, t_1,z_1,H_1, W_1).transpose_(1,3).transpose_(2,4) 52 | u_hat_t_list.append(u_hat_t) #[N,H_1,W_1,t_1,z_1] 53 | v=self.update_routing(u_hat_t_list,k,N,H_1,W_1,t_0,t_1,routing) 54 | return v 55 | def update_routing(self,u_hat_t_list, k, N, H_1, W_1, t_0, t_1, routing): 56 | one_kernel = torch.ones(1, t_1, k, k).cuda() # 不需要学习 57 | b = torch.zeros(N, H_1, W_1, t_0, t_1).cuda() # 不需要学习 58 | b_t_list = [b_t.squeeze(3) for b_t in b.split(1, 3)] 59 | u_hat_t_list_sg = [] 60 | for u_hat_t in u_hat_t_list: 61 | u_hat_t_sg=u_hat_t.detach() 62 | u_hat_t_list_sg.append(u_hat_t_sg) 63 | 64 | for d in range(routing): 65 | if d < routing - 1: 66 | u_hat_t_list_ = u_hat_t_list_sg 67 | else: 68 | u_hat_t_list_ = u_hat_t_list 69 | 70 | r_t_mul_u_hat_t_list = [] 71 | for b_t, u_hat_t in zip(b_t_list, u_hat_t_list_): 72 | # routing softmax (N,H_1,W_1,t_1) 73 | b_t.transpose_(1, 3).transpose_(2, 3) #[N,t_1,H_1, W_1] 74 | b_t_max = torch.nn.functional.max_pool2d(b_t,k,1,padding=2) 75 | b_t_max = b_t_max.max(1, True)[0] 76 | c_t = torch.exp(b_t - b_t_max) 77 | sum_c_t = nn_.conv2d_same(c_t, one_kernel, stride=(1, 1)) # [... , 1] 78 | r_t = c_t / sum_c_t # [N,t_1, H_1, W_1] 79 | r_t = r_t.transpose(1, 3).transpose(1, 2) # [N, H_1, W_1,t_1] 80 | r_t = r_t.unsqueeze(4) # [N, H_1, W_1,t_1, 1] 81 | r_t_mul_u_hat_t_list.append(r_t * u_hat_t) # [N, H_1, W_1, t_1, z_1] 82 | p = sum(r_t_mul_u_hat_t_list) # [N, H_1, W_1, t_1, z_1] 83 | v = squash(p) 84 | if d < routing - 1: 85 | b_t_list_ = [] 86 | for b_t, u_hat_t in zip(b_t_list, u_hat_t_list_): 87 | # b_t : [N, t_1,H_1, W_1] 88 | # u_hat_t : [N, H_1, W_1, t_1, z_1] 89 | # v : [N, H_1, W_1, t_1, z_1] 90 | # [N,H_1,W_1,t_1] 91 | b_t.transpose_(1,3).transpose_(2,1) 92 | b_t_list_.append(b_t + (u_hat_t * v).sum(4)) 93 | v.transpose_(1, 3).transpose_(2, 4) 94 | # print(v.grad) 95 | return v 96 | def squash(self, p): 97 | p_norm_sq = (p * p).sum(-1, True) 98 | p_norm = (p_norm_sq + 1e-9).sqrt() 99 | v = p_norm_sq / (1. + p_norm_sq) * p / p_norm 100 | return v 101 | 102 | 103 | def update_routing(u_hat_t_list,k,N,H_1,W_1,t_0,t_1,routing): 104 | one_kernel = torch.ones(1, t_1, k, k).cuda()#不需要学习 105 | b = torch.zeros(N, H_1, W_1, t_0, t_1 ).cuda()#不需要学习 106 | b_t_list = [b_t.squeeze(3) for b_t in b.split(1, 3)] 107 | u_hat_t_list_sg = [] 108 | for u_hat_t in u_hat_t_list: 109 | u_hat_t_sg = u_hat_t.clone() 110 | u_hat_t_sg.detach_() 111 | u_hat_t_list_sg.append(u_hat_t_sg) 112 | 113 | for d in range(routing): 114 | if d < routing - 1: 115 | u_hat_t_list_ = u_hat_t_list_sg 116 | else: 117 | u_hat_t_list_ = u_hat_t_list 118 | 119 | r_t_mul_u_hat_t_list = [] 120 | for b_t, u_hat_t in zip(b_t_list, u_hat_t_list_): 121 | # routing softmax (N,H_1,W_1,t_1) 122 | b_t.transpose_(1, 3).transpose_(2, 3) 123 | torch.nn.functional.max_pool2d(b_t,k,) 124 | b_t_max = nn_.max_pool2d_same(b_t, k, 1) 125 | b_t_max = b_t_max.max(1, True)[0] 126 | c_t = torch.exp(b_t - b_t_max) 127 | sum_c_t = nn_.conv2d_same(c_t, one_kernel, stride=(1, 1)) # [... , 1] 128 | r_t = c_t / sum_c_t # [N,t_1, H_1, W_1] 129 | r_t = r_t.transpose(1, 3).transpose(1, 2) # [N, H_1, W_1,t_1] 130 | r_t = r_t.unsqueeze(4) # [N, H_1, W_1,t_1, 1] 131 | r_t_mul_u_hat_t_list.append(r_t * u_hat_t) # [N, H_1, W_1, t_1, z_1] 132 | 133 | p = sum(r_t_mul_u_hat_t_list) # [N, H_1, W_1, t_1, z_1] 134 | v = squash(p) 135 | if d < routing - 1: 136 | b_t_list_ = [] 137 | for b_t, u_hat_t in zip(b_t_list, u_hat_t_list_): 138 | # b_t : [N, t_1,H_1, W_1] 139 | # u_hat_t : [N, H_1, W_1, t_1, z_1] 140 | # v : [N, H_1, W_1, t_1, z_1] 141 | b_t = b_t.transpose(1, 3).transpose(1, 2) # [N,H_1,W_1,t_1] 142 | b_t_list_.append(b_t + (u_hat_t * v).sum(4)) 143 | b_t_list = b_t_list_ 144 | v.transpose_(1,3).transpose_(2,4) 145 | #print(v.grad) 146 | return v 147 | 148 | def squash( p): 149 | p_norm_sq = (p * p).sum(-1, True) 150 | p_norm = (p_norm_sq + 1e-9).sqrt() 151 | v = p_norm_sq / (1. + p_norm_sq) * p / p_norm 152 | return v 153 | 154 | def test(): 155 | m=CapsuleLayer(1, 16, "conv", k=5, s=1, t_1=2, z_1=16, routing=1) 156 | #m=cheng_layer() 157 | m=m.cuda() 158 | b=input('s') 159 | a=torch.randn(10, 1, 16, int(b), int(b)) 160 | #a=torch.randn(2,2) 161 | a=a.cuda() 162 | optimizer = optim.Adam(m.parameters(), lr=1) 163 | for k,v in m.named_parameters(): 164 | print(k) 165 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20], 166 | gamma=0.1) 167 | b=m(a) 168 | c=b.mean() 169 | for k in m.parameters(): 170 | print(k) 171 | print(b.shape) 172 | print(c) 173 | c.backward() 174 | optimizer.step() 175 | b=m(a) 176 | c=b.mean() 177 | print(c) 178 | print(a.grad) 179 | print(b.shape) 180 | #print(a[1, :, :, 1, 1]) 181 | 182 | #test() 183 | def test1(): 184 | import os 185 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 186 | a = [] 187 | 188 | b = torch.ones([1, 10, 10, 2, 3]).cuda() 189 | print(b) 190 | a.append(b) 191 | c = update_routing(a, 2, 1, 10, 10, 1, 2,3) 192 | print(c.cpu().numpy()) 193 | 194 | #test1() 195 | -------------------------------------------------------------------------------- /models/capsule_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.capsule_layer import CapsuleLayer 4 | import models.nn_ 5 | 6 | 7 | class SegCaps(nn.Module): 8 | def __init__(self): 9 | super().__init__() 10 | self.conv_1 = nn.Sequential( 11 | nn.Conv2d(3, 16, 5, 1, padding=2,bias=False), 12 | 13 | ) 14 | self.step_1 = nn.Sequential( # 1/2 15 | CapsuleLayer(1, 16, "conv", k=5, s=2, t_1=2, z_1=16, routing=1), 16 | CapsuleLayer(2, 16, "conv", k=5, s=1, t_1=4, z_1=16, routing=3), 17 | ) 18 | self.step_2 = nn.Sequential( # 1/4 19 | CapsuleLayer(4, 16, "conv", k=5, s=2, t_1=4, z_1=32, routing=3), 20 | CapsuleLayer(4, 32, "conv", k=5, s=1, t_1=8, z_1=32, routing=3) 21 | ) 22 | self.step_3 = nn.Sequential( # 1/8 23 | CapsuleLayer(8, 32, "conv", k=5, s=2, t_1=8, z_1=64, routing=3), 24 | CapsuleLayer(8, 64, "conv", k=5, s=1, t_1=8, z_1=32, routing=3) 25 | ) 26 | self.step_4 = CapsuleLayer(8, 32, "deconv", k=5, s=2, t_1=8, z_1=32, routing=3) 27 | 28 | self.step_5 = CapsuleLayer(16, 32, "conv", k=5, s=1, t_1=4, z_1=32, routing=3) 29 | 30 | self.step_6 = CapsuleLayer(4, 32, "deconv", k=5, s=2, t_1=4, z_1=16, routing=3) 31 | self.step_7 = CapsuleLayer(8, 16, "conv", k=5, s=1, t_1=4, z_1=16, routing=3) 32 | self.step_8 = CapsuleLayer(4, 16, "deconv", k=5, s=2, t_1=2, z_1=16, routing=3) 33 | self.step_10 = CapsuleLayer(3, 16, "conv", k=5, s=1, t_1=1, z_1=16, routing=3) 34 | self.conv_2 = nn.Sequential( 35 | nn.Conv2d(16, 1, 5, 1, padding=2), 36 | ) 37 | def forward(self, x): 38 | x = self.conv_1(x) 39 | x.unsqueeze_(1) 40 | 41 | skip_1 = x # [N,1,16,H,W] 42 | 43 | x = self.step_1(x) 44 | 45 | skip_2 = x # [N,4,16,H/2,W/2] 46 | x = self.step_2(x) 47 | 48 | skip_3 = x # [N,8,32,H/4,W/4] 49 | 50 | x = self.step_3(x) # [N,8,32,H/8,W/8] 51 | 52 | 53 | x = self.step_4(x) # [N,8,32,H/4,W/4] 54 | x = torch.cat((x, skip_3), 1) # [N,16,32,H/4,W/4] 55 | 56 | x = self.step_5(x) # [N,4,32,H/4,W/4] 57 | 58 | x = self.step_6(x) # [N,4,16,H/2,W/2] 59 | 60 | x = torch.cat((x, skip_2), 1) # [N,8,16,H/2,W/2] 61 | x = self.step_7(x) # [N,4,16,H/2,W/2] 62 | x = self.step_8(x) # [N,2,16,H,W] 63 | 64 | x=torch.cat((x,skip_1),1) 65 | x=self.step_10(x) 66 | x.squeeze_(1) 67 | v_lens = self.compute_vector_length(x) 68 | v_lens=v_lens.squeeze(1) 69 | return v_lens 70 | def compute_vector_length(self, x): 71 | out = (x.pow(2)).sum(1, True)+1e-9 72 | out=out.sqrt() 73 | return out 74 | 75 | 76 | def test(): 77 | import os 78 | os.environ['CUDA_VISIBLE_DEVICES'] = '6' 79 | model = SegCaps() 80 | model = model.cuda() 81 | print(model) 82 | c = input('s') 83 | a = torch.ones(1, 3, 256, 256) 84 | a = a.cuda() 85 | b = model(a) 86 | print(b) 87 | c=b.sum() 88 | print(c) 89 | c.backward() 90 | for k,v in model.named_parameters(): 91 | a=input('s') 92 | print(v.grad,k) 93 | # from tensorboardX import SummaryWriter 94 | # with SummaryWriter(comment='LeNet') as w: 95 | # w.add_graph(model, a) 96 | print(b.shape) 97 | print(b) 98 | #test() 99 | def compute_vector_length( x): 100 | out = (x.pow(2)).sum(1, True)+1e-9 101 | out.sqrt_() 102 | return out -------------------------------------------------------------------------------- /models/nn_.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from torch.nn import functional as F 3 | 4 | import math 5 | import torch 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.functional import pad 8 | from torch.nn.modules import Module 9 | from torch.nn import ConvTranspose2d 10 | from torch.nn.modules.utils import _single, _pair, _triple 11 | import torch.nn as nn 12 | 13 | 14 | class _ConvNd(Module): 15 | 16 | def __init__(self, in_channels, out_channels, kernel_size, stride, 17 | padding, dilation, transposed, output_padding, groups, bias): 18 | super(_ConvNd, self).__init__() 19 | if in_channels % groups != 0: 20 | raise ValueError('in_channels must be divisible by groups') 21 | if out_channels % groups != 0: 22 | raise ValueError('out_channels must be divisible by groups') 23 | self.in_channels = in_channels 24 | self.out_channels = out_channels 25 | self.kernel_size = kernel_size 26 | self.stride = stride 27 | self.padding = padding 28 | self.dilation = dilation 29 | self.transposed = transposed 30 | self.output_padding = output_padding 31 | self.groups = groups 32 | if transposed: 33 | self.weight = Parameter(torch.Tensor( 34 | in_channels, out_channels // groups, *kernel_size)) 35 | else: 36 | self.weight = Parameter(torch.Tensor( 37 | out_channels, in_channels // groups, *kernel_size)) 38 | if bias: 39 | self.bias = Parameter(torch.Tensor(out_channels)) 40 | else: 41 | self.register_parameter('bias', None) 42 | self.reset_parameters() 43 | 44 | def reset_parameters(self): 45 | n = self.in_channels 46 | for k in self.kernel_size: 47 | n *= k 48 | stdv = 1. / math.sqrt(n) 49 | self.weight.data.uniform_(-stdv, stdv) 50 | if self.bias is not None: 51 | self.bias.data.uniform_(-stdv, stdv) 52 | 53 | def extra_repr(self): 54 | s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}' 55 | ', stride={stride}') 56 | if self.padding != (0,) * len(self.padding): 57 | s += ', padding={padding}' 58 | if self.dilation != (1,) * len(self.dilation): 59 | s += ', dilation={dilation}' 60 | if self.output_padding != (0,) * len(self.output_padding): 61 | s += ', output_padding={output_padding}' 62 | if self.groups != 1: 63 | s += ', groups={groups}' 64 | if self.bias is None: 65 | s += ', bias=False' 66 | return s.format(**self.__dict__) 67 | 68 | 69 | class Conv2d(_ConvNd): 70 | 71 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 72 | padding=0, dilation=1, groups=1, bias=False): 73 | kernel_size = _pair(kernel_size) 74 | stride = _pair(stride) 75 | padding = _pair(padding) 76 | dilation = _pair(dilation) 77 | super(Conv2d, self).__init__( 78 | in_channels, out_channels, kernel_size, stride, padding, dilation, 79 | False, _pair(0), groups, bias) 80 | 81 | # 修改这里的实现函数 82 | def forward(self, input): 83 | return conv2d_same(input, self.weight, self.bias, self.stride, 84 | self.dilation, self.groups) 85 | 86 | class ConvTranspose2d(nn.ConvTranspose2d): 87 | 88 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 89 | padding=0, output_padding=0, groups=1, bias=False, dilation=1): 90 | super(ConvTranspose2d, self).__init__( 91 | in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias, dilation) 92 | 93 | def forward(self, input): 94 | input_size = input.size(2) 95 | output_size = input_size*self.stride[0] 96 | pad_l, pad_r = get_same(input_size,self.kernel_size[0],self.stride[0],dilation=1) 97 | #print(pad_l,pad_r) 98 | self.padding=max(pad_l,pad_r) 99 | input_size=(input_size-1)*self.stride[0]+self.kernel_size[0]-2*self.padding 100 | #print(input_size) 101 | output_padding=output_size-input_size 102 | #print(output_padding) 103 | return F.conv_transpose2d( 104 | input, self.weight, self.bias, self.stride, self.padding, 105 | output_padding, self.groups, self.dilation) 106 | 107 | 108 | 109 | def conv2d_same(input, weight, bias=None, stride=[1, 1], dilation=(1, 1), groups=1): 110 | input_rows = input.size(2) 111 | filter_rows = weight.size(2) 112 | out_rows = (input_rows + stride[0] - 1) // stride[0] 113 | padding_rows = max(0, (out_rows - 1) * stride[0] + 114 | (filter_rows - 1) * dilation[0] + 1 - input_rows) 115 | rows_odd = (padding_rows % 2 != 0) 116 | padding_cols = max(0, (out_rows - 1) * stride[0] + 117 | (filter_rows - 1) * dilation[0] + 1 - input_rows) 118 | cols_odd = (padding_rows % 2 != 0) 119 | if rows_odd or cols_odd: 120 | input = pad(input, [0, int(cols_odd), 0, int(rows_odd)]) 121 | return F.conv2d(input, weight, bias, stride, 122 | padding=(padding_rows // 2, padding_cols // 2), 123 | dilation=dilation, groups=groups) 124 | 125 | 126 | def max_pool2d_same(input, kernel_size, stride=1, dilation=1, ceil_mode=False, return_indices=False): 127 | input_rows = input.size(2) 128 | out_rows = (input_rows + stride - 1) // stride 129 | padding_rows = max(0, (out_rows - 1) * stride + 130 | (kernel_size - 1) * dilation + 1 - input_rows) 131 | rows_odd = (padding_rows % 2 != 0) 132 | cols_odd = (padding_rows % 2 != 0) 133 | if rows_odd or cols_odd: 134 | input = pad(input, [0, int(cols_odd), 0, int(rows_odd)]) 135 | return F.max_pool2d(input, kernel_size=kernel_size, stride=stride, padding=padding_rows // 2, dilation=dilation, 136 | ceil_mode=ceil_mode, return_indices=return_indices) 137 | 138 | 139 | def get_same(size, kernel, stride, dilation): 140 | out_size = (size + stride - 1) // stride 141 | padding = max(0, (out_size - 1) * stride + 142 | (kernel - 1) * dilation + 1 - size) 143 | size_odd = (padding % 2 != 0) 144 | pad_l = padding // 2 145 | pad_r = padding // 2 146 | if size_odd: 147 | pad_l += 1 148 | return pad_l, pad_r 149 | -------------------------------------------------------------------------------- /models/shufflenet.py: -------------------------------------------------------------------------------- 1 | # encoding: utf-8 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from models.blocks import idenUnit, poolUnit 8 | 9 | 10 | class ShuffleNet(nn.Module): 11 | def __init__(self, output_size, scale_factor = 1, g = 8): 12 | super(ShuffleNet, self).__init__() 13 | self.g = g 14 | # self.cs = {1: 144, 2: 200, 3: 240, 4: 272, 8: 384} 15 | self.cs = {1: 144, 2: 200, 3: 240, 4: 272, 8: 384} 16 | 17 | # compute output channels for stages 18 | c2 = self.cs[self.g] 19 | c2 = int(scale_factor * c2) 20 | c3, c4 = 2*c2, 4*c2 21 | 22 | # first conv layer & last fc layer 23 | self.conv1 = nn.Conv2d(3, 24, kernel_size = 3, padding = 1, stride = 1, bias = False) 24 | self.bn1 = nn.BatchNorm2d(24) 25 | 26 | self.fc = nn.Linear(c4, output_size) 27 | 28 | # build stages 29 | self.stage2 = self.build_stage(24, c2, repeat_time = 3, first_group = False, downsample = False) 30 | self.stage3 = self.build_stage(c2, c3, repeat_time = 7) 31 | self.stage4 = self.build_stage(c3, c4, repeat_time = 3) 32 | 33 | # weights init 34 | self.weights_init() 35 | 36 | 37 | def build_stage(self, input_channel, output_channel, repeat_time, first_group = True, downsample = True): 38 | stage = [poolUnit(input_channel, output_channel, self.g, first_group = first_group, downsample = downsample)] 39 | 40 | for i in range(repeat_time): 41 | stage.append(idenUnit(output_channel, self.g)) 42 | 43 | return nn.Sequential(*stage) 44 | 45 | 46 | 47 | def weights_init(self): 48 | for m in self.modules(): 49 | if isinstance(m, nn.Conv2d): 50 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 51 | m.weight.data.normal_(0, math.sqrt(2. / n)) 52 | 53 | elif isinstance(m, nn.BatchNorm2d): 54 | m.weight.data.fill_(1) 55 | m.bias.data.zero_() 56 | 57 | 58 | def forward(self, inputs): 59 | 60 | # first conv layer 61 | x = F.relu(self.bn1(self.conv1(inputs))) 62 | # x = F.max_pool2d(x, kernel_size = 3, stride = 2, padding = 1) 63 | # assert x.shape[1:] == torch.Size([24,56,56]) 64 | 65 | # bottlenecks 66 | x = self.stage2(x) 67 | x = self.stage3(x) 68 | x = self.stage4(x) 69 | # print(x.shape) 70 | 71 | # global pooling and fc (in place of conv 1x1 in paper) 72 | x = F.adaptive_avg_pool2d(x, 1) 73 | x = x.view(x.shape[0], -1) 74 | x = self.fc(x) 75 | 76 | return x 77 | 78 | 79 | def test(): 80 | from count import measure_model 81 | import numpy as np 82 | 83 | x = np.random.randn(10, 3, 32, 32).astype(np.float32) 84 | x = torch.from_numpy(x) 85 | 86 | net = ShuffleNet(10, g = 1, scale_factor = 0.5) 87 | f, c = measure_model(net, 32, 32) 88 | print("model size %.4f M, ops %.4f M" %(c/1e6, f/1e6)) -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def double_conv(in_channels, out_channels): 6 | return nn.Sequential( 7 | nn.Conv2d(in_channels, out_channels, 3, padding=1), 8 | nn.BatchNorm2d(out_channels), 9 | nn.ReLU(inplace=True), 10 | nn.Conv2d(out_channels, out_channels, 3, padding=1), 11 | nn.BatchNorm2d(out_channels), 12 | nn.ReLU(inplace=True) 13 | ) 14 | 15 | 16 | class UNet(nn.Module): 17 | 18 | def __init__(self, n_class): 19 | super().__init__() 20 | 21 | self.dconv_down1 = double_conv(3, 64) 22 | self.dconv_down2 = double_conv(64, 128) 23 | self.dconv_down3 = double_conv(128, 256) 24 | self.dconv_down4 = double_conv(256, 512) 25 | 26 | self.dconv_up3 = double_conv(256 + 512, 256) 27 | self.dconv_up2 = double_conv(128 + 256, 128) 28 | self.dconv_up1 = double_conv(128 + 64, 64) 29 | 30 | self.conv_last = nn.Conv2d(64, n_class, 1) 31 | 32 | def forward(self, x): 33 | conv1 = self.dconv_down1(x) 34 | x = conv1 35 | conv2 = self.dconv_down2(x) 36 | x = conv2 37 | conv3 = self.dconv_down3(x) 38 | x = conv3 39 | x = self.dconv_down4(x) 40 | x = torch.cat([x, conv3], dim=1) 41 | 42 | x = self.dconv_up3(x) 43 | 44 | x = torch.cat([x, conv2], dim=1) 45 | 46 | x = self.dconv_up2(x) 47 | 48 | x = torch.cat([x, conv1], dim=1) 49 | 50 | x = self.dconv_up1(x) 51 | out = self.conv_last(x) 52 | 53 | return out 54 | -------------------------------------------------------------------------------- /my_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from lib.init import init 3 | from lib.train import train 4 | 5 | 6 | 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser(description="Little Ma's train") 11 | parser.add_argument('--batch_size_train', type=int, default=20, help='input batch size for training (default: 160)') 12 | parser.add_argument('--batch_size_test', type=int, default=60, help='input batch size for testing (default: 80)') 13 | parser.add_argument('--lr', type=float, default=0.01, help='learning rate (default: 1e-3)') 14 | parser.add_argument('--gpu', default='3', help='index of gpus to use') 15 | parser.add_argument('--dlr', default='10,25', help='decreasing strategy') 16 | parser.add_argument('--model', default='segcaps-train', help='which model (default: xception)') 17 | parser.add_argument('--data_root', default='./', help='data_root (default: ./)') 18 | parser.add_argument('--nepoch', type=int,default=50, help='epochs (default: 200)') 19 | parser.add_argument('--seed', type=int,default='10', help='seed (default: 1)') 20 | parser.add_argument('--pretrain', type=int, default='0', help='pretrain (default: 1)') 21 | parser.add_argument('--data_name', default='train', help='data_name (default: train)') 22 | parser.add_argument('--params_name', default='segcaps.pkl', help='params_name (default: segcaps.pkl)') 23 | parser.add_argument('--load_params_name', default='segcaps.pkl', help='params_name (default: segcaps.pkl)') 24 | args = parser.parse_args() 25 | train_loader,model,decreasing_lr=init(args) 26 | train(args,model,train_loader, 27 | decreasing_lr,wd=0.0001, momentum=0.9) 28 | print('hhh') 29 | print('Done!') 30 | 31 | 32 | --------------------------------------------------------------------------------