├── BSD500 ├── 101027.jpg ├── 108004.jpg └── 196062.jpg ├── BBC ├── ref │ └── frame426.jpg └── test │ ├── frame426.jpg │ ├── frame431.jpg │ ├── frame437.jpg │ ├── frame443.jpg │ ├── frame449.jpg │ ├── frame455.jpg │ ├── frame461.jpg │ ├── frame467.jpg │ ├── frame473.jpg │ ├── frame479.jpg │ └── frame485.jpg ├── PASCAL_VOC_2012 ├── 2007_001774.jpg ├── 2007_005915.jpg ├── 2007_008670.jpg ├── 2008_001439.jpg ├── 2008_003709.jpg ├── 2009_000421.jpg ├── 2007_001774_scribble.png ├── 2007_005915_scribble.png ├── 2007_008670_scribble.png ├── 2008_001439_scribble.png ├── 2008_003709_scribble.png └── 2009_000421_scribble.png ├── LICENSE ├── README.md ├── demo.py └── demo_ref.py /BSD500/101027.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/BSD500/101027.jpg -------------------------------------------------------------------------------- /BSD500/108004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/BSD500/108004.jpg -------------------------------------------------------------------------------- /BSD500/196062.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/BSD500/196062.jpg -------------------------------------------------------------------------------- /BBC/ref/frame426.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/BBC/ref/frame426.jpg -------------------------------------------------------------------------------- /BBC/test/frame426.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/BBC/test/frame426.jpg -------------------------------------------------------------------------------- /BBC/test/frame431.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/BBC/test/frame431.jpg -------------------------------------------------------------------------------- /BBC/test/frame437.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/BBC/test/frame437.jpg -------------------------------------------------------------------------------- /BBC/test/frame443.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/BBC/test/frame443.jpg -------------------------------------------------------------------------------- /BBC/test/frame449.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/BBC/test/frame449.jpg -------------------------------------------------------------------------------- /BBC/test/frame455.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/BBC/test/frame455.jpg -------------------------------------------------------------------------------- /BBC/test/frame461.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/BBC/test/frame461.jpg -------------------------------------------------------------------------------- /BBC/test/frame467.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/BBC/test/frame467.jpg -------------------------------------------------------------------------------- /BBC/test/frame473.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/BBC/test/frame473.jpg -------------------------------------------------------------------------------- /BBC/test/frame479.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/BBC/test/frame479.jpg -------------------------------------------------------------------------------- /BBC/test/frame485.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/BBC/test/frame485.jpg -------------------------------------------------------------------------------- /PASCAL_VOC_2012/2007_001774.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/PASCAL_VOC_2012/2007_001774.jpg -------------------------------------------------------------------------------- /PASCAL_VOC_2012/2007_005915.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/PASCAL_VOC_2012/2007_005915.jpg -------------------------------------------------------------------------------- /PASCAL_VOC_2012/2007_008670.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/PASCAL_VOC_2012/2007_008670.jpg -------------------------------------------------------------------------------- /PASCAL_VOC_2012/2008_001439.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/PASCAL_VOC_2012/2008_001439.jpg -------------------------------------------------------------------------------- /PASCAL_VOC_2012/2008_003709.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/PASCAL_VOC_2012/2008_003709.jpg -------------------------------------------------------------------------------- /PASCAL_VOC_2012/2009_000421.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/PASCAL_VOC_2012/2009_000421.jpg -------------------------------------------------------------------------------- /PASCAL_VOC_2012/2007_001774_scribble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/PASCAL_VOC_2012/2007_001774_scribble.png -------------------------------------------------------------------------------- /PASCAL_VOC_2012/2007_005915_scribble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/PASCAL_VOC_2012/2007_005915_scribble.png -------------------------------------------------------------------------------- /PASCAL_VOC_2012/2007_008670_scribble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/PASCAL_VOC_2012/2007_008670_scribble.png -------------------------------------------------------------------------------- /PASCAL_VOC_2012/2008_001439_scribble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/PASCAL_VOC_2012/2008_001439_scribble.png -------------------------------------------------------------------------------- /PASCAL_VOC_2012/2008_003709_scribble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/PASCAL_VOC_2012/2008_003709_scribble.png -------------------------------------------------------------------------------- /PASCAL_VOC_2012/2009_000421_scribble.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kanezaki/pytorch-unsupervised-segmentation-tip/HEAD/PASCAL_VOC_2012/2009_000421_scribble.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Asako Kanezaki and Wonijk Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Unsupervised Learning of Image Segmentation Based on Differentiable Feature Clustering 2 | 3 | This pytorch code generates segmentation labels of an input image. 4 | 5 | ![Unsupervised Image Segmentation with Scribbles](https://kanezaki.github.io/media/unsupervised_image_segmentation_with_scribbles.png) 6 | 7 | Wonjik Kim\*, Asako Kanezaki\*, and Masayuki Tanaka. 8 | **Unsupervised Learning of Image Segmentation Based on Differentiable Feature Clustering.** 9 | *IEEE Transactions on Image Processing*, accepted, 2020. 10 | ([arXiv](https://arxiv.org/abs/2007.09990)) 11 | 12 | \*W. Kim and A. Kanezaki contributed equally to this work. 13 | 14 | ## What is new? 15 | 16 | This is an extension of our [previous work](https://github.com/kanezaki/pytorch-unsupervised-segmentation). 17 | 18 | - Better performance with spatial continuity loss 19 | - Option of using scribbles as user input 20 | - Option of using reference image(s) 21 | 22 | ## Requirements 23 | 24 | pytorch, opencv2, tqdm 25 | 26 | ## Getting started 27 | 28 | ### Vanilla 29 | 30 | $ python demo.py --input ./BSD500/101027.jpg 31 | 32 | ### Vanilla + scribbles 33 | 34 | $ python demo.py --input ./PASCAL_VOC_2012/2007_001774.jpg --scribble 35 | 36 | ### Vanilla + reference image(s) 37 | 38 | $ python demo_ref.py --input ./BBC/ 39 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | #from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.autograd import Variable 9 | import cv2 10 | import sys 11 | import numpy as np 12 | import torch.nn.init 13 | import random 14 | 15 | use_cuda = torch.cuda.is_available() 16 | 17 | parser = argparse.ArgumentParser(description='PyTorch Unsupervised Segmentation') 18 | parser.add_argument('--scribble', action='store_true', default=False, 19 | help='use scribbles') 20 | parser.add_argument('--nChannel', metavar='N', default=100, type=int, 21 | help='number of channels') 22 | parser.add_argument('--maxIter', metavar='T', default=1000, type=int, 23 | help='number of maximum iterations') 24 | parser.add_argument('--minLabels', metavar='minL', default=3, type=int, 25 | help='minimum number of labels') 26 | parser.add_argument('--lr', metavar='LR', default=0.1, type=float, 27 | help='learning rate') 28 | parser.add_argument('--nConv', metavar='M', default=2, type=int, 29 | help='number of convolutional layers') 30 | parser.add_argument('--visualize', metavar='1 or 0', default=1, type=int, 31 | help='visualization flag') 32 | parser.add_argument('--input', metavar='FILENAME', 33 | help='input image file name', required=True) 34 | parser.add_argument('--stepsize_sim', metavar='SIM', default=1, type=float, 35 | help='step size for similarity loss', required=False) 36 | parser.add_argument('--stepsize_con', metavar='CON', default=1, type=float, 37 | help='step size for continuity loss') 38 | parser.add_argument('--stepsize_scr', metavar='SCR', default=0.5, type=float, 39 | help='step size for scribble loss') 40 | args = parser.parse_args() 41 | 42 | # CNN model 43 | class MyNet(nn.Module): 44 | def __init__(self,input_dim): 45 | super(MyNet, self).__init__() 46 | self.conv1 = nn.Conv2d(input_dim, args.nChannel, kernel_size=3, stride=1, padding=1 ) 47 | self.bn1 = nn.BatchNorm2d(args.nChannel) 48 | self.conv2 = nn.ModuleList() 49 | self.bn2 = nn.ModuleList() 50 | for i in range(args.nConv-1): 51 | self.conv2.append( nn.Conv2d(args.nChannel, args.nChannel, kernel_size=3, stride=1, padding=1 ) ) 52 | self.bn2.append( nn.BatchNorm2d(args.nChannel) ) 53 | self.conv3 = nn.Conv2d(args.nChannel, args.nChannel, kernel_size=1, stride=1, padding=0 ) 54 | self.bn3 = nn.BatchNorm2d(args.nChannel) 55 | 56 | def forward(self, x): 57 | x = self.conv1(x) 58 | x = F.relu( x ) 59 | x = self.bn1(x) 60 | for i in range(args.nConv-1): 61 | x = self.conv2[i](x) 62 | x = F.relu( x ) 63 | x = self.bn2[i](x) 64 | x = self.conv3(x) 65 | x = self.bn3(x) 66 | return x 67 | 68 | # load image 69 | im = cv2.imread(args.input) 70 | data = torch.from_numpy( np.array([im.transpose( (2, 0, 1) ).astype('float32')/255.]) ) 71 | if use_cuda: 72 | data = data.cuda() 73 | data = Variable(data) 74 | 75 | # load scribble 76 | if args.scribble: 77 | mask = cv2.imread(args.input.replace('.'+args.input.split('.')[-1],'_scribble.png'),-1) 78 | mask = mask.reshape(-1) 79 | mask_inds = np.unique(mask) 80 | mask_inds = np.delete( mask_inds, np.argwhere(mask_inds==255) ) 81 | inds_sim = torch.from_numpy( np.where( mask == 255 )[ 0 ] ) 82 | inds_scr = torch.from_numpy( np.where( mask != 255 )[ 0 ] ) 83 | target_scr = torch.from_numpy( mask.astype(np.int) ) 84 | if use_cuda: 85 | inds_sim = inds_sim.cuda() 86 | inds_scr = inds_scr.cuda() 87 | target_scr = target_scr.cuda() 88 | target_scr = Variable( target_scr ) 89 | # set minLabels 90 | args.minLabels = len(mask_inds) 91 | 92 | # train 93 | model = MyNet( data.size(1) ) 94 | if use_cuda: 95 | model.cuda() 96 | model.train() 97 | 98 | # similarity loss definition 99 | loss_fn = torch.nn.CrossEntropyLoss() 100 | 101 | # scribble loss definition 102 | loss_fn_scr = torch.nn.CrossEntropyLoss() 103 | 104 | # continuity loss definition 105 | loss_hpy = torch.nn.L1Loss(size_average = True) 106 | loss_hpz = torch.nn.L1Loss(size_average = True) 107 | 108 | HPy_target = torch.zeros(im.shape[0]-1, im.shape[1], args.nChannel) 109 | HPz_target = torch.zeros(im.shape[0], im.shape[1]-1, args.nChannel) 110 | if use_cuda: 111 | HPy_target = HPy_target.cuda() 112 | HPz_target = HPz_target.cuda() 113 | 114 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) 115 | label_colours = np.random.randint(255,size=(100,3)) 116 | 117 | for batch_idx in range(args.maxIter): 118 | # forwarding 119 | optimizer.zero_grad() 120 | output = model( data )[ 0 ] 121 | output = output.permute( 1, 2, 0 ).contiguous().view( -1, args.nChannel ) 122 | 123 | outputHP = output.reshape( (im.shape[0], im.shape[1], args.nChannel) ) 124 | HPy = outputHP[1:, :, :] - outputHP[0:-1, :, :] 125 | HPz = outputHP[:, 1:, :] - outputHP[:, 0:-1, :] 126 | lhpy = loss_hpy(HPy,HPy_target) 127 | lhpz = loss_hpz(HPz,HPz_target) 128 | 129 | ignore, target = torch.max( output, 1 ) 130 | im_target = target.data.cpu().numpy() 131 | nLabels = len(np.unique(im_target)) 132 | if args.visualize: 133 | im_target_rgb = np.array([label_colours[ c % args.nChannel ] for c in im_target]) 134 | im_target_rgb = im_target_rgb.reshape( im.shape ).astype( np.uint8 ) 135 | cv2.imshow( "output", im_target_rgb ) 136 | cv2.waitKey(10) 137 | 138 | # loss 139 | if args.scribble: 140 | loss = args.stepsize_sim * loss_fn(output[ inds_sim ], target[ inds_sim ]) + args.stepsize_scr * loss_fn_scr(output[ inds_scr ], target_scr[ inds_scr ]) + args.stepsize_con * (lhpy + lhpz) 141 | else: 142 | loss = args.stepsize_sim * loss_fn(output, target) + args.stepsize_con * (lhpy + lhpz) 143 | 144 | loss.backward() 145 | optimizer.step() 146 | 147 | print (batch_idx, '/', args.maxIter, '|', ' label num :', nLabels, ' | loss :', loss.item()) 148 | 149 | if nLabels <= args.minLabels: 150 | print ("nLabels", nLabels, "reached minLabels", args.minLabels, ".") 151 | break 152 | 153 | # save output image 154 | if not args.visualize: 155 | output = model( data )[ 0 ] 156 | output = output.permute( 1, 2, 0 ).contiguous().view( -1, args.nChannel ) 157 | ignore, target = torch.max( output, 1 ) 158 | im_target = target.data.cpu().numpy() 159 | im_target_rgb = np.array([label_colours[ c % args.nChannel ] for c in im_target]) 160 | im_target_rgb = im_target_rgb.reshape( im.shape ).astype( np.uint8 ) 161 | cv2.imwrite( "output.png", im_target_rgb ) 162 | -------------------------------------------------------------------------------- /demo_ref.py: -------------------------------------------------------------------------------- 1 | #from __future__ import print_function 2 | import argparse 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torchvision import datasets, transforms 8 | from torch.autograd import Variable 9 | import cv2 10 | import sys 11 | import os 12 | import numpy as np 13 | import torch.nn.init 14 | import random 15 | import glob 16 | import datetime 17 | import tqdm 18 | 19 | use_cuda = torch.cuda.is_available() 20 | 21 | parser = argparse.ArgumentParser(description='PyTorch Unsupervised Segmentation') 22 | parser.add_argument('--nChannel', metavar='N', default=100, type=int, 23 | help='number of channels') 24 | parser.add_argument('--maxIter', metavar='T', default=1, type=int, 25 | help='number of maximum iterations') 26 | parser.add_argument('--maxUpdate', metavar='T', default=1000, type=int, 27 | help='number of maximum update count') 28 | parser.add_argument('--minLabels', metavar='minL', default=3, type=int, 29 | help='minimum number of labels') 30 | parser.add_argument('--batch_size', metavar='bsz', default=1, type=int, 31 | help='number of batch_size') 32 | parser.add_argument('--lr', metavar='LR', default=0.1, type=float, 33 | help='learning rate') 34 | parser.add_argument('--nConv', metavar='M', default=2, type=int, 35 | help='number of convolutional layers') 36 | parser.add_argument('--visualize', metavar='1 or 0', default=1, type=int, 37 | help='visualization flag') 38 | parser.add_argument('--input', metavar='FOLDERNAME', 39 | help='input image folder name', required=True) 40 | parser.add_argument('--stepsize_sim', metavar='SIM', default=1, type=float, 41 | help='step size for similarity loss', required=False) 42 | parser.add_argument('--stepsize_con', metavar='CON', default=5, type=float, 43 | help='step size for continuity loss') 44 | args = parser.parse_args() 45 | 46 | # CNN model 47 | class MyNet(nn.Module): 48 | def __init__(self,input_dim): 49 | super(MyNet, self).__init__() 50 | self.conv1 = nn.Conv2d(input_dim, args.nChannel, kernel_size=3, stride=1, padding=1 ) 51 | self.bn1 = nn.BatchNorm2d(args.nChannel) 52 | self.conv2 = nn.ModuleList() 53 | self.bn2 = nn.ModuleList() 54 | for i in range(args.nConv-1): 55 | self.conv2.append( nn.Conv2d(args.nChannel, args.nChannel, kernel_size=3, stride=1, padding=1 ) ) 56 | self.bn2.append( nn.BatchNorm2d(args.nChannel) ) 57 | self.conv3 = nn.Conv2d(args.nChannel, args.nChannel, kernel_size=1, stride=1, padding=0 ) 58 | self.bn3 = nn.BatchNorm2d(args.nChannel) 59 | 60 | def forward(self, x): 61 | x = self.conv1(x) 62 | x = F.relu( x ) 63 | x = self.bn1(x) 64 | for i in range(args.nConv-1): 65 | x = self.conv2[i](x) 66 | x = F.relu( x ) 67 | x = self.bn2[i](x) 68 | x = self.conv3(x) 69 | x = self.bn3(x) 70 | return x 71 | 72 | img_list = sorted(glob.glob(args.input+'/ref/*')) 73 | im = cv2.imread(img_list[0]) 74 | 75 | # train 76 | model = MyNet( im.shape[2] ) 77 | if use_cuda: 78 | model.cuda() 79 | model.train() 80 | 81 | # similarity loss definition 82 | loss_fn = torch.nn.CrossEntropyLoss() 83 | 84 | # continuity loss definition 85 | loss_hpy = torch.nn.L1Loss(size_average = True) 86 | loss_hpz = torch.nn.L1Loss(size_average = True) 87 | 88 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9) 89 | label_colours = np.random.randint(255,size=(100,3)) 90 | 91 | for batch_idx in range(args.maxIter): 92 | print('Training started. '+str(datetime.datetime.now())+' '+str(batch_idx+1)+' / '+str(args.maxIter)) 93 | for im_file in range(int(len(img_list)/args.batch_size)): 94 | for loop in tqdm.tqdm(range(args.maxUpdate)): 95 | im = [] 96 | for batch_count in range(args.batch_size): 97 | # load image 98 | resized_im = cv2.imread(img_list[args.batch_size*im_file + batch_count]) 99 | resized_im = cv2.resize(resized_im, dsize=(224, 224)) 100 | resized_im = resized_im.transpose( (2, 0, 1) ).astype('float32')/255. 101 | im.append(resized_im) 102 | 103 | data = torch.from_numpy( np.array(im) ) 104 | if use_cuda: 105 | data = data.cuda() 106 | data = Variable(data) 107 | 108 | HPy_target = torch.zeros(data.shape[0], resized_im.shape[1]-1, resized_im.shape[2], args.nChannel) 109 | HPz_target = torch.zeros(data.shape[0], resized_im.shape[1], resized_im.shape[2]-1, args.nChannel) 110 | if use_cuda: 111 | HPy_target = HPy_target.cuda() 112 | HPz_target = HPz_target.cuda() 113 | 114 | # forwarding 115 | optimizer.zero_grad() 116 | output = model( data ) 117 | output = output.permute( 0, 2, 3, 1 ).contiguous().view( data.shape[0], -1, args.nChannel ) 118 | 119 | outputHP = output.reshape( (data.shape[0], resized_im.shape[1], resized_im.shape[2], args.nChannel) ) 120 | 121 | HPy = outputHP[:, 1:, :, :] - outputHP[:, 0:-1, :, :] 122 | HPz = outputHP[:, :, 1:, :] - outputHP[:, :, 0:-1, :] 123 | lhpy = loss_hpy(HPy,HPy_target) 124 | lhpz = loss_hpz(HPz,HPz_target) 125 | 126 | output = output.reshape( output.shape[0] * output.shape[1], -1 ) 127 | ignore, target = torch.max( output, 1 ) 128 | 129 | loss = args.stepsize_sim * loss_fn(output, target) + args.stepsize_con * (lhpy + lhpz) 130 | loss.backward() 131 | optimizer.step() 132 | 133 | torch.save(model.state_dict(), os.path.join(args.input, 'b'+str(args.batch_size)+'_itr'+str(args.maxIter)+'_layer'+str(args.nConv+1)+'.pth')) 134 | 135 | label_colours = np.random.randint(255,size=(100,3)) 136 | test_img_list = sorted(glob.glob(args.input+'/test/*')) 137 | if not os.path.exists(os.path.join(args.input, 'result/')): 138 | os.mkdir(os.path.join(args.input, 'result/')) 139 | print('Testing '+str(len(test_img_list))+' images.') 140 | for img_file in tqdm.tqdm(test_img_list): 141 | im = cv2.imread(img_file) 142 | data = torch.from_numpy( np.array([im.transpose( (2, 0, 1) ).astype('float32')/255.]) ) 143 | if use_cuda: 144 | data = data.cuda() 145 | data = Variable(data) 146 | output = model( data )[ 0 ] 147 | output = output.permute( 1, 2, 0 ).contiguous().view( -1, args.nChannel ) 148 | ignore, target = torch.max( output, 1 ) 149 | inds = target.data.cpu().numpy().reshape( (im.shape[0], im.shape[1]) ) 150 | inds_rgb = np.array([label_colours[ c % args.nChannel ] for c in inds]) 151 | inds_rgb = inds_rgb.reshape( im.shape ).astype( np.uint8 ) 152 | cv2.imwrite( os.path.join(args.input, 'result/') + os.path.basename(img_file), inds_rgb ) 153 | --------------------------------------------------------------------------------