├── meancal.py ├── dataloader.py ├── test.py ├── model.py └── main.py /meancal.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | means = np.zeros(3, dtype=np.float64) 5 | 6 | with open('./data/train_cls.txt', 'r') as f: 7 | lines = f.readlines() 8 | num = 0 9 | for i in range(len(lines)): 10 | name = lines[i].split()[0] 11 | try: 12 | img = Image.open('data/Pascal VOC dataset/VOCdevkit/VOC2012/JPEGImages/'+name+".jpg") 13 | img = np.array(img) 14 | img_mean = img.mean(axis=(0, 1)) 15 | means += img_mean 16 | print(img_mean) 17 | except: 18 | print(name+' is massing') 19 | num+=1 20 | 21 | means /= (len(lines)-num) 22 | print(num) 23 | print(means) 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import numpy as np 3 | import os 4 | from os.path import join 5 | from PIL import Image 6 | import torch 7 | from random import shuffle 8 | import os 9 | import pandas as pd 10 | from PIL import Image 11 | import numpy as np 12 | import matplotlib.pyplot as plt 13 | from torch.utils.data import Dataset, DataLoader 14 | from torchvision import transforms, utils 15 | 16 | class JPEGLoader(data.Dataset): 17 | 18 | def __init__(self, txt_path, img_dir, transform=None): 19 | f = open(txt_path, 'r') 20 | self.info = f.readlines() 21 | f.close() 22 | #shuffle(self.info) 23 | 24 | self.img_dir = img_dir 25 | 26 | self.transform = transform 27 | 28 | def __len__(self): 29 | return len(self.info) 30 | 31 | def __getitem__(self, idx): 32 | data = self.info[idx].split() 33 | 34 | img = Image.open(join(self.img_dir, data[0]+".jpg")) 35 | if self.transform: 36 | img = self.transform(img) 37 | 38 | img = np.asarray(img).astype(np.float32) - np.array([116.62341813, 111.51273588, 103.14803339]) 39 | img = torch.Tensor(np.transpose(img, [2, 0, 1])) 40 | 41 | labels = torch.zeros(20, dtype=torch.int64) 42 | for idx in data[1:]: 43 | labels[int(idx)] = 1 44 | 45 | return img, labels -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import matplotlib.pylab as pylab 4 | import matplotlib.cm as cm 5 | import matplotlib as mpl 6 | import scipy.misc 7 | from PIL import Image 8 | import scipy.io 9 | import os 10 | import cv2 11 | 12 | caffe_root = '../../' 13 | import sys 14 | sys.path.insert(0, caffe_root + 'python') 15 | 16 | import caffe 17 | 18 | #remove the following two lines if testing with cpu 19 | #caffe.set_mode_gpu() 20 | # choose which GPU you want to use 21 | #caffe.set_device(0) 22 | caffe.SGDSolver.display = 0 23 | 24 | # load net 25 | deploy_file = 'deploy_seenet.prototxt' 26 | model_file = 'seenet_final.caffemodel' 27 | net = caffe.Net(deploy_file, model_file, caffe.TEST) 28 | 29 | # images for testing 30 | im_lst = [('samples/2007_000039.jpg', [19,]), 31 | ('samples/2007_000063.jpg', [8, 11]), 32 | ('samples/2007_000738.jpg', [0,]), 33 | ('samples/2007_001185.jpg', [4, 7, 10, 14])] 34 | 35 | cats = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 36 | 'dining table', 'dog', 'horse', 'motorbike', 'person', 'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor'] 37 | 38 | colormaps = ['#000000', '#7F0000', '#007F00', '#7F7F00', '#00007F', '#7F007F', '#007F7F', '#7F7F7F', '#3F0000', '#BF0000', '#3F7F00', 39 | '#BF7F00', '#3F007F', '#BF007F', '#3F7F7F', '#BF7F7F', '#003F00', '#7F3F00', '#00BF00', '#7FBF00', '#003F7F'] 40 | 41 | def colormap(index): 42 | return mpl.colors.LinearSegmentedColormap.from_list('cmap', [colormaps[0], colormaps[index+1], '#FFFFFF'], 256) 43 | 44 | def resize(im, size): 45 | h, w = im.shape[:2] 46 | im = cv2.resize(im, (size, size), interpolation=cv2.INTER_CUBIC) 47 | im -= np.array((104.007, 116.669, 122.679)) 48 | im = im.transpose((2, 0, 1)) 49 | return im, h, w 50 | 51 | def forward(net, im, label): 52 | im, height, width = resize(im, test_size) 53 | net.blobs['data'].reshape(1, *im.shape) 54 | net.blobs['data'].data[...] = im 55 | net.blobs['label'].reshape(1, 1, 1, 20) 56 | net.blobs['label'].data[0,0,0,label] = 1 57 | net.forward() 58 | att1 = net.blobs['score_b1'].data[0][label] 59 | att2 = net.blobs['score_b2'].data[0][label] 60 | att1 = cv2.resize(att1, (width, height), interpolation=cv2.INTER_CUBIC) 61 | att2 = cv2.resize(att2, (width, height), interpolation=cv2.INTER_CUBIC) 62 | att1[att1 < 0] = 0 63 | att2[att2 < 0] = 0 64 | att1 = att1 / (np.max(att1) + 1e-8) 65 | att2 = att2 / (np.max(att2) + 1e-8) 66 | att = np.maximum(att1, att2) 67 | return att 68 | 69 | #Visualization 70 | def plot_atts(att_lst, label_lst, size): 71 | pylab.rcParams['figure.figsize'] = size, size/2 72 | plt.figure() 73 | for i in range(0, len(att_lst)): 74 | s = plt.subplot(1,len(att_lst),i+1) 75 | 76 | if label_lst[i] == 'Source': 77 | s.set_xlabel(label_lst[i], fontsize=18) 78 | plt.imshow(att_lst[i]) 79 | else: 80 | s.set_xlabel(cats[int(label_lst[i])], fontsize=18) 81 | plt.imshow(att_lst[i], cmap = colormap(int(label_lst[i]))) 82 | s.set_xticklabels([]) 83 | s.set_yticklabels([]) 84 | s.yaxis.set_ticks_position('none') 85 | s.xaxis.set_ticks_position('none') 86 | plt.tight_layout() 87 | plt.savefig('img%d'%img_id) 88 | plt.close() 89 | 90 | # input image 91 | test_size = 256 92 | with_flip = True 93 | 94 | img_id = 3 # 0-3 95 | im_name, im_labels = im_lst[img_id] 96 | img = cv2.imread(im_name) 97 | img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # for visualizing 98 | att_maps = [] 99 | img = img.astype('float') 100 | for label in im_labels: 101 | att = forward(net, img, label) 102 | if with_flip: 103 | img_flip = img[:,::-1,:] 104 | att_flip = forward(net, img_flip, label) 105 | att = np.maximum(att, att_flip[:,::-1]) 106 | att = att * 0.8 + img_gray / 255. * 0.2 107 | att_maps.append(att) 108 | 109 | res_lst = [img[:,:,::-1].astype(np.uint8),] + att_maps 110 | label_lst = ['Source',] + im_labels 111 | plot_atts(res_lst, label_lst, 20) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | from torch.nn import Parameter 6 | import torchvision 7 | import math 8 | import scipy.io as sio 9 | import numpy as np 10 | 11 | class train_net(nn.Module): 12 | def __init__(self): 13 | super(train_net, self).__init__() 14 | 15 | self.conv1_1 = nn.Conv2d(3, 64, 3, 1, 1) 16 | self.relu1_1 = nn.ReLU() 17 | self.conv1_2 = nn.Conv2d(64, 64, 3, 1, 1) 18 | self.relu1_2 = nn.ReLU() 19 | self.maxpool1 = nn.MaxPool2d(2, 2) 20 | 21 | self.conv2_1 = nn.Conv2d(64, 128, 3, 1, 1) 22 | self.relu2_1 = nn.ReLU() 23 | self.conv2_2 = nn.Conv2d(128, 128, 3, 1, 1) 24 | self.relu2_2 = nn.ReLU() 25 | self.maxpool2 = nn.MaxPool2d(2, 2) 26 | 27 | self.conv3_1 = nn.Conv2d(128, 256, 3, 1, 1) 28 | self.relu3_1 = nn.ReLU() 29 | self.conv3_2 = nn.Conv2d(256, 256, 3, 1, 1) 30 | self.relu3_2 = nn.ReLU() 31 | self.conv3_3 = nn.Conv2d(256, 256, 3, 1, 1) 32 | self.relu3_3 = nn.ReLU() 33 | self.maxpool3 = nn.MaxPool2d(2, 2) 34 | 35 | self.conv4_1 = nn.Conv2d(256, 512, 3, 1, 1) 36 | self.relu4_1 = nn.ReLU() 37 | self.conv4_2 = nn.Conv2d(512, 512, 3, 1, 1) 38 | self.relu4_2 = nn.ReLU() 39 | self.conv4_3 = nn.Conv2d(512, 512, 3, 1, 1) 40 | self.relu4_3 = nn.ReLU() 41 | self.maxpool4 = nn.MaxPool2d(2, 2) 42 | 43 | self.conv5_1 = nn.Conv2d(512, 512, 3, 1, 1) 44 | self.relu5_1 = nn.ReLU() 45 | self.conv5_2 = nn.Conv2d(512, 512, 3, 1, 1) 46 | self.relu5_2 = nn.ReLU() 47 | self.conv5_3 = nn.Conv2d(512, 512, 3, 1, 1) 48 | self.relu5_3 = nn.ReLU() 49 | 50 | self.conv1_b1 = nn.Conv2d(512, 512, 3, 1, 1) 51 | self.relu1_b1 = nn.ReLU() 52 | self.conv2_b1 = nn.Conv2d(512, 512, 3, 1, 1) 53 | self.relu2_b1 = nn.ReLU() 54 | self.dropout_b1 = nn.Dropout(p=0.7) # Not sure about this layer 55 | self.score_b1 = nn.Conv2d(512, 20, 1, 1, 0) 56 | self.GAP_b1 = nn.AvgPool2d(14, 1) # Global Average Pooling 57 | 58 | self.conv1_b2 = nn.Conv2d(512, 512, 3, 1, 1) 59 | self.relu1_b2 = nn.ReLU() 60 | self.conv2_b2 = nn.Conv2d(512, 512, 3, 1, 1) 61 | self.relu2_b2 = nn.ReLU() 62 | self.score_b2 = nn.Conv2d(512, 20, 1, 1, 0) 63 | self.GAP_b2 = nn.AvgPool2d(14, 1) # Global Average Pooling 64 | 65 | self.conv1_b3 = nn.Conv2d(512, 512, 3, 1, 1) 66 | self.relu1_b3 = nn.ReLU() 67 | self.conv2_b3 = nn.Conv2d(512, 512, 3, 1, 1) 68 | self.relu2_b3 = nn.ReLU() 69 | self.score_b3 = nn.Conv2d(512, 20, 1, 1, 0) 70 | self.GAP_b3 = nn.AvgPool2d(14, 1) # Global Average Pooling 71 | 72 | def forward(self, x, labels): 73 | 74 | x = self.maxpool1(self.relu1_2(self.conv1_2(self.relu1_1(self.conv1_1(x))))) 75 | x = self.maxpool2(self.relu2_2(self.conv2_2(self.relu2_1(self.conv2_1(x))))) 76 | x = self.maxpool3(self.relu3_3(self.conv3_3(self.relu3_2(self.conv3_2(self.relu3_1(self.conv3_1(x))))))) 77 | x = self.maxpool4(self.relu4_3(self.conv4_3(self.relu4_2(self.conv4_2(self.relu4_1(self.conv4_1(x))))))) 78 | features = self.relu5_3(self.conv5_3(self.relu5_2(self.conv5_2(self.relu5_1(self.conv5_1(x)))))) 79 | 80 | score_b1 = self.score_b1(self.dropout_b1(self.relu2_b1(self.conv2_b1(self.relu1_b1(self.conv1_b1(features)))))) 81 | prob1 = torch.sigmoid(self.GAP_b1(score_b1).squeeze()) 82 | 83 | features_b2 = mask_b2(features, score_b1, labels) 84 | score_b2 = self.score_b2(self.relu2_b2(self.conv2_b2(self.relu1_b2(self.conv1_b2(features_b2))))) 85 | prob2 = torch.sigmoid(self.GAP_b2(score_b2).squeeze()) 86 | 87 | features_b3 = mask_b3(features, score_b1, labels) 88 | score_b3 = self.score_b3(self.relu2_b3(self.conv2_b3(self.relu1_b3(self.conv1_b3(features_b3))))) 89 | prob3 = torch.sigmoid(self.GAP_b3(score_b3).squeeze()) 90 | 91 | return prob1, prob2, prob3, score_b1, score_b2 92 | 93 | def mask_b2(features, score, labels, maxt=0.8, mint=0.05): 94 | 95 | mask = score.clone() 96 | mask[mask < 0] = 0 97 | features = features.clone() 98 | 99 | for i in range(20): 100 | if torch.all(labels[:, i] == 0): 101 | mask[:, i, :, :] = 0 102 | else: 103 | bs_label = torch.nonzero(labels[:, i]).reshape(-1) 104 | ma, _ = mask[bs_label, i, :, :].reshape(-1, 14*14).max(dim=1) 105 | ma = ma.reshape(-1, 1, 1) 106 | mi, _ = mask[bs_label, i, :, :].reshape(-1, 14*14).min(dim=1) 107 | mi = mi.reshape(-1, 1, 1) 108 | tmp = (mask[bs_label, i, :, :] - mi) / (ma - mi + 1e-8) 109 | mask[:, i, :, :] = 0 110 | mask[bs_label, i, :, :] = tmp 111 | 112 | mask, _ = mask.max(dim=1) 113 | pos = torch.nonzero(mask > maxt).transpose(1, 0) 114 | neg = torch.nonzero(mask < mint).transpose(1, 0) 115 | features[pos[0], i, pos[1], pos[2]] = 0 116 | features[neg[0], i, neg[1], neg[2]] = -1 * features[neg[0], i, neg[1], neg[2]] 117 | 118 | return features 119 | 120 | def mask_b3(features, score, labels, thres=0.3): 121 | 122 | mask = score.clone() 123 | mask[mask < 0] = 0 124 | features = features.clone() 125 | 126 | for i in range(20): 127 | if torch.all(labels[:, i] == 0): 128 | mask[:, i, :, :] = 0 129 | else: 130 | bs_label = torch.nonzero(labels[:, i]).reshape(-1) 131 | ma, _ = mask[bs_label, i, :, :].reshape(-1, 14*14).max(dim=1) 132 | ma = ma.reshape(-1, 1, 1) 133 | mi, _ = mask[bs_label, i, :, :].reshape(-1, 14*14).min(dim=1) 134 | mi = mi.reshape(-1, 1, 1) 135 | tmp = (mask[bs_label, i, :, :] - mi) / (ma - mi + 1e-8) 136 | mask[:, i, :, :] = 0 137 | mask[bs_label, i, :, :] = tmp 138 | 139 | mask, _ = mask.max(dim=1) 140 | pos = torch.nonzero(mask > thres).transpose(1, 0) 141 | features[pos[0], i, pos[1], pos[2]] = 0 142 | 143 | return features 144 | 145 | def load_vgg16pretrain(model, vggmodel='vgg16convs.mat'): 146 | vgg16 = sio.loadmat(vggmodel) 147 | torch_params = model.state_dict() 148 | for k in vgg16.keys(): 149 | name_par = k.split('-') 150 | size = len(name_par) 151 | if size == 2: 152 | name_space = name_par[0] + '.' + name_par[1] 153 | data = np.squeeze(vgg16[k]) 154 | torch_params[name_space] = torch.from_numpy(data) 155 | model.load_state_dict(torch_params) 156 | 157 | def weights_init(m): 158 | if isinstance(m, nn.Conv2d): 159 | # xavier(m.weight.data) 160 | m.weight.data.normal_(0, 0.01) 161 | if m.bias is not None: 162 | m.bias.data.zero_() -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import cv2 4 | import os 5 | from os.path import isdir, join 6 | import numpy as np 7 | import torch 8 | from torch import optim 9 | import torch.nn as nn 10 | from torch.nn import DataParallel 11 | from torch.utils import data 12 | from torchvision import transforms 13 | import matplotlib.pylab as pylab 14 | import matplotlib.pyplot as plt 15 | import matplotlib as mpl 16 | from vltools import Logger 17 | from dataloader import JPEGLoader 18 | from model import train_net, load_vgg16pretrain, weights_init 19 | 20 | parser = argparse.ArgumentParser(description='PyTorch Implementation of SeeNet.') 21 | parser.add_argument('--bs', type=int, help='batch size', default=16) 22 | # optimizer parameters 23 | parser.add_argument('--lr', type=float, help='base learning rate', default=5e-3) 24 | parser.add_argument('--momentum', type=float, help='momentum', default=0.9) 25 | parser.add_argument('--stepsize', type=float, help='step size (epoch)', default=8) 26 | parser.add_argument('--gamma', type=float, help='gamma', default=0.1) 27 | parser.add_argument('--wd', type=float, help='weight decay', default=2e-4) 28 | parser.add_argument('--maxepoch', type=int, help='max epoch', default=40) 29 | # general parameters 30 | parser.add_argument('--print_freq', type=int, help='print frequency', default=10) 31 | parser.add_argument('--save_freq', type=int, help='save frequency', default=100) 32 | parser.add_argument('--cuda', type=str, help='cuda', default='3') 33 | parser.add_argument('--checkpoint', type=str, help='checkpoint prefix', default=None) 34 | parser.add_argument('--imgsize', type=int, help='image size fed into network', default=224) 35 | # datasets 36 | parser.add_argument('--tmp', type=str, default='tmp', help='root of saving images') 37 | args = parser.parse_args() 38 | 39 | def save_obj(obj, name): 40 | with open(name + '.pkl', 'wb') as f: 41 | pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL) 42 | 43 | 44 | def load_obj(name): 45 | with open(name + '.pkl', 'rb') as f: 46 | return pickle.load(f) 47 | 48 | class AverageMeter(object): 49 | """Computes and stores the average and current value""" 50 | def __init__(self, bs, maxlen=100): 51 | self.reset() 52 | self.maxlen = maxlen 53 | self.bs = bs 54 | 55 | def reset(self): 56 | self.memory = [] 57 | self.avg = 0 58 | self.val = 0 59 | self.count = 0 60 | 61 | def update(self, val): 62 | if self.count >= self.maxlen: 63 | self.memory.pop(0) 64 | self.count -= 1 65 | self.memory.append(val) 66 | self.val = val 67 | self.sum = sum(self.memory) 68 | self.count += 1 69 | self.avg = self.sum / self.count 70 | 71 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" 72 | os.environ["CUDA_VISIBLE_DEVICES"]=args.cuda 73 | 74 | cats = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 75 | 'dining table', 'dog', 'horse', 'motorbike', 'person', 'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor'] 76 | 77 | colormaps = ['#000000', '#7F0000', '#007F00', '#7F7F00', '#00007F', '#7F007F', '#007F7F', '#7F7F7F', '#3F0000', '#BF0000', '#3F7F00', 78 | '#BF7F00', '#3F007F', '#BF007F', '#3F7F7F', '#BF7F7F', '#003F00', '#7F3F00', '#00BF00', '#7FBF00', '#003F7F'] 79 | 80 | TMP_DIR = "/media/data1/SeeNet_result/"+args.tmp 81 | if not isdir(TMP_DIR): 82 | os.mkdir(TMP_DIR) 83 | if not isdir(join(TMP_DIR, 'checkpoint')): 84 | os.mkdir(join(TMP_DIR, 'checkpoint')) 85 | log = Logger(TMP_DIR+'/log.txt') 86 | 87 | transform = transforms.Compose([transforms.Resize(256), 88 | transforms.RandomCrop(args.imgsize)]) 89 | 90 | training_dataset = JPEGLoader('data/train_cls.txt', 'data/VOCdataset/VOCdevkit/VOC2012/JPEGImages', transform=transform) 91 | training_dataloader = data.DataLoader(training_dataset, batch_size=args.bs, 92 | shuffle=True, num_workers=8, pin_memory=True) 93 | 94 | model = train_net() 95 | model.apply(weights_init) 96 | load_vgg16pretrain(model) 97 | model = DataParallel(model).cuda() 98 | 99 | weight = [] 100 | bias = [] 101 | 102 | for name, p in model.named_parameters(): 103 | if 'weight' in name: 104 | weight.append(p) 105 | else: 106 | bias.append(p) 107 | 108 | optimizer = optim.SGD([{"params": weight, "lr": args.lr, "weight_decay": 0}, 109 | {"params": bias, "lr": 2*args.lr, "weight_decay": args.wd}], momentum=args.momentum) 110 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.stepsize, gamma=args.gamma) 111 | criterion = nn.BCELoss() 112 | 113 | losses = AverageMeter(args.bs) 114 | times = AverageMeter(1) 115 | record = {'avg': [], 'val': []} 116 | 117 | def plot_atts(att_lst, label_lst, size, savedir): 118 | pylab.rcParams['figure.figsize'] = size, size/2 119 | plt.figure() 120 | for i in range(0, len(att_lst)): 121 | s = plt.subplot(1,len(att_lst),i+1) 122 | 123 | if label_lst[i] == 'Source': 124 | s.set_xlabel(label_lst[i], fontsize=18) 125 | plt.imshow(att_lst[i]) 126 | else: 127 | s.set_xlabel(cats[int(label_lst[i])], fontsize=18) 128 | plt.imshow(att_lst[i], cmap = colormap(int(label_lst[i]))) 129 | s.set_xticklabels([]) 130 | s.set_yticklabels([]) 131 | s.yaxis.set_ticks_position('none') 132 | s.xaxis.set_ticks_position('none') 133 | plt.tight_layout() 134 | plt.savefig(savedir) 135 | plt.close() 136 | 137 | def colormap(index): 138 | return mpl.colors.LinearSegmentedColormap.from_list('cmap', [colormaps[0], colormaps[index+1], '#FFFFFF'], 256) 139 | 140 | def train_epoch(epoch): 141 | if not isdir(join(TMP_DIR, 'epoch%d'%epoch)): 142 | os.mkdir(join(TMP_DIR, 'epoch%d'%epoch)) 143 | model.train() 144 | 145 | for batch_idx, (img, labels) in enumerate(training_dataloader): 146 | start_time = time.time() 147 | bs = img.size(0) 148 | img, labels = img.cuda(), labels.cuda() 149 | prob1, prob2, prob3, score_b1, score_b2 = model(img, labels) 150 | background = torch.zeros(bs, 20, dtype=torch.float).cuda() 151 | 152 | #print(prob1[0, :]);print(prob2[0, :]);print(prob3[0, :]) 153 | 154 | loss1 = criterion(prob1, labels.float()) 155 | loss2 = criterion(prob2, labels.float()) 156 | loss3 = criterion(prob3, background) 157 | 158 | loss = loss1 + loss2 + loss3 159 | 160 | losses.update(loss) 161 | record['avg'].append(losses.avg) 162 | record['val'].append(losses.val) 163 | 164 | optimizer.zero_grad() 165 | loss.backward() 166 | optimizer.step() 167 | 168 | times.update(time.time() - start_time) 169 | 170 | if batch_idx % args.print_freq == 0: 171 | log.info("Tr|Ep %03d Bt %03d/%03d: sec/bt: %.2fsec, loss=%.3f (avg=%.3f)" \ 172 | % (epoch, batch_idx, len(training_dataloader), times.val, losses.val, losses.avg)) 173 | 174 | if batch_idx % args.save_freq == 0: 175 | rn = np.random.choice(bs, 1, replace=False) 176 | for i in rn: 177 | i_image = img[i, :, :, :].cpu().detach().numpy() 178 | i_image = (np.transpose(i_image, [1, 2, 0]) + np.array([116.62341813, 111.51273588, 103.14803339])).astype('uint8') 179 | image_gray = cv2.cvtColor(i_image, cv2.COLOR_BGR2GRAY) 180 | i_label = torch.nonzero(labels[i, :]).reshape(-1) 181 | att_maps = [] 182 | 183 | for j in i_label: 184 | att1, att2 = score_b1[i, j, :, :].cpu().detach().numpy(), score_b2[i, j, :, :].cpu().detach().numpy() 185 | att1 = cv2.resize(att1, (args.imgsize, args.imgsize), interpolation=cv2.INTER_CUBIC) 186 | att2 = cv2.resize(att2, (args.imgsize, args.imgsize), interpolation=cv2.INTER_CUBIC) 187 | att1[att1 < 0] = 0 188 | att2[att2 < 0] = 0 189 | att1 = att1 / (np.max(att1) + 1e-8) 190 | att2 = att2 / (np.max(att2) + 1e-8) 191 | att = np.maximum(att1, att2) 192 | #att = att * 0.8 + image_gray / 255. * 0.2 193 | att_maps.append(att) 194 | 195 | res_lst = [i_image[:,:,::-1],] + att_maps 196 | label_lst = ['Source',] + i_label.cpu().numpy().tolist() 197 | plot_atts(res_lst, label_lst, 16, join(TMP_DIR, 'epoch%d'%epoch, 'iter%d.jpg'%batch_idx)) 198 | 199 | torch.save(model.state_dict(), TMP_DIR+'/checkpoint/epoch%d'%epoch) 200 | log.info('checkpoint has been created!') 201 | 202 | 203 | def main(): 204 | for epoch in range(args.maxepoch): 205 | scheduler.step() # will adjust learning rate 206 | train_epoch(epoch) 207 | 208 | fig, axes = plt.subplots(1, 2, figsize=(8, 4)) 209 | 210 | axes[0].plot(record['avg']) 211 | axes[0].legend(['Loss_avg'], loc="upper right") 212 | axes[0].grid(alpha=0.5, linestyle='dotted', linewidth=2, color='black') 213 | axes[0].set_xlabel("Iter") 214 | axes[0].set_ylabel("Loss_avg") 215 | 216 | axes[1].plot(record['val']) 217 | axes[1].grid(alpha=0.5, linestyle='dotted', linewidth=2, color='black') 218 | axes[1].legend(["Loss_val"], loc="upper right") 219 | axes[1].set_xlabel("Iter") 220 | axes[1].set_ylabel("Loss_val") 221 | 222 | plt.tight_layout() 223 | plt.savefig(TMP_DIR+'/record.pdf') 224 | plt.close(fig) 225 | 226 | 227 | if __name__ == '__main__': 228 | main() 229 | --------------------------------------------------------------------------------