')
47 | print('')
48 | print(''.format(i))
49 | print('  '.format(image_ids[i], w,h))
50 | print(' ')
51 | print(' | ')
52 | print('')
53 | print(''.format(i))
54 | print('  '.format(image_ids[i], w,h))
55 | print(' ')
56 | print(' | ')
57 | print('')
58 | print(''.format(i))
59 | print('  '.format(image_ids[i], w,h))
60 | print(' ')
61 | print(' | ')
62 | print('
')
63 | print('')
64 |
--------------------------------------------------------------------------------
/script/gen_html_dssdd.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | modelid='default'
4 |
5 | print('')
6 | print('')
7 | print('')
8 | print('')
92 | w=200
93 | h=200
94 | e=0
95 | print('image | seed mask(sssdd) | integrated mask(second step) |
')
96 | for i in range(shown):
97 | print('')
98 | print('')
99 | print(''.format(i))
100 | print('  '.format(modelid, modelid, e,i, w,h))
101 | print(' ')
102 | print(' | ')
103 | print('')
104 | print(''.format(i))
105 | print('  '.format(modelid, modelid, e,i, w,h))
106 | print(' ')
107 | print(' | ')
108 | print('')
109 | print(''.format(i))
110 | print('  '.format(modelid, modelid, e,i, w,h))
111 | print(' ')
112 | print(' | ')
113 | print('
')
114 | print('')
115 | print('
first step |
')
116 | print('K1 | d_k | d_a | A1 | | integrated_mask1 |
')
117 | print('')
118 | print('')
119 | print('')
120 | print(''.format(i))
121 | print('  '.format(modelid, modelid, e,i, w,h))
122 | print(' ')
123 | print(' | ')
124 | print('')
125 | print(''.format(i))
126 | print('  '.format(modelid, modelid, e,i, w,h))
127 | print(' ')
128 | print(' | ')
129 | print('')
130 | print(''.format(i))
131 | print('  '.format(modelid, modelid, e,i, w,h))
132 | print(' ')
133 | print(' | ')
134 | print('')
135 | print(''.format(i))
136 | print('  '.format(modelid, modelid, e,i, w,h))
137 | print(' ')
138 | print(' | ')
139 | print(' ')
140 | print(' | ')
141 | print('')
142 | print(''.format(i))
143 | print('  '.format(modelid, modelid, e,i, w,h))
144 | print(' ')
145 | print(' | ')
146 | print('
')
147 | print('second step |
')
148 | print('K2 | d_k | d_a | A2 | | integrated_mask2 |
')
149 | print('')
150 | print('')
151 | print(''.format(i))
152 | print('  '.format(modelid, modelid, e,i, w,h))
153 | print(' ')
154 | print(' | ')
155 | print('')
156 | print(''.format(i))
157 | print('  '.format(modelid, modelid, e,i, w,h))
158 | print(' ')
159 | print(' | ')
160 | print('')
161 | print(''.format(i))
162 | print('  '.format(modelid, modelid, e,i, w,h))
163 | print(' ')
164 | print(' | ')
165 | print('')
166 | print(''.format(i))
167 | print('  '.format(modelid, modelid, e,i, w,h))
168 | print(' ')
169 | print(' | ')
170 | print('')
171 | print(' | ')
172 | print('')
173 | print(''.format(i))
174 | print('  '.format(modelid, modelid, e,i, w,h))
175 | print(' ')
176 | print(' | ')
177 | print('
')
178 |
179 | print('')
180 |
--------------------------------------------------------------------------------
/script/gen_html_val.py:
--------------------------------------------------------------------------------
1 | import os
2 | print('')
3 | print('')
4 | print('')
5 | print('')
6 | print('title')
7 | print('')
8 | print('')
9 | print('')
10 | print('')
39 | btn_txt=['<<<','<<','<','>','>>','>>>']
40 | for i in range(6):
41 | print(''.format(i))
42 | print('')
43 | print('
')
44 | print('')
45 | w=200
46 | h=200
47 | print('Image | Inference | Ground truth |
')
48 | for i in range(shown):
49 | print('')
50 | print('')
51 | print(''.format(i))
52 | print('  '.format(image_ids[i], w,h))
53 | print(' ')
54 | print(' | ')
55 | print('')
56 | print(''.format(i))
57 | print('  ')
59 | print(' | ')
60 | print('')
61 | print(''.format(i))
62 | print('  '.format(image_ids[i], w,h))
63 | print(' ')
64 | print(' | ')
65 | print('
')
66 | print('')
67 |
--------------------------------------------------------------------------------
/script/val.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | title
6 |
7 |
8 |
9 |
1645 |
1646 |
1647 |
1648 |
1649 |
1650 |
1651 |
1652 |
1653 |
1654 |
1655 |
1656 |
1657 |
1658 |
1659 |
1660 |
1661 |
1662 |
1663 |
1664 | Image | Inference | Ground truth |
1665 |
1666 |
1667 |
1668 | 
1669 |
1670 | |
1671 |
1672 |
1673 | 
1675 | |
1676 |
1677 |
1678 | 
1679 |
1680 | |
1681 |
1682 |
1683 |
1684 |
1685 | 
1686 |
1687 | |
1688 |
1689 |
1690 | 
1692 | |
1693 |
1694 |
1695 | 
1696 |
1697 | |
1698 |
1699 |
1700 |
1701 |
1702 | 
1703 |
1704 | |
1705 |
1706 |
1707 | 
1709 | |
1710 |
1711 |
1712 | 
1713 |
1714 | |
1715 |
1716 |
1717 |
1718 |
1719 | 
1720 |
1721 | |
1722 |
1723 |
1724 | 
1726 | |
1727 |
1728 |
1729 | 
1730 |
1731 | |
1732 |
1733 |
1734 |
1735 |
1736 | 
1737 |
1738 | |
1739 |
1740 |
1741 | 
1743 | |
1744 |
1745 |
1746 | 
1747 |
1748 | |
1749 |
1750 |
1751 |
--------------------------------------------------------------------------------
/ssdd_function.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | sigmoid = torch.nn.Sigmoid()
6 | def compute_sig_mask_loss(logits, bin_mask):
7 | bin_mask=bin_mask.float()
8 | logits=sigmoid(logits).squeeze(1)
9 | loc0=bin_mask==0
10 | loc1=bin_mask==1
11 | logits0=logits[loc0]
12 | logits1=logits[loc1]
13 | bin_mask0=bin_mask[loc0]
14 | bin_mask1=bin_mask[loc1]
15 | loss0=F.binary_cross_entropy(logits0, bin_mask0)
16 | loss1=F.binary_cross_entropy(logits1, bin_mask1)
17 | return (loss0 + loss1)/2
18 |
19 | def add_class_weights(pixel_weights, mask0, mask1, ignore_flags, bg_bias=0.00):
20 | for i in range(len(mask0)):
21 | pixel_weight = pixel_weights[i]
22 | pixel_weight -= (mask0[i]==(0)).float()*(bg_bias)
23 | pixel_weight += (mask1[i]==(0)).float()*(bg_bias)
24 | for j in range(1,ignore_flags.shape[1]):
25 | pixel_weight -= (mask0[i]==(j)).float()*(ignore_flags[i,j]*1.0)
26 | pixel_weight += (mask1[i]==(j)).float()*(ignore_flags[i,j]*1.0)
27 | return pixel_weights
28 | def get_dd_mask(dd0, dd1, mask0, mask1, ignore_flags, dd_bias=0.15, bg_bias=0.05):
29 | dd0_prob = sigmoid(dd0)
30 | dd1_prob = sigmoid(dd1)
31 | w = dd0_prob-dd1_prob+dd_bias
32 | w = add_class_weights(w, mask0, mask1, ignore_flags, bg_bias=bg_bias)
33 | refine_mask=Variable(torch.zeros_like(mask0))+255
34 | bsc=((w.squeeze(1)>=0))
35 | bcs=bsc==0
36 | refine_mask[bsc]=mask1[bsc]
37 | refine_mask[bcs]=mask0[bcs]
38 | return (dd0, dd1, ignore_flags, refine_mask)
39 | def get_dd(dd, dd_head, mask):
40 | binmask = get_binarymask(mask)
41 | dd_pred = dd((dd_head, binmask.detach()))
42 | return dd_pred
43 |
44 | def get_ignore_flags(mask0, mask1, mlabel, th=0.5):
45 | ignore_flags=np.zeros((len(mask0),21,))
46 | for i in range(len(mlabel)):
47 | for j in range(len(mlabel[0])):
48 | if mlabel[i][j]==1:
49 | loc0=torch.sum(mask0[i]==(j+1)).item()
50 | loc1=torch.sum(mask1[i]==(j+1)).item()
51 | rate=loc1/max(loc0,1)
52 | if rate < th:
53 | ignore_flags[i,j+1]=1
54 | return ignore_flags
55 |
56 | def get_binarymask(masks, chn=21):
57 | # input [NxHxW]
58 | N,H,W=masks.shape
59 | bin_masks=torch.zeros(N,chn,H,W).cuda()
60 | for n in range(N):
61 | mask = masks[n]
62 | for c in range(chn):
63 | bin_mask = bin_masks[n,c]
64 | loc = mask==c
65 | locn=torch.sum(loc)
66 | if locn.sum()>0:
67 | bin_mask[loc]=1
68 | return bin_masks
69 |
70 | def get_ddloss(dd, diff_mask, ignore_flags):
71 | loss_dd = Variable(torch.FloatTensor([0]),requires_grad=True).cuda()
72 | cnt=0
73 | for k in range(len(dd)):
74 | if torch.sum(ignore_flags[k,1:]).item()>0:
75 | continue
76 | cnt+=1
77 | loss_dd += compute_sig_mask_loss(dd[k:k+1], diff_mask[k:k+1])
78 | if cnt >0:
79 | loss_dd /= cnt
80 | return loss_dd
81 |
--------------------------------------------------------------------------------
/ssdd_val.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import math
3 | import os
4 | import random
5 | import re
6 | import cv2
7 | import numpy as np
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.optim as optim
12 | import torch.utils.data
13 | from torch.autograd import Variable
14 | import imutils
15 | import utils
16 | from torchvision import transforms
17 | from torch.utils.data.dataloader import default_collate
18 | import time
19 | from PIL import Image
20 | from base_class import BaseModel, SegBaseModel, SSDDBaseModel, PascalDataset
21 | from network import SegmentationPsa, PredictDiff, PredictDiffHead
22 | import math
23 |
24 | ############################################################
25 | # dataset
26 | ############################################################
27 |
28 | class SSDDValData(PascalDataset):
29 | def __init__(self, dataset, config):
30 | super().__init__(dataset, config)
31 | self.joint_transform_list=[
32 | imutils.Rescale(self.config.INP_SHAPE),
33 | None,
34 | None,
35 | ]
36 | self.img_transform_list=[
37 | np.asarray,
38 | imutils.Normalize(mean = self.mean, std = self.std),
39 | imutils.HWC_to_CHW
40 | ]
41 | def __getitem__(self, image_index):
42 | image_id = self.image_ids[image_index]
43 | # Load image and mask
44 | impath= self.config.VOC_ROOT+'/JPEGImages/'
45 | imn=impath+image_id+'.jpg'
46 | img = Image.open(imn).convert("RGB")
47 | img = self.img_label_resize([img])[0]
48 | images = torch.from_numpy(img)
49 | return images, image_index
50 |
51 | def __len__(self):
52 | return self.image_ids.shape[0]
53 |
54 |
55 | ############################################################
56 | # Model Class
57 | ############################################################
58 |
59 | class SegModel(SegBaseModel):
60 | def __init__(self, config):
61 | super(SegModel, self).__init__(config)
62 | in_channel=4096
63 | self.seg_main = SegmentationPsa(config, in_channel=in_channel, middle_channel=512, num_classes=21)
64 |
65 | def forward(self, inputs):
66 | x = inputs
67 | [x1, x2, x3, x4, x5] = self.encoder(x)
68 | seg, seg_head = self.seg_main(x5)
69 | return seg
70 |
71 | class Evaluator():
72 | def __init__(self, config, model):
73 | super(Evaluator, self).__init__()
74 | self.config = config
75 | self.model=model
76 |
77 | def eval_model(self, val_dataset):
78 | self.val_set = SSDDValData(val_dataset, self.config)
79 | val_generator = torch.utils.data.DataLoader(self.val_set, batch_size=self.config.BATCH, shuffle=False, num_workers=torch.cuda.device_count()*2, pin_memory=True)
80 | self.model.eval()
81 | self.eval(val_generator)
82 |
83 | def get_segmentation(self, img):
84 | segs = self.get_ms_segout(img)
85 | fimg = img[:,:,:,torch.arange(img.shape[3]-1,-1,-1)]
86 | fsegs = self.get_ms_segout(fimg)
87 | seg_all = torch.zeros(1,segs[0].shape[1],segs[0].shape[2],segs[0].shape[3])
88 | for i in range(len(segs)):
89 | seg_all += segs[i]
90 | for i in range(len(segs)):
91 | seg_all += fsegs[i][:,:,:,torch.arange(fsegs[i].shape[3]-1,-1,-1)]
92 | return seg_all
93 |
94 | def get_ms_segout(self, img):
95 | scales = [1/2, 3/4, 1, 5/4, 3/2]
96 | segs = []
97 | for i in range(len(scales)):
98 | scale=scales[i]
99 | simg = F.interpolate(img, (int(img.shape[2]*scale),int(img.shape[3]*scale)), mode='bilinear')
100 | seg = self.model(simg)
101 | seg = F.softmax(seg,dim=1)
102 | seg = F.interpolate(seg, (int(img.shape[2]),int(img.shape[3])), mode='bilinear')
103 | seg = seg.data.cpu()
104 | segs.append(seg)
105 | torch.cuda.empty_cache()
106 | return segs
107 |
108 | def eval(self, datagenerator):
109 | end = time.time()
110 | cnt=0
111 | for inputs in datagenerator:
112 | print(cnt)
113 | data_time = time.time()
114 | start=time.time()
115 | images, imgindex = inputs
116 | images = Variable(images).cuda()
117 | segs=[]
118 | with torch.no_grad():
119 | for i in range(len(images)):
120 | # segmentation
121 | seg=self.get_segmentation(images[i:i+1])
122 | # crf
123 | image_id = self.val_set.image_ids[imgindex[i]]
124 | impath=self.config.VOC_ROOT+'/JPEGImages/'
125 | imn=impath+image_id+'.jpg'
126 | img_org = np.asarray(Image.open(imn))
127 | seg=F.interpolate(seg,(img_org.shape[0],img_org.shape[1]),mode='bilinear')
128 | prob=F.softmax(seg,dim=1)[0].data.cpu().numpy()
129 | seg_mask = np.argmax(prob,0)
130 | seg_crf_map = imutils.crf_inference(img_org, prob, labels=prob.shape[0], t=10)
131 | seg_crf_mask = np.argmax(seg_crf_map,axis=0)
132 | # save results
133 | cnt+=1
134 | saven = os.path.join(self.savedir, 'seg_val_'+self.saveid+'_'+str(cnt)+'.png')
135 | utils.mask2png(saven, seg_mask)
136 | saven = os.path.join(self.savedir, 'seg_val_'+self.saveid+'_'+str(cnt)+'.txt')
137 | np.savetxt(saven, seg_mask)
138 | saven = os.path.join(self.savedir, 'seg_val_crf_'+self.saveid+'_'+str(cnt)+'.png')
139 | utils.mask2png(saven, seg_crf_mask)
140 | saven = os.path.join(self.savedir, 'seg_val_crf_'+self.saveid+'_'+str(cnt)+'.txt')
141 | np.savetxt(saven, seg_crf_mask)
142 |
143 |
144 |
145 | def set_log_dir(self, phase, saveid, model_path=None):
146 | self.phase = phase
147 | self.saveid = saveid
148 | self.savedir = 'validation/'+self.saveid
149 | print("save the results to "+self.savedir)
150 | if not os.path.exists(self.savedir):
151 | os.makedirs(self.savedir)
152 |
153 | def val(config, weight_file=None):
154 | model = SegModel(config=config)
155 | return model
156 |
--------------------------------------------------------------------------------
/train_dssdd.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | import re
5 | import cv2
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | import torch.utils.data
12 | from torch.autograd import Variable
13 | from torchvision import transforms
14 | import imutils
15 | import utils
16 | from base_class import BaseModel, SegBaseModel, SSDDBaseModel, PascalDataset
17 | import ssdd_function as ssddF
18 | import time
19 | from PIL import Image
20 | from network import SegmentationPsa, PredictDiff, PredictDiffHead
21 | import math
22 | import cv2
23 | cv2.setNumThreads(0)
24 |
25 | ############################################################
26 | # dataset
27 | ############################################################
28 |
29 | class DssddData(PascalDataset):
30 | def __init__(self, dataset, config):
31 | super().__init__(dataset, config)
32 | self.label_dic = dataset.label_dic
33 | self.joint_transform_list=[
34 | None,
35 | imutils.RandomHorizontalFlip(),
36 | imutils.RandomResizeLong(512, 768),
37 | imutils.RandomCrop(448),
38 | None,
39 | ]
40 | def __getitem__(self, image_index):
41 | image_id = self.image_ids[image_index]
42 | impath = self.config.VOC_ROOT+'/JPEGImages/'
43 | imn = impath+image_id+'.jpg'
44 | img = Image.open(imn).convert("RGB")
45 | gt_class_mlabel = torch.from_numpy(self.label_dic[image_id])
46 | gt_class_mlabel_bg = torch.from_numpy(np.concatenate(([1],self.label_dic[image_id])))
47 | psan = 'prepare_labels/results/out_aff/'+image_id+'.npy'
48 | psa=np.array(list(np.load(psan).item().values())).transpose(1,2,0)
49 | psan = 'prepare_labels/results/out_aff_crf/'+image_id+'.npy'
50 | psa_crf=np.load(psan).transpose(1,2,0)
51 | h=psa.shape[0]
52 | w=psa.shape[1]
53 | saven = 'precompute/'+self.config.modelid+'/da_precompute_'+self.config.modelid+'_'+str(image_index)+'.npy'
54 | dd0=np.load(saven).transpose(1,2,0)
55 | dd0=np.reshape(cv2.resize(dd0,(w,h)),(h,w,1))
56 | saven = 'precompute/'+self.config.modelid+'/dk_precompute_'+self.config.modelid+'_'+str(image_index)+'.npy'
57 | dd1=np.load(saven).transpose(1,2,0)
58 | dd1=np.reshape(cv2.resize(dd1,(w,h)),(h,w,1))
59 | # resize inputs
60 | img_norm, img_org, psa, psa_crf, dp0, dp1 = self.img_label_resize([img, np.array(img), psa, psa_crf, dd0, dd1])
61 | img_org = cv2.resize(img_org,self.config.OUT_SHAPE)
62 | dd0 = cv2.resize(dd0,self.config.OUT_SHAPE)
63 | dd1 = cv2.resize(dd1,self.config.OUT_SHAPE)
64 | psa = cv2.resize(psa,self.config.OUT_SHAPE)
65 | psa_crf = cv2.resize(psa_crf,self.config.OUT_SHAPE)
66 | psa=self.get_prob_label(psa, gt_class_mlabel_bg).transpose(2,0,1)
67 | psa_crf=self.get_prob_label(psa_crf, gt_class_mlabel_bg).transpose(2,0,1)
68 | psa_mask = np.argmax(psa,0)
69 | psa_crf_mask = np.argmax(psa_crf,0)
70 | dd0 = torch.from_numpy(dd0).unsqueeze(0)
71 | dd1 = torch.from_numpy(dd1).unsqueeze(0)
72 | psa_mask = torch.from_numpy(psa_mask).unsqueeze(0)
73 | psa_crf_mask = torch.from_numpy(psa_crf_mask).unsqueeze(0)
74 | ignore_flags=torch.from_numpy(ssddF.get_ignore_flags(psa_mask, psa_crf_mask, [gt_class_mlabel])).float()
75 | # integration using sssdd module
76 | # the parameters are different from dssdd module
77 | (_, _, _, seed_mask) = ssddF.get_dd_mask(dd0, dd1, psa_mask, psa_crf_mask, ignore_flags, dd_bias=0.1, bg_bias=0.1)
78 | return img_norm, img_org, gt_class_mlabel, gt_class_mlabel_bg, seed_mask[0]
79 | def __len__(self):
80 | return self.image_ids.shape[0]
81 |
82 | ############################################################
83 | # Models
84 | ############################################################
85 | class SegModel(SegBaseModel):
86 | def __init__(self, config):
87 | super(SegModel, self).__init__(config)
88 | self.config = config
89 | in_channel=4096
90 | self.seg_main = SegmentationPsa(config,num_classes=21, in_channel=in_channel, middle_channel=512, scale=2)
91 | self.seg_sub = SegmentationPsa(config,num_classes=21, in_channel=in_channel, middle_channel=512, scale=2)
92 | def set_bn_fix(m):
93 | classname = m.__class__.__name__
94 | if classname.find('BatchNorm') != -1:
95 | for p in m.parameters(): p.requires_grad = False
96 | self.apply(set_bn_fix)
97 | def forward(self, inputs):
98 | x, img_org, gt_class_mlabel = inputs
99 | feats = self.encoder(x)
100 | [x1,x2,x3,x4,x5] = feats
101 | seg_outs_main = self.get_seg(self.seg_main, x5, gt_class_mlabel)
102 | seg_outs_sub = self.get_seg(self.seg_sub, x5, gt_class_mlabel)
103 | seg_crf, seg_crf_mask = self.get_crf(img_org, seg_outs_main[0], gt_class_mlabel)
104 | return seg_outs_main, seg_outs_sub, seg_crf_mask, feats
105 |
106 | class SSDDModel(SSDDBaseModel):
107 | def __init__(self, config):
108 | super(SSDDModel, self).__init__(config)
109 | self.dd_head0 = PredictDiffHead(config, in_channel=512, in_channel2=128)
110 | self.dd_head1 = PredictDiffHead(config, in_channel=512, in_channel2=128)
111 | self.dd0 = PredictDiff(config, in_channel=256, in_channel2=128)
112 | self.dd1 = PredictDiff(config, in_channel=256, in_channel2=128)
113 | def forward(self, inputs):
114 | (seg_outs_main, seg_outs_sub, seg_crf_mask, feats), seed_mask, gt_class_mlabel = inputs
115 | [x1,x2,x3,x4,x5] = feats
116 | x1=F.avg_pool2d(x1, 2, 2)
117 | # first step
118 | seg_main, seg_prob_main, seg_mask_main, seg_head_main = seg_outs_main
119 | ignore_flags0=torch.from_numpy(ssddF.get_ignore_flags(seg_mask_main, seg_crf_mask, gt_class_mlabel)).cuda().float()
120 | dd_head0 = self.dd_head0((seg_head_main.detach(), x1.detach()))
121 | dd00 = ssddF.get_dd(self.dd0, dd_head0, seg_mask_main)
122 | dd01 = ssddF.get_dd(self.dd0, dd_head0, seg_crf_mask)
123 | dd_outs0 = ssddF.get_dd_mask(dd00, dd01, seg_mask_main, seg_crf_mask, ignore_flags0, dd_bias=0.4, bg_bias=0)
124 | (dd01, dd10, ignore_flags0, refine_mask0)=dd_outs0
125 | # second step
126 | seg_sub, seg_prob_sub, seg_mask_sub, seg_head_sub = seg_outs_sub
127 | dd_head1 = self.dd_head1((seg_head_sub.detach(), x1.detach()))
128 | dd10 = ssddF.get_dd(self.dd1, dd_head1, seed_mask)
129 | dd11 = ssddF.get_dd(self.dd1, dd_head1, refine_mask0)
130 | ignore_flags1 = torch.from_numpy(ssddF.get_ignore_flags(seed_mask, refine_mask0, gt_class_mlabel)).cuda().float()
131 | dd_outs1 = ssddF.get_dd_mask(dd10, dd11, seed_mask, refine_mask0, ignore_flags1, dd_bias=0.4, bg_bias=0)
132 | return dd_outs0, dd_outs1
133 |
134 | ############################################################
135 | # Trainer
136 | ############################################################
137 | class Trainer():
138 | def __init__(self, config, model_dir, model):
139 | super(Trainer, self).__init__()
140 | self.config = config
141 | self.model_dir = model_dir
142 | self.epoch = 0
143 | self.layer_regex = {
144 | "lr1": r"(encoder.*)",
145 | "lr10": r"(seg_main.*)|(seg_sub.*)",
146 | "dd": r"(dd0.*)|(dd1.*)|(dd_head0.*)|(dd_head1.*)",
147 | }
148 | lr_1x = self.layer_regex["lr1"]
149 | lr_10x = self.layer_regex["lr10"]
150 | dd = self.layer_regex['dd']
151 | seg_model=model[0].cuda()
152 | ssdd_model=model[1].cuda()
153 | self.param_lr_1x = [param for name, param in seg_model.named_parameters() if bool(re.fullmatch(lr_1x, name)) and not 'bn' in name]
154 | self.param_lr_10x = [param for name, param in seg_model.named_parameters() if bool(re.fullmatch(lr_10x, name)) and not 'bn' in name]
155 | self.param_dd = [param for name, param in ssdd_model.named_parameters() if bool(re.fullmatch(dd, name)) and not 'bn' in name]
156 | lr=1e-3
157 | self.seg_model=nn.DataParallel(seg_model)
158 | self.ssdd_model=nn.DataParallel(ssdd_model)
159 | def train_model(self, train_dataset):
160 | epochs=self.config.EPOCHS
161 | # Data generators
162 | self.train_set = DssddData(train_dataset, self.config)
163 | train_generator = torch.utils.data.DataLoader(self.train_set, batch_size=self.config.BATCH, shuffle=True, num_workers=8, pin_memory=True)
164 | self.config.LR_RAMPDOWN_EPOCHS=int(epochs*1.2)
165 | self.seg_model.train()
166 | self.ssdd_model.train()
167 | for epoch in range(0, epochs):
168 | print("Epoch {}/{}.".format(epoch,epochs))
169 | # Training
170 | self.train_epoch(train_generator, epoch)
171 | # Save model
172 | if (epoch % 2 ==0) & (epoch>0):
173 | torch.save(self.seg_model.state_dict(), self.checkpoint_path_seg.format(epoch))
174 | torch.save(self.ssdd_model.state_dict(), self.checkpoint_path_ssdd.format(epoch))
175 | torch.cuda.empty_cache()
176 | def train_epoch(self, datagenerator, epoch):
177 | learning_rate=self.config.LEARNING_RATE
178 | self.cnt=0
179 | self.steps = len(datagenerator)
180 | self.step=0
181 | self.epoch=epoch
182 | end=time.time()
183 | for inputs in datagenerator:
184 | self.train_step(inputs, end)
185 | end=time.time()
186 | self.step += 1
187 | def train_step(self, inputs, end):
188 | start = time.time()
189 | # adjust learning rate
190 | lr=utils.adjust_learning_rate(self.config.LEARNING_RATE, self.epoch, self.config.LR_RAMPDOWN_EPOCHS, self.step, self.steps)
191 | self.optimizer = torch.optim.SGD([
192 | {'params': self.param_lr_1x,'lr': lr*1, 'weight_decay': self.config.WEIGHT_DECAY},
193 | {'params': self.param_lr_10x,'lr': lr*10, 'weight_decay': self.config.WEIGHT_DECAY},
194 | ], lr=lr, momentum=self.config.LEARNING_MOMENTUM, weight_decay= self.config.WEIGHT_DECAY)
195 | self.optimizer_dd = torch.optim.SGD([
196 | {'params': self.param_dd,'lr': lr*10, 'weight_decay': self.config.WEIGHT_DECAY},
197 | ], lr=lr, momentum=self.config.LEARNING_MOMENTUM, weight_decay= self.config.WEIGHT_DECAY)
198 | # input items
199 | img_norm, img_org, gt_class_mlabels, gt_class_mlabels_bg, seed_mask = inputs
200 | img_norm = Variable(img_norm).cuda().float()
201 | img_org = Variable(img_org).cuda().float()
202 | seed_mask = Variable(seed_mask).cuda().long()
203 | gt_class_mlabels = Variable(gt_class_mlabels).cuda().float()
204 | gt_class_mlabels_bg = Variable(gt_class_mlabels_bg).cuda().float()
205 | # forward
206 | seg_outs = self.seg_model((img_norm, img_org, gt_class_mlabels_bg))
207 | dd_outs = self.ssdd_model((seg_outs, seed_mask, gt_class_mlabels))
208 | # get loss
209 | loss_seg, loss_dd = self.compute_loss(seg_outs, dd_outs, inputs)
210 | forward_time=time.time()
211 | # backward
212 | self.optimizer.zero_grad()
213 | loss_seg.backward()
214 | self.optimizer.step()
215 | forward_time=time.time()
216 | self.optimizer_dd.zero_grad()
217 | loss_dd.backward()
218 | self.optimizer_dd.step()
219 | forward_time=time.time()
220 | if (self.step%10==0):
221 | prefix="{}/{}/{}/{}".format(self.epoch, self.cnt, self.step + 1, self.steps)
222 | suffix="forward_time: {:.3f} data {:.3f} loss: {:.3f}".format(
223 | forward_time-start, (start-end),loss_seg.item())
224 | print('%s %s' % (prefix, suffix), end = '\n')
225 |
226 | def compute_loss(self, seg_outs, dd_outs, inputs):
227 | seg_outs_main, seg_outs_sub, seg_crf_mask, feats = seg_outs
228 | seg_main, seg_prob_main, seg_mask_main, _ = seg_outs_main
229 | seg_sub, seg_prob_sub, seg_mask_sub, _ = seg_outs_sub
230 | dd_outs0, dd_outs1 = dd_outs
231 | images, img_org, gt_class_mlabels, gt_class_mlabels_bg, seed_mask = inputs
232 | seed_mask = Variable(seed_mask).cuda().long()
233 | (dd00, dd01, ignore_flags0, refine_mask0) = dd_outs0
234 | (dd10, dd11, ignore_flags1, refine_mask1) = dd_outs1
235 | # compute losses
236 | # segmentation loss
237 | loss_seg_main = F.cross_entropy(seg_main, refine_mask1, ignore_index=255)
238 | loss_seg_sub = 0.5*F.cross_entropy(seg_sub, seed_mask, ignore_index=255) + 0.5*F.cross_entropy(seg_sub, refine_mask1, ignore_index=255)
239 | loss_seg = loss_seg_main + loss_seg_sub
240 | # difference detection loss
241 | seg_crf_diff = seg_mask_main != seg_crf_mask
242 | loss_dd00 = ssddF.get_ddloss(dd00, seg_crf_diff, ignore_flags0)
243 | loss_dd01 = ssddF.get_ddloss(dd01, seg_crf_diff, ignore_flags0)
244 | loss_dd10 = ssddF.compute_sig_mask_loss(dd10, seed_mask != seg_mask_sub)
245 | loss_dd11 = ssddF.compute_sig_mask_loss(dd11, refine_mask1 != seg_mask_sub)
246 | loss_dd = (loss_dd00 + loss_dd01 + loss_dd10 + loss_dd11)/4
247 | # save temporary outputs
248 | if (self.step%30==0):
249 | sid='_'+self.phase+'_'+self.saveid+'_'+str(self.epoch)+'_'+str(self.cnt)
250 | img_org=img_org.data.cpu().numpy()[...,::-1]
251 | saven = self.log_dir_img + '/i'+sid+'.jpg'
252 | cv2.imwrite(saven,img_org[0])
253 | saven = self.log_dir_img + '/D1'+sid+'.png'
254 | mask_png = utils.mask2png(saven, refine_mask0[0].squeeze().data.cpu().numpy())
255 | saven = self.log_dir_img + '/K1'+sid+'.png'
256 | mask_png = utils.mask2png(saven, seg_mask_main[0].data.cpu().numpy().astype(np.float32))
257 | saven = self.log_dir_img + '/A1'+sid+'.png'
258 | mask_png = utils.mask2png(saven, seg_crf_mask[0].squeeze().data.cpu().numpy())
259 |
260 | saven = self.log_dir_img + '/dk1'+sid+'.png'
261 | tmp=F.sigmoid(dd00)[0].squeeze().data.cpu().numpy()
262 | cv2.imwrite(saven,tmp*255)
263 | saven = self.log_dir_img + '/da1'+sid+'.png'
264 | tmp=F.sigmoid(dd01)[0].squeeze().data.cpu().numpy()
265 | cv2.imwrite(saven,tmp*255)
266 |
267 | saven = self.log_dir_img + '/D2'+sid+'.png'
268 | mask_png = utils.mask2png(saven, refine_mask1[0].squeeze().data.cpu().numpy())
269 | saven = self.log_dir_img + '/K2'+sid+'.png'
270 | mask_png = utils.mask2png(saven, seed_mask[0].data.cpu().numpy().astype(np.float32))
271 | #saven = self.log_dir_img + '/A2'+sid+'.png'
272 | #mask_png = utils.mask2png(saven, refine_mask[0].squeeze().data.cpu().numpy())
273 |
274 | saven = self.log_dir_img + '/dk2'+sid+'.png'
275 | tmp=F.sigmoid(dd10)[0].squeeze().data.cpu().numpy()
276 | cv2.imwrite(saven,tmp*255)
277 | saven = self.log_dir_img + '/da2'+sid+'.png'
278 | tmp=F.sigmoid(dd11)[0].squeeze().data.cpu().numpy()
279 | cv2.imwrite(saven,tmp*255)
280 | self.cnt += 1
281 | return loss_seg, loss_dd
282 |
283 | def set_log_dir(self, phase, saveid, model_path=None):
284 | self.epoch = 0
285 | self.phase = phase
286 | self.saveid = saveid
287 | self.log_dir = os.path.join(self.model_dir, "{}_{}".format(phase, saveid))
288 | self.log_dir_model = self.log_dir +'/'+ 'models'
289 | if not os.path.exists(self.log_dir_model):
290 | os.makedirs(self.log_dir_model)
291 | self.log_dir_img = self.log_dir +'/'+ 'imgs'
292 | if not os.path.exists(self.log_dir_img):
293 | os.makedirs(self.log_dir_img)
294 | self.checkpoint_path_seg = os.path.join(self.log_dir_model, "seg_*epoch*.pth".format())
295 | self.checkpoint_path_seg = self.checkpoint_path_seg.replace("*epoch*", "{:04d}")
296 | self.checkpoint_path_ssdd = os.path.join(self.log_dir_model, "ssdd_*epoch*.pth".format())
297 | self.checkpoint_path_ssdd = self.checkpoint_path_ssdd.replace("*epoch*", "{:04d}")
298 |
299 | def models(config, weight_file=None):
300 | seg_model = SegModel(config=config)
301 | seg_model.initialize_weights()
302 | seg_model.load_resnet38_weights(weight_file)
303 | ssdd_model = SSDDModel(config=config)
304 | ssdd_model.initialize_weights()
305 | return (seg_model, ssdd_model)
306 |
--------------------------------------------------------------------------------
/train_sssdd.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import random
4 | import re
5 | import cv2
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.optim as optim
11 | import torch.utils.data
12 | from torch.autograd import Variable
13 | from torchvision import transforms
14 | import imutils
15 | import utils
16 | from base_class import BaseModel, SegBaseModel, SSDDBaseModel, PascalDataset
17 | import ssdd_function as ssddF
18 | import time
19 | from PIL import Image
20 | from network import SegmentationPsa, PredictDiff, PredictDiffHead
21 | import math
22 | import cv2
23 | cv2.setNumThreads(0)
24 |
25 | ############################################################
26 | # dataset
27 | ############################################################
28 |
29 | class SssddData(PascalDataset):
30 | def __init__(self, dataset, config):
31 | super().__init__(dataset, config)
32 | self.label_dic = dataset.label_dic
33 | self.joint_transform_list=[
34 | None,
35 | imutils.RandomHorizontalFlip(),
36 | imutils.RandomResizeLong(448, 512),
37 | imutils.RandomCrop(448),
38 | None,
39 | ]
40 | def __getitem__(self, image_index):
41 | image_id = self.image_ids[image_index]
42 | impath = self.config.VOC_ROOT+'/JPEGImages/'
43 | imn = impath+image_id+'.jpg'
44 | img = Image.open(imn).convert("RGB")
45 | gt_class_mlabel = torch.from_numpy(self.label_dic[image_id])
46 | gt_class_mlabel_bg = torch.from_numpy(np.concatenate(([1],self.label_dic[image_id])))
47 | psan = 'prepare_labels/results/out_aff/'+image_id+'.npy'
48 | psa=np.array(list(np.load(psan).item().values())).transpose(1,2,0)
49 | psan = 'prepare_labels/results/out_aff_crf/'+image_id+'.npy'
50 | psa_crf=np.load(psan).transpose(1,2,0)
51 | h=psa.shape[0]
52 | w=psa.shape[1]
53 | # resize inputs
54 | img_norm, img_org, psa, psa_crf = self.img_label_resize([img, np.array(img), psa, psa_crf])
55 | img_org = cv2.resize(img_org,self.config.OUT_SHAPE)
56 | psa = cv2.resize(psa,self.config.OUT_SHAPE)
57 | psa_crf = cv2.resize(psa_crf,self.config.OUT_SHAPE)
58 | psa=self.get_prob_label(psa, gt_class_mlabel_bg).transpose(2,0,1)
59 | psa_crf=self.get_prob_label(psa_crf, gt_class_mlabel_bg).transpose(2,0,1)
60 | psa_mask = np.argmax(psa,0)
61 | psa_crf_mask = np.argmax(psa_crf,0)
62 | return img_norm, img_org, gt_class_mlabel, gt_class_mlabel_bg, psa_mask, psa_crf_mask
63 | def __len__(self):
64 | return self.image_ids.shape[0]
65 |
66 | ############################################################
67 | # Models
68 | ############################################################
69 | class SegModel(SegBaseModel):
70 | def __init__(self, config):
71 | super(SegModel, self).__init__(config)
72 | self.config = config
73 | in_channel=4096
74 | self.seg_main = SegmentationPsa(config,num_classes=21, in_channel=in_channel, middle_channel=512, scale=2)
75 | def set_bn_fix(m):
76 | classname = m.__class__.__name__
77 | if classname.find('BatchNorm') != -1:
78 | for p in m.parameters(): p.requires_grad = False
79 | self.apply(set_bn_fix)
80 | def forward(self, inputs):
81 | x, img_org, gt_class_mlabel = inputs
82 | feats = self.encoder(x)
83 | [x1,x2,x3,x4,x5] = feats
84 | seg_outs_main = self.get_seg(self.seg_main, x5, gt_class_mlabel)
85 | return seg_outs_main, feats
86 |
87 | class SSDDModel(SSDDBaseModel):
88 | def __init__(self, config):
89 | super(SSDDModel, self).__init__(config)
90 | self.dd_head0 = PredictDiffHead(config, in_channel=512, in_channel2=128)
91 | self.dd0 = PredictDiff(config, in_channel=256, in_channel2=128)
92 | def forward(self, inputs):
93 | (seg_outs_main, feats), psa_mask, psa_crf_mask, gt_class_mlabel = inputs
94 | [x1,x2,x3,x4,x5] = feats
95 | x1=F.avg_pool2d(x1, 2, 2)
96 | # first step
97 | seg_main, seg_prob_main, seg_mask_main, seg_head_main = seg_outs_main
98 | ignore_flags0=torch.from_numpy(ssddF.get_ignore_flags(psa_mask, psa_crf_mask, gt_class_mlabel)).cuda().float()
99 | dd_head0 = self.dd_head0((seg_head_main.detach(), x1.detach()))
100 | dd00 = ssddF.get_dd(self.dd0, dd_head0, psa_mask)
101 | dd01 = ssddF.get_dd(self.dd0, dd_head0, psa_crf_mask)
102 | dd_outs0 = ssddF.get_dd_mask(dd00, dd01, psa_mask, psa_crf_mask, ignore_flags0, dd_bias=0.4, bg_bias=0)
103 | return dd_outs0
104 |
105 | ############################################################
106 | # Trainer
107 | ############################################################
108 | class Trainer():
109 | def __init__(self, config, model_dir, model):
110 | super(Trainer, self).__init__()
111 | self.config = config
112 | self.model_dir = model_dir
113 | self.epoch = 0
114 | self.layer_regex = {
115 | "lr1": r"(encoder.*)",
116 | "lr10": r"(seg_main.*)",
117 | "dd": r"(dd0.*)|(dd_head0.*)",
118 | }
119 | lr_1x = self.layer_regex["lr1"]
120 | lr_10x = self.layer_regex["lr10"]
121 | dd = self.layer_regex['dd']
122 | seg_model=model[0].cuda()
123 | ssdd_model=model[1].cuda()
124 | self.param_lr_1x = [param for name, param in seg_model.named_parameters() if bool(re.fullmatch(lr_1x, name)) and not 'bn' in name]
125 | self.param_lr_10x = [param for name, param in seg_model.named_parameters() if bool(re.fullmatch(lr_10x, name)) and not 'bn' in name]
126 | self.param_dd = [param for name, param in ssdd_model.named_parameters() if bool(re.fullmatch(dd, name)) and not 'bn' in name]
127 | lr=1e-3
128 | self.seg_model=nn.DataParallel(seg_model)
129 | self.ssdd_model=nn.DataParallel(ssdd_model)
130 | def train_model(self, train_dataset):
131 | epochs=self.config.EPOCHS
132 | # Data generators
133 | self.train_set = SssddData(train_dataset, self.config)
134 | train_generator = torch.utils.data.DataLoader(self.train_set, batch_size=self.config.BATCH, shuffle=True, num_workers=8, pin_memory=True)
135 | self.config.LR_RAMPDOWN_EPOCHS=int(epochs*1.2)
136 | self.seg_model.train()
137 | self.ssdd_model.train()
138 | for epoch in range(0, epochs):
139 | print("Epoch {}/{}.".format(epoch,epochs))
140 | # Training
141 | self.train_epoch(train_generator, epoch)
142 | # Save model
143 | if (epoch % 2 ==0) & (epoch>0):
144 | torch.save(self.seg_model.state_dict(), self.checkpoint_path_seg.format(epoch))
145 | torch.save(self.ssdd_model.state_dict(), self.checkpoint_path_ssdd.format(epoch))
146 | torch.cuda.empty_cache()
147 | def train_epoch(self, datagenerator, epoch):
148 | learning_rate=self.config.LEARNING_RATE
149 | self.cnt=0
150 | self.steps = len(datagenerator)
151 | self.step=0
152 | self.epoch=epoch
153 | end=time.time()
154 | for inputs in datagenerator:
155 | self.train_step(inputs, end)
156 | end=time.time()
157 | self.step += 1
158 | def train_step(self, inputs, end):
159 | start = time.time()
160 | # adjust learning rate
161 | lr=utils.adjust_learning_rate(self.config.LEARNING_RATE, self.epoch, self.config.LR_RAMPDOWN_EPOCHS, self.step, self.steps)
162 | self.optimizer = torch.optim.SGD([
163 | {'params': self.param_lr_1x,'lr': lr*1, 'weight_decay': self.config.WEIGHT_DECAY},
164 | {'params': self.param_lr_10x,'lr': lr*10, 'weight_decay': self.config.WEIGHT_DECAY},
165 | ], lr=lr, momentum=self.config.LEARNING_MOMENTUM, weight_decay= self.config.WEIGHT_DECAY)
166 | self.optimizer_dd = torch.optim.SGD([
167 | {'params': self.param_dd,'lr': lr*10, 'weight_decay': self.config.WEIGHT_DECAY},
168 | ], lr=lr, momentum=self.config.LEARNING_MOMENTUM, weight_decay= self.config.WEIGHT_DECAY)
169 | # input items
170 | img_norm, img_org, gt_class_mlabel, gt_class_mlabel_bg, psa_mask, psa_crf_mask = inputs
171 | img_norm = Variable(img_norm).cuda().float()
172 | img_org = Variable(img_org).cuda().float()
173 | gt_class_mlabel = Variable(gt_class_mlabel).cuda().float()
174 | gt_class_mlabel_bg = Variable(gt_class_mlabel_bg).cuda().float()
175 | # forward
176 | seg_outs = self.seg_model((img_norm, img_org, gt_class_mlabel_bg))
177 | dd_outs = self.ssdd_model((seg_outs, psa_mask, psa_crf_mask, gt_class_mlabel))
178 | # get loss
179 | loss_seg, loss_dd = self.compute_loss(seg_outs, dd_outs, inputs)
180 | forward_time=time.time()
181 | # backward
182 | self.optimizer.zero_grad()
183 | loss_seg.backward()
184 | self.optimizer.step()
185 | forward_time=time.time()
186 | self.optimizer_dd.zero_grad()
187 | loss_dd.backward()
188 | self.optimizer_dd.step()
189 | forward_time=time.time()
190 | if (self.step%10==0):
191 | prefix="{}/{}/{}/{}".format(self.epoch, self.cnt, self.step + 1, self.steps)
192 | suffix="forward_time: {:.3f} time: {:.3f} data {:.3f} seg: {:.3f}".format(
193 | forward_time-start, (time.time()-start),(start-end),loss_seg.item())
194 | print('\r%s %s' % (prefix, suffix), end = '\n')
195 |
196 | def compute_loss(self, seg_outs, dd_outs, inputs):
197 | seg_outs_main, feats = seg_outs
198 | seg_main, seg_prob_main, seg_mask_main, _ = seg_outs_main
199 | dd_outs0 = dd_outs
200 | img_norm, img_org, gt_class_mlabel, gt_class_mlabel_bg, psa_mask, psa_crf_mask = inputs
201 | (dd00, dd01, ignore_flags0, refine_mask0) = dd_outs0
202 | psa_mask = Variable(psa_mask).cuda().long()
203 | psa_crf_mask = Variable(psa_crf_mask).cuda().long()
204 | # compute losses
205 | # segmentation loss
206 | loss_seg_main = F.cross_entropy(seg_main, psa_mask, ignore_index=255)
207 | loss_seg = loss_seg_main
208 | # difference detection loss
209 | psa_diff = psa_mask != psa_crf_mask
210 | loss_dd00 = ssddF.get_ddloss(dd00, psa_diff, ignore_flags0)
211 | loss_dd01 = ssddF.get_ddloss(dd01, psa_diff, ignore_flags0)
212 | loss_dd = (loss_dd00 + loss_dd01)/2
213 | # save temporary outputs
214 | if (self.step%30==0):
215 | sid='_'+self.phase+'_'+self.saveid+'_'+str(self.epoch)+'_'+str(self.cnt)
216 | img_org=img_org.data.cpu().numpy()[...,::-1]
217 | saven = self.log_dir_img + '/i'+sid+'.jpg'
218 | cv2.imwrite(saven,img_org[0])
219 | saven = self.log_dir_img + '/D'+sid+'.png'
220 | mask_png = utils.mask2png(saven, refine_mask0[0].squeeze().data.cpu().numpy())
221 | saven = self.log_dir_img + '/K'+sid+'.png'
222 | mask_png = utils.mask2png(saven, psa_mask[0].data.cpu().numpy().astype(np.float32))
223 | saven = self.log_dir_img + '/A'+sid+'.png'
224 | mask_png = utils.mask2png(saven, psa_crf_mask[0].squeeze().data.cpu().numpy())
225 |
226 | saven = self.log_dir_img + 'da'+sid+'.png'
227 | tmp=F.sigmoid(dd00)[0].squeeze().data.cpu().numpy()
228 | cv2.imwrite(saven,tmp*255)
229 | saven = self.log_dir_img + 'dk'+sid+'.png'
230 | tmp=F.sigmoid(dd01)[0].squeeze().data.cpu().numpy()
231 | cv2.imwrite(saven,tmp*255)
232 | self.cnt += 1
233 | return loss_seg, loss_dd
234 |
235 | def set_log_dir(self, phase, saveid):
236 | self.epoch = 0
237 | self.phase = phase
238 | self.saveid = saveid
239 | self.log_dir = os.path.join(self.model_dir, "{}_{}".format(phase, saveid))
240 | self.log_dir_model = self.log_dir +'/'+ 'models'
241 | if not os.path.exists(self.log_dir_model):
242 | os.makedirs(self.log_dir_model)
243 | self.log_dir_img = self.log_dir +'/'+ 'imgs'
244 | if not os.path.exists(self.log_dir_img):
245 | os.makedirs(self.log_dir_img)
246 | self.checkpoint_path_seg = os.path.join(self.log_dir_model, "seg_*epoch*.pth".format())
247 | self.checkpoint_path_seg = self.checkpoint_path_seg.replace("*epoch*", "{:04d}")
248 | self.checkpoint_path_ssdd = os.path.join(self.log_dir_model, "ssdd_*epoch*.pth".format())
249 | self.checkpoint_path_ssdd = self.checkpoint_path_ssdd.replace("*epoch*", "{:04d}")
250 |
251 | def models(config, weight_file=None):
252 | seg_model = SegModel(config=config)
253 | seg_model.initialize_weights()
254 | seg_model.load_resnet38_weights(weight_file)
255 | ssdd_model = SSDDModel(config=config)
256 | ssdd_model.initialize_weights()
257 | return (seg_model, ssdd_model)
258 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn.functional as F
4 | from torch.autograd import Variable
5 | from PIL import Image
6 | def adjust_learning_rate(lr, epoch, lr_rampdown_epochs, step_in_epoch, total_steps_in_epoch):
7 | epoch = epoch + step_in_epoch / total_steps_in_epoch
8 | def cosine_rampdown(current, rampdown_length):
9 | """Cosine rampdown from https://arxiv.org/abs/1608.03983"""
10 | assert 0 <= current <= rampdown_length
11 | return float(.5 * (np.cos(np.pi * current / rampdown_length) + 1))
12 | lr *= cosine_rampdown(epoch, lr_rampdown_epochs)
13 | return lr
14 |
15 | def get_labeled_tensor(tensor, class_label):
16 | labeled_tensor=[]
17 | for i in range(len(tensor)):
18 | for i in range(class_mlabel.shape[1]):
19 | if gt_class_mlabel[i,j].item()==1:
20 | tmp_prob.append(tensor[i:i+1,j:j+1])
21 | tmp_prob=torch.cat(tmp_prob)
22 |
23 | def mask2png(saven, mask):
24 | palette = get_palette(256)
25 | mask=Image.fromarray(mask.astype(np.uint8))
26 | mask.putpalette(palette)
27 | mask.save(saven)
28 |
29 | def get_palette(num_cls):
30 | n = num_cls
31 | palette = [0] * (n * 3)
32 | for j in range(0, n):
33 | lab = j
34 | palette[j * 3 + 0] = 0
35 | palette[j * 3 + 1] = 0
36 | palette[j * 3 + 2] = 0
37 | i = 0
38 | while lab:
39 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
40 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
41 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
42 | i += 1
43 | lab >>= 3
44 | return palette
45 |
--------------------------------------------------------------------------------